Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0292c90ca1 | |||
| 617c7dc6b9 | |||
| 542b79dacf | |||
| 0a25b3bd31 | |||
| 32bc781326 | |||
| 316ad63f76 | |||
| e37077a963 | |||
| eff9b60434 | |||
| 9bcddad91b |
@@ -0,0 +1,76 @@
|
|||||||
|
name: Build CUDA image (fork)
|
||||||
|
|
||||||
|
# Builds this fork's llama-swap (serial scheduler + embedded UI) from source and
|
||||||
|
# layers it on a pinned llama.cpp CUDA server base, then pushes to the Gitea
|
||||||
|
# container registry, e.g. gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||||
|
#
|
||||||
|
# Requires repo secrets: REGISTRY_USER, REGISTRY_PASSWORD (push to the registry).
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
llama_swap_version:
|
||||||
|
description: "llama-swap version label (image tag prefix)"
|
||||||
|
required: false
|
||||||
|
default: "v230"
|
||||||
|
llamacpp_build:
|
||||||
|
description: "llama.cpp CUDA server build (base image tag suffix)"
|
||||||
|
required: false
|
||||||
|
default: "b9821"
|
||||||
|
# Building the build definition itself kicks off a fresh image.
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
paths:
|
||||||
|
- ".gitea/workflows/build-cuda-image.yml"
|
||||||
|
- "docker/fork-cuda.Containerfile"
|
||||||
|
|
||||||
|
env:
|
||||||
|
REGISTRY: gitea.stevedudenhoeffer.com
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Compute image metadata
|
||||||
|
id: meta
|
||||||
|
run: |
|
||||||
|
LS_VER="${{ inputs.llama_swap_version || 'v230' }}"
|
||||||
|
LCPP="${{ inputs.llamacpp_build || 'b9821' }}"
|
||||||
|
{
|
||||||
|
echo "image=${REGISTRY}/${{ github.repository }}"
|
||||||
|
echo "tag=${LS_VER}-cuda-${LCPP}"
|
||||||
|
echo "base_tag=server-cuda-${LCPP}"
|
||||||
|
echo "ls_version=${LS_VER}"
|
||||||
|
echo "build_date=$(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||||
|
} >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Log in to Gitea registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ secrets.REGISTRY_USER }}
|
||||||
|
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||||
|
|
||||||
|
- name: Build and push
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: docker/fork-cuda.Containerfile
|
||||||
|
push: true
|
||||||
|
provenance: false
|
||||||
|
build-args: |
|
||||||
|
BASE_TAG=${{ steps.meta.outputs.base_tag }}
|
||||||
|
LS_VERSION=${{ steps.meta.outputs.ls_version }}
|
||||||
|
GIT_HASH=${{ github.sha }}
|
||||||
|
BUILD_DATE=${{ steps.meta.outputs.build_date }}
|
||||||
|
tags: ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}
|
||||||
|
|
||||||
|
- name: Summary
|
||||||
|
run: |
|
||||||
|
echo "Pushed ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}" >> "$GITHUB_STEP_SUMMARY"
|
||||||
@@ -5,16 +5,14 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
|||||||
## Tech stack
|
## Tech stack
|
||||||
|
|
||||||
- golang
|
- golang
|
||||||
- typescript, vite and svelt5 for UI (located in ui/)
|
- typescript, vite and svelte 5 for UI (located in ui-svelte/)
|
||||||
|
|
||||||
## Workflow Tasks
|
## Workflow Tasks
|
||||||
|
|
||||||
- when summarizing changes only include details that require further action
|
- when summarizing changes only include details that require further action
|
||||||
- just say "Done." when there is no further action
|
|
||||||
- use the github CLI `gh` to create pull requests and work with github
|
|
||||||
- Rules for creating pull requests:
|
- Rules for creating pull requests:
|
||||||
- keep them short and focused on changes.
|
- keep them short and focused on changes
|
||||||
- never include a test plan
|
- skip the test plan
|
||||||
- write the summary using the same style rules as commit message
|
- write the summary using the same style rules as commit message
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
@@ -30,7 +28,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
|||||||
### Commit message example format:
|
### Commit message example format:
|
||||||
|
|
||||||
```
|
```
|
||||||
proxy: add new feature
|
internal/server: add new feature
|
||||||
|
|
||||||
Add new feature that implements functionality X and Y.
|
Add new feature that implements functionality X and Y.
|
||||||
|
|
||||||
|
|||||||
+21
-2
@@ -572,6 +572,24 @@
|
|||||||
"default": {},
|
"default": {},
|
||||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||||
},
|
},
|
||||||
|
"upstream": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Controls behaviour of the /upstream passthrough endpoint. Recommended to only use in special use cases; leaving it as the default will typically be the best experience.",
|
||||||
|
"properties": {
|
||||||
|
"ignorePaths": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"default": [
|
||||||
|
".*\\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$"
|
||||||
|
],
|
||||||
|
"description": "List of RE2 compatible regular expressions. Any request to a path matching any of the regular expressions will be ignored and not trigger a swap. When not specified, defaults to a pattern matching common static-asset suffixes (.js, .json, .css, .png, .gif, .jpg, .jpeg, .ico, .txt)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"default": {}
|
||||||
|
},
|
||||||
"routing": {
|
"routing": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
|
"description": "Canonical routing/scheduling configuration. Alternative to the legacy top-level 'groups'/'matrix' keys; a config must not use both styles.",
|
||||||
@@ -583,10 +601,11 @@
|
|||||||
"use": {
|
"use": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
|
"serial",
|
||||||
"fifo"
|
"fifo"
|
||||||
],
|
],
|
||||||
"default": "fifo",
|
"default": "serial",
|
||||||
"description": "Scheduler to use. Only 'fifo' is currently supported."
|
"description": "Scheduler to use. 'serial' (default on this fork): strict one-model-at-a-time, requests run in exact arrival order, switching models evicts every other model first. 'fifo': throughput-oriented, batches same-model requests and allows parallel/co-resident models."
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
+25
-3
@@ -134,6 +134,18 @@ apiKeys:
|
|||||||
- "${env.API_KEY_1}"
|
- "${env.API_KEY_1}"
|
||||||
- "${env.API_KEY_2}"
|
- "${env.API_KEY_2}"
|
||||||
|
|
||||||
|
# upstream: controls behaviour of the /upstream passthrough endpoint
|
||||||
|
# - optional, default: empty dictionary
|
||||||
|
# - recommended to only use in special use cases. Leaving it as the
|
||||||
|
# default will typically be the best experience
|
||||||
|
upstream:
|
||||||
|
# ignorePaths: list of RE2 compatible regular expressions
|
||||||
|
# - default: (see below)
|
||||||
|
# - any request to a path matching any of the regular expressions
|
||||||
|
# will be ignored and not trigger a swap
|
||||||
|
ignorePaths:
|
||||||
|
- '.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$'
|
||||||
|
|
||||||
# models: a dictionary of model configurations
|
# models: a dictionary of model configurations
|
||||||
# - required
|
# - required
|
||||||
# - each key is the model's ID, used in API requests
|
# - each key is the model's ID, used in API requests
|
||||||
@@ -544,11 +556,21 @@ routing:
|
|||||||
# expands to: [L]
|
# expands to: [L]
|
||||||
full: "L"
|
full: "L"
|
||||||
|
|
||||||
# scheduler: how queued requests are ordered.
|
# scheduler: how queued requests are ordered and run.
|
||||||
# The default and only valid scheduler is "fifo"
|
# - optional, default on this fork: "serial"
|
||||||
|
# - valid values:
|
||||||
|
# - "serial": strict one-model-at-a-time. Requests run in exact arrival
|
||||||
|
# order; only one request runs at a time; switching to a different model
|
||||||
|
# evicts every other running model first so a single model occupies memory
|
||||||
|
# at a time. This ignores group/matrix co-residency entirely. The "fifo"
|
||||||
|
# settings below (priority) do not apply.
|
||||||
|
# - "fifo": throughput-oriented. Same-model requests are batched to reduce
|
||||||
|
# swaps and a model serves up to its concurrencyLimit in parallel; models
|
||||||
|
# in non-exclusive groups can run concurrently. Requests may be reordered.
|
||||||
scheduler:
|
scheduler:
|
||||||
use: fifo
|
use: serial
|
||||||
settings:
|
settings:
|
||||||
|
# fifo settings only apply when use: fifo
|
||||||
fifo:
|
fifo:
|
||||||
# priority: a dictionary of model ID -> priority
|
# priority: a dictionary of model ID -> priority
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
# Build a CUDA llama-swap image FROM THIS FORK's source (includes the serial
|
||||||
|
# scheduler) and layer it on a pinned llama.cpp CUDA server base. Produces e.g.:
|
||||||
|
# gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||||
|
#
|
||||||
|
# BASE_TAG selects the llama.cpp CUDA runtime + llama-server build, e.g.
|
||||||
|
# "server-cuda-b9821". The llama-swap binary (with the embedded Svelte UI) is
|
||||||
|
# compiled from the repo at build time, so no GitHub release is required.
|
||||||
|
#
|
||||||
|
# Build context is the repo root:
|
||||||
|
# docker build -f docker/fork-cuda.Containerfile \
|
||||||
|
# --build-arg BASE_TAG=server-cuda-b9821 -t llama-swap:v230-cuda-b9821 .
|
||||||
|
|
||||||
|
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||||
|
ARG BASE_TAG=server-cuda-b9821
|
||||||
|
|
||||||
|
# ---- Stage 1: build the Svelte UI (embedded into the binary) ----
|
||||||
|
FROM node:22-bookworm-slim AS ui
|
||||||
|
WORKDIR /src/ui-svelte
|
||||||
|
# Install deps first for layer caching. .npmrc carries legacy-peer-deps=true,
|
||||||
|
# which the project relies on (tailwind/vite peer ranges), so copy it before
|
||||||
|
# npm ci or the strict resolver fails with ERESOLVE.
|
||||||
|
COPY ui-svelte/package.json ui-svelte/package-lock.json ui-svelte/.npmrc ./
|
||||||
|
RUN npm ci
|
||||||
|
COPY ui-svelte/ ./
|
||||||
|
# `npm run build` is `vite build --emptyOutDir`; vite.config.ts writes to
|
||||||
|
# ../internal/server/ui_dist, which //go:embed picks up in the next stage.
|
||||||
|
RUN mkdir -p /src/internal/server && npm run build
|
||||||
|
|
||||||
|
# ---- Stage 2: build the llama-swap binary with the embedded UI ----
|
||||||
|
FROM golang:1.26-bookworm AS build
|
||||||
|
WORKDIR /src
|
||||||
|
# Cache modules independently of source churn.
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
COPY . .
|
||||||
|
# Overlay the freshly built UI so //go:embed ui_dist ships the real assets
|
||||||
|
# instead of the committed placeholder.
|
||||||
|
COPY --from=ui /src/internal/server/ui_dist/ ./internal/server/ui_dist/
|
||||||
|
ARG LS_VERSION=v230
|
||||||
|
ARG GIT_HASH=unknown
|
||||||
|
ARG BUILD_DATE=unknown
|
||||||
|
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||||
|
-ldflags="-X main.version=${LS_VERSION} -X main.commit=${GIT_HASH} -X main.date=${BUILD_DATE}" \
|
||||||
|
-o /out/llama-swap .
|
||||||
|
|
||||||
|
# ---- Stage 3: runtime image on the pinned llama.cpp CUDA base ----
|
||||||
|
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||||
|
|
||||||
|
# Run as root by default to match the upstream `vNNN-cuda-bNNNN` (non-suffixed)
|
||||||
|
# image that ragnaros pulls today: it needs root to reach the mounted docker
|
||||||
|
# socket for container-backed models (sd-server). Override UID/GID at build time
|
||||||
|
# for a non-root variant.
|
||||||
|
ARG UID=0
|
||||||
|
ARG GID=0
|
||||||
|
ARG USER_HOME=/root
|
||||||
|
ENV HOME=$USER_HOME
|
||||||
|
|
||||||
|
RUN set -eux; \
|
||||||
|
if [ "$UID" -ne 0 ]; then \
|
||||||
|
if [ "$GID" -ne 0 ]; then groupadd --system --gid "$GID" app; fi; \
|
||||||
|
useradd --system --uid "$UID" --gid "$GID" --home "$USER_HOME" app; \
|
||||||
|
fi; \
|
||||||
|
mkdir --parents "$HOME" /app; \
|
||||||
|
chown --recursive "$UID:$GID" "$HOME" /app
|
||||||
|
|
||||||
|
COPY --from=build --chown=$UID:$GID /out/llama-swap /app/llama-swap
|
||||||
|
COPY --chown=$UID:$GID docker/config.example.yaml /app/config.yaml
|
||||||
|
|
||||||
|
USER $UID:$GID
|
||||||
|
WORKDIR /app
|
||||||
|
ENV PATH="/app:${PATH}"
|
||||||
|
|
||||||
|
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||||
|
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/billziss-gh/golib/shlex"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Handle trailing backslashes by replacing with space
|
||||||
|
if strings.HasSuffix(trimmed, "\\") {
|
||||||
|
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||||
|
} else {
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// put it back together
|
||||||
|
cmdStr = strings.Join(cleanedLines, "\n")
|
||||||
|
|
||||||
|
// Split the command into arguments
|
||||||
|
var args []string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
args = shlex.Windows.Split(cmdStr)
|
||||||
|
} else {
|
||||||
|
args = shlex.Posix.Split(cmdStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the command is not empty
|
||||||
|
if len(args) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty command")
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripComments(cmdStr string) string {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
return strings.Join(cleanedLines, "\n")
|
||||||
|
}
|
||||||
+3
-661
@@ -2,16 +2,9 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
|
||||||
"runtime"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/billziss-gh/golib/shlex"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,12 +78,6 @@ type GroupConfig struct {
|
|||||||
Members []string `yaml:"members"`
|
Members []string `yaml:"members"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
|
||||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
|
||||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
|
||||||
)
|
|
||||||
|
|
||||||
// set default values for GroupConfig
|
// set default values for GroupConfig
|
||||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
type rawGroupConfig GroupConfig
|
type rawGroupConfig GroupConfig
|
||||||
@@ -163,6 +150,9 @@ type Config struct {
|
|||||||
|
|
||||||
// support remote peers, see issue #433, #296
|
// support remote peers, see issue #433, #296
|
||||||
Peers PeerDictionaryConfig `yaml:"peers"`
|
Peers PeerDictionaryConfig `yaml:"peers"`
|
||||||
|
|
||||||
|
// upstream controls behaviour of the /upstream passthrough endpoint
|
||||||
|
Upstream UpstreamConfig `yaml:"upstream"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
|
// RoutingConfig is the canonical, normalized routing/scheduling configuration.
|
||||||
@@ -221,424 +211,6 @@ func LoadConfig(path string) (Config, error) {
|
|||||||
return LoadConfigFromReader(file)
|
return LoadConfigFromReader(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
yamlStr := string(data)
|
|
||||||
|
|
||||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
|
||||||
// This is safe because env values are simple strings without YAML formatting
|
|
||||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal into full Config with defaults
|
|
||||||
config := Config{
|
|
||||||
HealthCheckTimeout: 120,
|
|
||||||
StartPort: 5800,
|
|
||||||
LogLevel: "info",
|
|
||||||
LogTimeFormat: "",
|
|
||||||
LogToStdout: LogToStdoutProxy,
|
|
||||||
MetricsMaxInMemory: 1000,
|
|
||||||
CaptureBuffer: 5,
|
|
||||||
GlobalTTL: 0,
|
|
||||||
}
|
|
||||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.HealthCheckTimeout < 15 {
|
|
||||||
config.HealthCheckTimeout = 15
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply defaults for performance config when section is missing
|
|
||||||
if config.Performance.Every == 0 {
|
|
||||||
config.Performance.Every = 5 * time.Second
|
|
||||||
}
|
|
||||||
if err = config.Performance.Validate(); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("performance: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.StartPort < 1 {
|
|
||||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.GlobalTTL < 0 {
|
|
||||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch config.LogToStdout {
|
|
||||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
|
||||||
default:
|
|
||||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate the aliases map
|
|
||||||
config.aliases = make(map[string]string)
|
|
||||||
for modelName, modelConfig := range config.Models {
|
|
||||||
for _, alias := range modelConfig.Aliases {
|
|
||||||
if _, found := config.aliases[alias]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
|
||||||
}
|
|
||||||
config.aliases[alias] = modelName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate global macros
|
|
||||||
for _, macro := range config.Macros {
|
|
||||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get and sort all model IDs for consistent port assignment
|
|
||||||
modelIds := make([]string, 0, len(config.Models))
|
|
||||||
for modelId := range config.Models {
|
|
||||||
modelIds = append(modelIds, modelId)
|
|
||||||
}
|
|
||||||
sort.Strings(modelIds)
|
|
||||||
|
|
||||||
nextPort := config.StartPort
|
|
||||||
for _, modelId := range modelIds {
|
|
||||||
modelConfig := config.Models[modelId]
|
|
||||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
|
||||||
|
|
||||||
// Strip comments from command fields
|
|
||||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
|
||||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
|
||||||
|
|
||||||
// set model TTL to globalTTL it is the default value
|
|
||||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
|
||||||
modelConfig.UnloadAfter = config.GlobalTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelConfig.UnloadAfter < 0 {
|
|
||||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate model macros
|
|
||||||
for _, macro := range modelConfig.Macros {
|
|
||||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
|
||||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
|
||||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
|
||||||
mergedMacros = append(mergedMacros, config.Macros...)
|
|
||||||
|
|
||||||
// Add model macros (override globals with same name)
|
|
||||||
for _, entry := range modelConfig.Macros {
|
|
||||||
found := false
|
|
||||||
for i, existing := range mergedMacros {
|
|
||||||
if existing.Name == entry.Name {
|
|
||||||
mergedMacros[i] = entry
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
mergedMacros = append(mergedMacros, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Substitute remaining macros in model fields (LIFO order)
|
|
||||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
|
||||||
entry := mergedMacros[i]
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
|
||||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
|
||||||
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
|
||||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
|
||||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
|
||||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
|
||||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
|
||||||
|
|
||||||
// Substitute macros in SetParamsByID keys and values
|
|
||||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
|
||||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
|
||||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
|
||||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
|
||||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
newParamMap, ok := newValAny.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
|
||||||
}
|
|
||||||
newSetParamsByID[newKey] = newParamMap
|
|
||||||
}
|
|
||||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Substitute in metadata (type-preserving)
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
modelConfig.Metadata = result.(map[string]any)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle PORT macro - only allocate if cmd uses it
|
|
||||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
|
||||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
|
||||||
if cmdHasPort || proxyHasPort {
|
|
||||||
if !cmdHasPort && proxyHasPort {
|
|
||||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
|
||||||
}
|
|
||||||
|
|
||||||
macroSlug := "${PORT}"
|
|
||||||
macroStr := fmt.Sprintf("%v", nextPort)
|
|
||||||
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
|
||||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
|
||||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
|
||||||
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
modelConfig.Metadata = result.(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
nextPort++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate no unknown macros remain
|
|
||||||
fieldMap := map[string]string{
|
|
||||||
"cmd": modelConfig.Cmd,
|
|
||||||
"cmdStop": modelConfig.CmdStop,
|
|
||||||
"proxy": modelConfig.Proxy,
|
|
||||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
|
||||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
|
||||||
"name": modelConfig.Name,
|
|
||||||
"description": modelConfig.Description,
|
|
||||||
}
|
|
||||||
|
|
||||||
for fieldName, fieldValue := range fieldMap {
|
|
||||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
macroName := match[1]
|
|
||||||
if macroName == "PID" && fieldName == "cmdStop" {
|
|
||||||
continue // replaced at runtime
|
|
||||||
}
|
|
||||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
|
||||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
|
||||||
}
|
|
||||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate SetParamsByID keys and values
|
|
||||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
|
||||||
}
|
|
||||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
|
||||||
for key := range modelConfig.Filters.SetParamsByID {
|
|
||||||
if key == modelId {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, exists := config.Models[key]; exists {
|
|
||||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
|
||||||
}
|
|
||||||
if existingModel, exists := config.aliases[key]; exists {
|
|
||||||
if existingModel != modelId {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
|
||||||
}
|
|
||||||
continue // already registered as explicit alias for this model
|
|
||||||
}
|
|
||||||
config.aliases[key] = modelId
|
|
||||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelConfig.SendLoadingState == nil {
|
|
||||||
v := config.SendLoadingState
|
|
||||||
modelConfig.SendLoadingState = &v
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Models[modelId] = modelConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
|
||||||
// the new `routing.router` block are mutually exclusive: a config may use
|
|
||||||
// either style, never both.
|
|
||||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
|
||||||
rtr := config.Routing.Router
|
|
||||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
|
||||||
|
|
||||||
if hasTopLevel && hasRouting {
|
|
||||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasTopLevel {
|
|
||||||
// Both groups and matrix may be defined under routing.router.settings;
|
|
||||||
// routing.router.use selects which one is active, so there is no conflict.
|
|
||||||
rs := config.Routing.Router.Settings
|
|
||||||
switch config.Routing.Router.Use {
|
|
||||||
case "matrix":
|
|
||||||
if rs.Matrix == nil {
|
|
||||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
|
||||||
}
|
|
||||||
config.Matrix = rs.Matrix
|
|
||||||
case "group", "":
|
|
||||||
config.Groups = rs.Groups
|
|
||||||
default:
|
|
||||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// groups XOR matrix
|
|
||||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Matrix != nil {
|
|
||||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
|
||||||
}
|
|
||||||
config.Matrix.ExpandedSets = expandedSets
|
|
||||||
} else {
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
|
||||||
|
|
||||||
// Validate group members
|
|
||||||
memberUsage := make(map[string]string)
|
|
||||||
for groupID, groupConfig := range config.Groups {
|
|
||||||
prevSet := make(map[string]bool)
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
if _, found := prevSet[member]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
|
||||||
}
|
|
||||||
prevSet[member] = true
|
|
||||||
|
|
||||||
if existingGroup, exists := memberUsage[member]; exists {
|
|
||||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
|
||||||
}
|
|
||||||
memberUsage[member] = groupID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
|
||||||
// and new-style configs converge here. The Matrix pointer is shared so
|
|
||||||
// ExpandedSets stays in one place.
|
|
||||||
if config.Matrix != nil {
|
|
||||||
config.Routing.Router.Use = "matrix"
|
|
||||||
} else {
|
|
||||||
config.Routing.Router.Use = "group"
|
|
||||||
}
|
|
||||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
|
||||||
config.Routing.Router.Settings.Groups = config.Groups
|
|
||||||
|
|
||||||
if config.Routing.Scheduler.Use == "" {
|
|
||||||
config.Routing.Scheduler.Use = "fifo"
|
|
||||||
}
|
|
||||||
if config.Routing.Scheduler.Use != "fifo" {
|
|
||||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
|
|
||||||
}
|
|
||||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
|
||||||
if _, found := config.RealModelName(modelID); !found {
|
|
||||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up hooks preload
|
|
||||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
|
||||||
var toPreload []string
|
|
||||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
|
||||||
modelID = strings.TrimSpace(modelID)
|
|
||||||
if modelID == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if real, found := config.RealModelName(modelID); found {
|
|
||||||
toPreload = append(toPreload, real)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config.Hooks.OnStartup.Preload = toPreload
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate API keys (env macros already substituted at string level)
|
|
||||||
for i, apikey := range config.RequiredAPIKeys {
|
|
||||||
if apikey == "" {
|
|
||||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
|
||||||
}
|
|
||||||
if strings.Contains(apikey, " ") {
|
|
||||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
|
||||||
}
|
|
||||||
config.RequiredAPIKeys[i] = apikey
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process peers with global macro substitution
|
|
||||||
for peerName, peerConfig := range config.Peers {
|
|
||||||
// Substitute global macros (LIFO order)
|
|
||||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
|
||||||
entry := config.Macros[i]
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
|
||||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
|
||||||
|
|
||||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
|
||||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
|
||||||
|
|
||||||
// Substitute in setParams (type-preserving)
|
|
||||||
if len(peerConfig.Filters.SetParams) > 0 {
|
|
||||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
|
||||||
}
|
|
||||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate no unknown macros remain
|
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
|
||||||
}
|
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
|
||||||
}
|
|
||||||
if len(peerConfig.Filters.SetParams) > 0 {
|
|
||||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config.Peers[peerName] = peerConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewrites the yaml to include a default group with any orphaned models
|
// rewrites the yaml to include a default group with any orphaned models
|
||||||
func AddDefaultGroupToConfig(config Config) Config {
|
func AddDefaultGroupToConfig(config Config) Config {
|
||||||
|
|
||||||
@@ -683,233 +255,3 @@ func AddDefaultGroupToConfig(config Config) Config {
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
|
||||||
var cleanedLines []string
|
|
||||||
for _, line := range strings.Split(cmdStr, "\n") {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
// Skip comment lines
|
|
||||||
if strings.HasPrefix(trimmed, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Handle trailing backslashes by replacing with space
|
|
||||||
if strings.HasSuffix(trimmed, "\\") {
|
|
||||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
|
||||||
} else {
|
|
||||||
cleanedLines = append(cleanedLines, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// put it back together
|
|
||||||
cmdStr = strings.Join(cleanedLines, "\n")
|
|
||||||
|
|
||||||
// Split the command into arguments
|
|
||||||
var args []string
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
args = shlex.Windows.Split(cmdStr)
|
|
||||||
} else {
|
|
||||||
args = shlex.Posix.Split(cmdStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the command is not empty
|
|
||||||
if len(args) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty command")
|
|
||||||
}
|
|
||||||
|
|
||||||
return args, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func StripComments(cmdStr string) string {
|
|
||||||
var cleanedLines []string
|
|
||||||
for _, line := range strings.Split(cmdStr, "\n") {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
// Skip comment lines
|
|
||||||
if strings.HasPrefix(trimmed, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cleanedLines = append(cleanedLines, line)
|
|
||||||
}
|
|
||||||
return strings.Join(cleanedLines, "\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateMacro validates macro name and value constraints
|
|
||||||
func validateMacro(name string, value any) error {
|
|
||||||
if len(name) >= 64 {
|
|
||||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
|
||||||
}
|
|
||||||
if !macroNameRegex.MatchString(name) {
|
|
||||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate that value is a scalar type
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
// Check for self-reference
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", name)
|
|
||||||
if strings.Contains(v, macroSlug) {
|
|
||||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
|
||||||
}
|
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
|
||||||
// These types are allowed
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch name {
|
|
||||||
case "PORT", "MODEL_ID":
|
|
||||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
|
||||||
func validateNestedForUnknownMacros(value any, context string) error {
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
macroName := match[1]
|
|
||||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
|
||||||
}
|
|
||||||
// Check for unsubstituted env macros
|
|
||||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
|
||||||
for _, match := range envMatches {
|
|
||||||
varName := match[1]
|
|
||||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case map[string]any:
|
|
||||||
for _, val := range v {
|
|
||||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case []any:
|
|
||||||
for _, val := range v {
|
|
||||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Scalar types don't contain macros
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
|
||||||
// This is called once per macro, allowing LIFO substitution order
|
|
||||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
|
||||||
macroStr := fmt.Sprintf("%v", macroValue)
|
|
||||||
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
// Check if this is a direct macro substitution
|
|
||||||
if v == macroSlug {
|
|
||||||
return macroValue, nil
|
|
||||||
}
|
|
||||||
// Handle string interpolation
|
|
||||||
if strings.Contains(v, macroSlug) {
|
|
||||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
|
||||||
}
|
|
||||||
return v, nil
|
|
||||||
|
|
||||||
case map[string]any:
|
|
||||||
// Recursively process map values
|
|
||||||
newMap := make(map[string]any)
|
|
||||||
for key, val := range v {
|
|
||||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newMap[key] = newVal
|
|
||||||
}
|
|
||||||
return newMap, nil
|
|
||||||
|
|
||||||
case []any:
|
|
||||||
// Recursively process slice elements
|
|
||||||
newSlice := make([]any, len(v))
|
|
||||||
for i, val := range v {
|
|
||||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newSlice[i] = newVal
|
|
||||||
}
|
|
||||||
return newSlice, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Return scalar types as-is
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
|
||||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
|
||||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
|
||||||
// (which strips comments) and only checking the comment-free version for macros.
|
|
||||||
func substituteEnvMacros(s string) (string, error) {
|
|
||||||
// Unmarshal and remarshal to strip YAML comments
|
|
||||||
var raw any
|
|
||||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
|
||||||
// If YAML is invalid, fall back to scanning the original string
|
|
||||||
// so the user gets the env var error rather than a confusing YAML parse error
|
|
||||||
return substituteEnvMacrosInString(s, s)
|
|
||||||
}
|
|
||||||
clean, err := yaml.Marshal(raw)
|
|
||||||
if err != nil {
|
|
||||||
return substituteEnvMacrosInString(s, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
return substituteEnvMacrosInString(s, string(clean))
|
|
||||||
}
|
|
||||||
|
|
||||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
|
||||||
// them in target. This separation allows scanning comment-free YAML while
|
|
||||||
// substituting in the original string.
|
|
||||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
|
||||||
result := target
|
|
||||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
fullMatch := match[0] // ${env.VAR_NAME}
|
|
||||||
varName := match[1] // VAR_NAME
|
|
||||||
|
|
||||||
value, exists := os.LookupEnv(varName)
|
|
||||||
if !exists {
|
|
||||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sanitize the value for safe YAML substitution
|
|
||||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
result = strings.ReplaceAll(result, fullMatch, value)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
|
||||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
|
||||||
// for compatibility with double-quoted YAML strings.
|
|
||||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
|
||||||
// Reject values that would break YAML structure regardless of quoting context
|
|
||||||
if strings.ContainsAny(value, "\n\r\x00") {
|
|
||||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
|
||||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
|
||||||
// In double-quoted contexts, they are interpreted correctly.
|
|
||||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
|
||||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
|
||||||
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -266,6 +266,9 @@ groups:
|
|||||||
"mthree": "model3",
|
"mthree": "model3",
|
||||||
},
|
},
|
||||||
Groups: expectedGroups,
|
Groups: expectedGroups,
|
||||||
|
Upstream: UpstreamConfig{
|
||||||
|
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||||
|
},
|
||||||
Routing: RoutingConfig{
|
Routing: RoutingConfig{
|
||||||
Router: RouterConfig{
|
Router: RouterConfig{
|
||||||
Use: "group",
|
Use: "group",
|
||||||
@@ -274,7 +277,7 @@ groups:
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Scheduler: SchedulerConfig{
|
Scheduler: SchedulerConfig{
|
||||||
Use: "fifo",
|
Use: "serial",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -777,22 +777,27 @@ func TestConfig_APIKeys_Invalid(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "blank spaces only",
|
name: "blank spaces only",
|
||||||
content: `apiKeys: [" "]`,
|
content: `apiKeys: [" "]`,
|
||||||
expectedErr: "api key cannot contain spaces: ` `",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "contains leading space",
|
name: "contains leading space",
|
||||||
content: `apiKeys: [" key123"]`,
|
content: `apiKeys: [" key123"]`,
|
||||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "contains trailing space",
|
name: "contains trailing space",
|
||||||
content: `apiKeys: ["key123 "]`,
|
content: `apiKeys: ["key123 "]`,
|
||||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "contains middle space",
|
name: "contains middle space",
|
||||||
content: `apiKeys: ["key 123"]`,
|
content: `apiKeys: ["key 123"]`,
|
||||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "space in second key reports correct index",
|
||||||
|
content: `apiKeys: ["valid-key", "bad key"]`,
|
||||||
|
expectedErr: "apiKeys[1]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty in list with valid keys",
|
name: "empty in list with valid keys",
|
||||||
@@ -1567,7 +1572,7 @@ groups:
|
|||||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||||
// default group injected for orphaned models (none here) still leaves g1
|
// default group injected for orphaned models (none here) still leaves g1
|
||||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
||||||
@@ -1626,7 +1631,7 @@ func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
|
|||||||
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
||||||
|
|||||||
@@ -255,6 +255,9 @@ groups:
|
|||||||
"mthree": "model3",
|
"mthree": "model3",
|
||||||
},
|
},
|
||||||
Groups: expectedGroups,
|
Groups: expectedGroups,
|
||||||
|
Upstream: UpstreamConfig{
|
||||||
|
IgnorePaths: DefaultUpstreamIgnorePaths(),
|
||||||
|
},
|
||||||
Routing: RoutingConfig{
|
Routing: RoutingConfig{
|
||||||
Router: RouterConfig{
|
Router: RouterConfig{
|
||||||
Use: "group",
|
Use: "group",
|
||||||
@@ -263,7 +266,7 @@ groups:
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Scheduler: SchedulerConfig{
|
Scheduler: SchedulerConfig{
|
||||||
Use: "fifo",
|
Use: "serial",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,441 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
yamlStr := string(data)
|
||||||
|
|
||||||
|
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||||
|
// This is safe because env values are simple strings without YAML formatting
|
||||||
|
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal into full Config with defaults
|
||||||
|
config := Config{
|
||||||
|
HealthCheckTimeout: 120,
|
||||||
|
StartPort: 5800,
|
||||||
|
LogLevel: "info",
|
||||||
|
LogTimeFormat: "",
|
||||||
|
LogToStdout: LogToStdoutProxy,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
CaptureBuffer: 5,
|
||||||
|
GlobalTTL: 0,
|
||||||
|
}
|
||||||
|
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HealthCheckTimeout < 15 {
|
||||||
|
config.HealthCheckTimeout = 15
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults for performance config when section is missing
|
||||||
|
if config.Performance.Every == 0 {
|
||||||
|
config.Performance.Every = 5 * time.Second
|
||||||
|
}
|
||||||
|
if err = config.Performance.Validate(); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("performance: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.StartPort < 1 {
|
||||||
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.GlobalTTL < 0 {
|
||||||
|
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply default for upstream.ignorePaths when not specified. The default
|
||||||
|
// matches common static-asset suffixes so they do not trigger a swap.
|
||||||
|
if len(config.Upstream.IgnorePaths) == 0 {
|
||||||
|
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.LogToStdout {
|
||||||
|
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the aliases map
|
||||||
|
config.aliases = make(map[string]string)
|
||||||
|
for modelName, modelConfig := range config.Models {
|
||||||
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
if _, found := config.aliases[alias]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||||
|
}
|
||||||
|
config.aliases[alias] = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate global macros
|
||||||
|
for _, macro := range config.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get and sort all model IDs for consistent port assignment
|
||||||
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
|
for modelId := range config.Models {
|
||||||
|
modelIds = append(modelIds, modelId)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIds)
|
||||||
|
|
||||||
|
nextPort := config.StartPort
|
||||||
|
for _, modelId := range modelIds {
|
||||||
|
modelConfig := config.Models[modelId]
|
||||||
|
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||||
|
|
||||||
|
// Strip comments from command fields
|
||||||
|
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||||
|
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||||
|
|
||||||
|
// set model TTL to globalTTL it is the default value
|
||||||
|
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||||
|
modelConfig.UnloadAfter = config.GlobalTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelConfig.UnloadAfter < 0 {
|
||||||
|
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate model macros
|
||||||
|
for _, macro := range modelConfig.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
mergedMacros = append(mergedMacros, config.Macros...)
|
||||||
|
|
||||||
|
// Add model macros (override globals with same name)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Substitute remaining macros in model fields (LIFO order)
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||||
|
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute macros in SetParamsByID keys and values
|
||||||
|
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||||
|
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||||
|
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||||
|
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||||
|
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
newParamMap, ok := newValAny.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||||
|
}
|
||||||
|
newSetParamsByID[newKey] = newParamMap
|
||||||
|
}
|
||||||
|
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Substitute in metadata (type-preserving)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle PORT macro - only allocate if cmd uses it
|
||||||
|
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||||
|
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||||
|
if cmdHasPort || proxyHasPort {
|
||||||
|
if !cmdHasPort && proxyHasPort {
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
macroSlug := "${PORT}"
|
||||||
|
macroStr := fmt.Sprintf("%v", nextPort)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||||
|
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||||
|
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate no unknown macros remain
|
||||||
|
fieldMap := map[string]string{
|
||||||
|
"cmd": modelConfig.Cmd,
|
||||||
|
"cmdStop": modelConfig.CmdStop,
|
||||||
|
"proxy": modelConfig.Proxy,
|
||||||
|
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||||
|
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||||
|
"name": modelConfig.Name,
|
||||||
|
"description": modelConfig.Description,
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, fieldValue := range fieldMap {
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
if macroName == "PID" && fieldName == "cmdStop" {
|
||||||
|
continue // replaced at runtime
|
||||||
|
}
|
||||||
|
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||||
|
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate SetParamsByID keys and values
|
||||||
|
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||||
|
}
|
||||||
|
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||||
|
for key := range modelConfig.Filters.SetParamsByID {
|
||||||
|
if key == modelId {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := config.Models[key]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||||
|
}
|
||||||
|
if existingModel, exists := config.aliases[key]; exists {
|
||||||
|
if existingModel != modelId {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||||
|
}
|
||||||
|
continue // already registered as explicit alias for this model
|
||||||
|
}
|
||||||
|
config.aliases[key] = modelId
|
||||||
|
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelConfig.SendLoadingState == nil {
|
||||||
|
v := config.SendLoadingState
|
||||||
|
modelConfig.SendLoadingState = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Models[modelId] = modelConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||||
|
// the new `routing.router` block are mutually exclusive: a config may use
|
||||||
|
// either style, never both.
|
||||||
|
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||||
|
rtr := config.Routing.Router
|
||||||
|
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||||
|
|
||||||
|
if hasTopLevel && hasRouting {
|
||||||
|
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasTopLevel {
|
||||||
|
// Both groups and matrix may be defined under routing.router.settings;
|
||||||
|
// routing.router.use selects which one is active, so there is no conflict.
|
||||||
|
rs := config.Routing.Router.Settings
|
||||||
|
switch config.Routing.Router.Use {
|
||||||
|
case "matrix":
|
||||||
|
if rs.Matrix == nil {
|
||||||
|
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||||
|
}
|
||||||
|
config.Matrix = rs.Matrix
|
||||||
|
case "group", "":
|
||||||
|
config.Groups = rs.Groups
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// groups XOR matrix
|
||||||
|
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Matrix != nil {
|
||||||
|
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||||
|
}
|
||||||
|
config.Matrix.ExpandedSets = expandedSets
|
||||||
|
} else {
|
||||||
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
|
||||||
|
// Validate group members
|
||||||
|
memberUsage := make(map[string]string)
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||||
|
// and new-style configs converge here. The Matrix pointer is shared so
|
||||||
|
// ExpandedSets stays in one place.
|
||||||
|
if config.Matrix != nil {
|
||||||
|
config.Routing.Router.Use = "matrix"
|
||||||
|
} else {
|
||||||
|
config.Routing.Router.Use = "group"
|
||||||
|
}
|
||||||
|
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||||
|
config.Routing.Router.Settings.Groups = config.Groups
|
||||||
|
|
||||||
|
// This fork defaults to the "serial" scheduler: one model loaded at a time,
|
||||||
|
// requests served in strict arrival order. Set use: fifo for the upstream
|
||||||
|
// throughput-oriented behavior that batches same-model requests.
|
||||||
|
if config.Routing.Scheduler.Use == "" {
|
||||||
|
config.Routing.Scheduler.Use = "serial"
|
||||||
|
}
|
||||||
|
switch config.Routing.Scheduler.Use {
|
||||||
|
case "fifo", "serial":
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo, serial)", config.Routing.Scheduler.Use)
|
||||||
|
}
|
||||||
|
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||||
|
if _, found := config.RealModelName(modelID); !found {
|
||||||
|
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up hooks preload
|
||||||
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
var toPreload []string
|
||||||
|
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if real, found := config.RealModelName(modelID); found {
|
||||||
|
toPreload = append(toPreload, real)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Hooks.OnStartup.Preload = toPreload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate API keys (env macros already substituted at string level)
|
||||||
|
for i, apikey := range config.RequiredAPIKeys {
|
||||||
|
if apikey == "" {
|
||||||
|
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||||
|
}
|
||||||
|
if strings.Contains(apikey, " ") {
|
||||||
|
return Config{}, fmt.Errorf("apiKeys[%d]: api key cannot contain spaces", i)
|
||||||
|
}
|
||||||
|
config.RequiredAPIKeys[i] = apikey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process peers with global macro substitution
|
||||||
|
for peerName, peerConfig := range config.Peers {
|
||||||
|
// Substitute global macros (LIFO order)
|
||||||
|
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||||
|
entry := config.Macros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||||
|
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in setParams (type-preserving)
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||||
|
}
|
||||||
|
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate no unknown macros remain
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Peers[peerName] = peerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||||
|
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||||
|
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// validateMacro validates macro name and value constraints
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
if len(name) >= 64 {
|
||||||
|
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||||
|
}
|
||||||
|
if !macroNameRegex.MatchString(name) {
|
||||||
|
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that value is a scalar type
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check for self-reference
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||||
|
// These types are allowed
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "PORT", "MODEL_ID":
|
||||||
|
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||||
|
func validateNestedForUnknownMacros(value any, context string) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||||
|
}
|
||||||
|
// Check for unsubstituted env macros
|
||||||
|
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range envMatches {
|
||||||
|
varName := match[1]
|
||||||
|
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Scalar types don't contain macros
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||||
|
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||||
|
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||||
|
// (which strips comments) and only checking the comment-free version for macros.
|
||||||
|
func substituteEnvMacros(s string) (string, error) {
|
||||||
|
// Unmarshal and remarshal to strip YAML comments
|
||||||
|
var raw any
|
||||||
|
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||||
|
// If YAML is invalid, fall back to scanning the original string
|
||||||
|
// so the user gets the env var error rather than a confusing YAML parse error
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
clean, err := yaml.Marshal(raw)
|
||||||
|
if err != nil {
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return substituteEnvMacrosInString(s, string(clean))
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||||
|
// them in target. This separation allows scanning comment-free YAML while
|
||||||
|
// substituting in the original string.
|
||||||
|
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||||
|
result := target
|
||||||
|
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
fullMatch := match[0] // ${env.VAR_NAME}
|
||||||
|
varName := match[1] // VAR_NAME
|
||||||
|
|
||||||
|
value, exists := os.LookupEnv(varName)
|
||||||
|
if !exists {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the value for safe YAML substitution
|
||||||
|
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = strings.ReplaceAll(result, fullMatch, value)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||||
|
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||||
|
// for compatibility with double-quoted YAML strings.
|
||||||
|
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||||
|
// Reject values that would break YAML structure regardless of quoting context
|
||||||
|
if strings.ContainsAny(value, "\n\r\x00") {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||||
|
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||||
|
// In double-quoted contexts, they are interpreted correctly.
|
||||||
|
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||||
|
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,300 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// identityMapPaths is the set of dotted paths whose direct children are
|
||||||
|
// identity-keyed maps. A child key present in two sources is a hard error;
|
||||||
|
// such keys name discrete entities (a model, a group, a peer, etc.) and a
|
||||||
|
// duplicate means the user has split one entity across files by mistake.
|
||||||
|
var identityMapPaths = map[string]bool{
|
||||||
|
"models": true,
|
||||||
|
"groups": true,
|
||||||
|
"profiles": true,
|
||||||
|
"peers": true,
|
||||||
|
"matrix": true,
|
||||||
|
"routing.router.settings.groups": true,
|
||||||
|
"routing.router.settings.matrix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfigSources loads and merges configuration from -config (optional)
|
||||||
|
// and -config-dir (optional). At least one must be provided. The -config file
|
||||||
|
// is loaded first; *.yml/*.yaml files directly under -config-dir are then
|
||||||
|
// merged in sorted filename order. The merged document is passed through the
|
||||||
|
// existing LoadConfigFromReader pipeline unchanged.
|
||||||
|
func LoadConfigSources(configPath, configDir string) (Config, error) {
|
||||||
|
if configPath == "" && configDir == "" {
|
||||||
|
return Config{}, fmt.Errorf("at least one of -config or -config-dir must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
var sourcePaths []string
|
||||||
|
|
||||||
|
if configPath != "" {
|
||||||
|
sourcePaths = append(sourcePaths, configPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if configDir != "" {
|
||||||
|
dirFiles, err := listYAMLFiles(configDir)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("-config-dir %s: %w", configDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if configPath != "" {
|
||||||
|
absConfig, err := filepath.Abs(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to resolve -config path: %w", err)
|
||||||
|
}
|
||||||
|
for _, f := range dirFiles {
|
||||||
|
absF, err := filepath.Abs(f)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to resolve config dir file %s: %w", f, err)
|
||||||
|
}
|
||||||
|
if absConfig == absF {
|
||||||
|
return Config{}, fmt.Errorf("-config path %s is also present in -config-dir %s; remove it from one", configPath, configDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sourcePaths = append(sourcePaths, dirFiles...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sourcePaths) == 0 {
|
||||||
|
return Config{}, fmt.Errorf("no configuration sources found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var merged *yaml.Node
|
||||||
|
for _, p := range sourcePaths {
|
||||||
|
node, err := parseSource(p)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
if node == nil {
|
||||||
|
continue // empty file
|
||||||
|
}
|
||||||
|
if merged == nil {
|
||||||
|
merged = node
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := mergeNodes(merged, node, "", p); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if merged == nil {
|
||||||
|
// All sources were empty; run the pipeline on empty input so defaults
|
||||||
|
// and validation still apply (e.g. startPort, performance defaults).
|
||||||
|
return LoadConfigFromReader(strings.NewReader(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := yaml.Marshal(merged)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to marshal merged config: %w", err)
|
||||||
|
}
|
||||||
|
return LoadConfigFromReader(strings.NewReader(string(out)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// listYAMLFiles returns the top-level *.yml and *.yaml files in dir, sorted by
|
||||||
|
// filename for deterministic merge order. Subdirectories are not traversed.
|
||||||
|
func listYAMLFiles(dir string) ([]string, error) {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var files []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := e.Name()
|
||||||
|
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
files = append(files, filepath.Join(dir, name))
|
||||||
|
}
|
||||||
|
sort.Strings(files)
|
||||||
|
return files, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSource reads and parses one YAML config file into a root mapping node.
|
||||||
|
// Returns a nil node (no error) when the file is empty or contains only
|
||||||
|
// comments.
|
||||||
|
//
|
||||||
|
// Env macros (${env.VAR}) are substituted at the string level before YAML
|
||||||
|
// parsing so that flow-style constructs like [${env.API_KEY}] parse
|
||||||
|
// correctly — the brace would otherwise be interpreted as a flow mapping.
|
||||||
|
func parseSource(path string) (*yaml.Node, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
yamlStr, err := substituteEnvMacros(string(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
var doc yaml.Node
|
||||||
|
if err := yaml.Unmarshal([]byte(yamlStr), &doc); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
// yaml.Unmarshal into a yaml.Node yields a DocumentNode whose Content[0]
|
||||||
|
// is the actual root. Unwrap it so callers see the real top-level node.
|
||||||
|
root := &doc
|
||||||
|
if root.Kind == yaml.DocumentNode && len(root.Content) > 0 {
|
||||||
|
root = root.Content[0]
|
||||||
|
}
|
||||||
|
if root.Kind == 0 || root.Content == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if root.Kind != yaml.MappingNode {
|
||||||
|
return nil, fmt.Errorf("config %s: top-level YAML must be a mapping", path)
|
||||||
|
}
|
||||||
|
return root, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeNodes merges src into dst (both MappingNodes) in place. Keys present in
|
||||||
|
// only one side are kept; shared keys are merged recursively under the rules
|
||||||
|
// in mergeValue. srcPath is included in error messages to identify the file
|
||||||
|
// that introduced the conflict.
|
||||||
|
func mergeNodes(dst, src *yaml.Node, path, srcPath string) error {
|
||||||
|
srcIdx := indexMapping(src)
|
||||||
|
|
||||||
|
// First pass: merge shared keys in place.
|
||||||
|
for i := 0; i+1 < len(dst.Content); i += 2 {
|
||||||
|
keyNode := dst.Content[i]
|
||||||
|
dstVal := dst.Content[i+1]
|
||||||
|
key := keyNode.Value
|
||||||
|
|
||||||
|
srcVal, ok := srcIdx[key]
|
||||||
|
if !ok {
|
||||||
|
continue // dst-only key, keep as-is
|
||||||
|
}
|
||||||
|
|
||||||
|
childPath := joinPath(path, key)
|
||||||
|
|
||||||
|
if identityMapPaths[childPath] {
|
||||||
|
// Identity-keyed map: each child key names a discrete entity
|
||||||
|
// (a model, group, peer, ...). A shared child key is a hard
|
||||||
|
// error; src-only children are appended in the second pass.
|
||||||
|
if err := mergeIdentityMap(dstVal, srcVal, childPath, key, srcPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mergeValue(dstVal, srcVal, childPath, srcPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: append src-only keys.
|
||||||
|
dstIdx := indexMapping(dst)
|
||||||
|
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||||
|
keyNode := src.Content[i]
|
||||||
|
srcVal := src.Content[i+1]
|
||||||
|
key := keyNode.Value
|
||||||
|
|
||||||
|
if _, ok := dstIdx[key]; ok {
|
||||||
|
continue // already merged above
|
||||||
|
}
|
||||||
|
keyCopy := *keyNode
|
||||||
|
valCopy := *srcVal
|
||||||
|
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeIdentityMap merges two identity-keyed mapping nodes (e.g. `models`,
|
||||||
|
// `groups`, `peers`). Any child key present in both sides is a duplicate
|
||||||
|
// entity and produces an error naming the conflicting key and source file.
|
||||||
|
// src-only keys are appended to dst.
|
||||||
|
func mergeIdentityMap(dst, src *yaml.Node, path, mapName, srcPath string) error {
|
||||||
|
if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode {
|
||||||
|
return fmt.Errorf("conflict at %q: expected a mapping, introduced by %s", path, srcPath)
|
||||||
|
}
|
||||||
|
dstIdx := indexMapping(dst)
|
||||||
|
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||||
|
keyNode := src.Content[i]
|
||||||
|
srcVal := src.Content[i+1]
|
||||||
|
key := keyNode.Value
|
||||||
|
if _, dup := dstIdx[key]; dup {
|
||||||
|
return fmt.Errorf("duplicate %s %q found in %s (already defined in another config source)", mapName, key, srcPath)
|
||||||
|
}
|
||||||
|
keyCopy := *keyNode
|
||||||
|
valCopy := *srcVal
|
||||||
|
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeValue merges srcVal into dstVal (both pointing into the parent's
|
||||||
|
// Content slice). Mapping+Mapping recurses; Sequence+Sequence concatenates;
|
||||||
|
// Scalar+Scalar errors on value mismatch; null on either side yields to the
|
||||||
|
// non-null side.
|
||||||
|
func mergeValue(dstVal, srcVal *yaml.Node, path, srcPath string) error {
|
||||||
|
switch {
|
||||||
|
case dstVal.Kind == yaml.MappingNode && srcVal.Kind == yaml.MappingNode:
|
||||||
|
return mergeNodes(dstVal, srcVal, path, srcPath)
|
||||||
|
|
||||||
|
case dstVal.Kind == yaml.SequenceNode && srcVal.Kind == yaml.SequenceNode:
|
||||||
|
dstVal.Content = append(dstVal.Content, srcVal.Content...)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case dstVal.Kind == yaml.ScalarNode && srcVal.Kind == yaml.ScalarNode:
|
||||||
|
if isNullScalar(dstVal) {
|
||||||
|
*dstVal = *srcVal
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if isNullScalar(srcVal) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if dstVal.Value == srcVal.Value && dstVal.Tag == srcVal.Tag {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("conflict at %q: %s sets a different value than a previous source", path, srcPath)
|
||||||
|
|
||||||
|
case isNull(dstVal):
|
||||||
|
*dstVal = *srcVal
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case isNull(srcVal):
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("conflict at %q: incompatible YAML node kinds (kind %d vs %d) introduced by %s", path, dstVal.Kind, srcVal.Kind, srcPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNull reports whether n represents a YAML null (empty or !!null).
|
||||||
|
func isNull(n *yaml.Node) bool {
|
||||||
|
if n == nil || n.Kind == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return isNullScalar(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNullScalar(n *yaml.Node) bool {
|
||||||
|
return n.Kind == yaml.ScalarNode && (n.Tag == "!!null" || n.Tag == "") && n.Value == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// indexMapping builds a key -> value-node index for a mapping node.
|
||||||
|
func indexMapping(n *yaml.Node) map[string]*yaml.Node {
|
||||||
|
idx := make(map[string]*yaml.Node, len(n.Content)/2)
|
||||||
|
for i := 0; i+1 < len(n.Content); i += 2 {
|
||||||
|
idx[n.Content[i].Value] = n.Content[i+1]
|
||||||
|
}
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinPath(parent, key string) string {
|
||||||
|
if parent == "" {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
return parent + "." + key
|
||||||
|
}
|
||||||
@@ -0,0 +1,304 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeYAML writes content to a file named name inside dir. Returns the full
|
||||||
|
// path of the written file.
|
||||||
|
func writeYAML(t *testing.T, dir, name, content string) string {
|
||||||
|
t.Helper()
|
||||||
|
p := filepath.Join(dir, name)
|
||||||
|
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(p, []byte(content), 0o644))
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelCfg builds a single-model YAML snippet indented for nesting under a
|
||||||
|
// `models:` key. The proxy uses a fixed port so tests don't depend on
|
||||||
|
// ${PORT} allocation.
|
||||||
|
func modelCfg(id, cmd string) string {
|
||||||
|
return " " + id + ":\n cmd: " + cmd + "\n proxy: \"http://localhost:9999\"\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_NeitherProvided(t *testing.T) {
|
||||||
|
_, err := LoadConfigSources("", "")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "at least one of -config or -config-dir")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ConfigOnly(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "config.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("model1", "echo hi")+`
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
members: ["model1"]
|
||||||
|
`)
|
||||||
|
cfg, err := LoadConfigSources(cfgPath, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, id, ok := cfg.FindConfig("model1")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "model1", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DirOnly(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("alpha", "echo a"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("beta", "echo b"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, want := range []string{"alpha", "beta"} {
|
||||||
|
_, _, ok := cfg.FindConfig(want)
|
||||||
|
assert.True(t, ok, "model %s should be present", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ConfigPlusDirAdditive(t *testing.T) {
|
||||||
|
// -config lives outside -config-dir; both contribute models additively.
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "config.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||||
|
cfgDir := t.TempDir()
|
||||||
|
writeYAML(t, cfgDir, "extra.yaml", "models:\n"+modelCfg("ext", "echo ext"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, want := range []string{"base", "ext"} {
|
||||||
|
_, _, ok := cfg.FindConfig(want)
|
||||||
|
assert.True(t, ok, "model %s should be present after merge", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadConfigSources_ConfigInDirOverlap verifies that a -config file that
|
||||||
|
// is also a member of -config-dir is rejected.
|
||||||
|
func TestLoadConfigSources_ConfigInDirOverlap(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||||
|
|
||||||
|
_, err := LoadConfigSources(cfgPath, dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "is also present in -config-dir")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DuplicateModelID(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("dup", "echo a"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("dup", "echo b"))
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `duplicate models "dup"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DuplicateGroupID(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m1", "echo m1")+"groups:\n g1:\n members: [m1]\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m2", "echo m2")+"groups:\n g1:\n members: [m2]\n")
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `duplicate groups "g1"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DuplicatePeer(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
peerA := "peers:\n remote:\n proxy: http://x:1\n models: [m1]\n"
|
||||||
|
peerB := "peers:\n remote:\n proxy: http://x:2\n models: [m2]\n"
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\n"+peerA)
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\n"+peerB)
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `duplicate peers "remote"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ScalarConflict(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 200\n")
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `conflict at "globalTTL"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ScalarSameValueNoConflict(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 100\n")
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 100, cfg.GlobalTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_MacrosConcatenate(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "macros:\n LOW: 1\nmodels:\n"+modelCfg("m1", "echo ${LOW}"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "macros:\n HIGH: 2\nmodels:\n"+modelCfg("m2", "echo ${HIGH}"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Both macros are available globally after merge.
|
||||||
|
low, ok := cfg.Macros.Get("LOW")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, 1, low)
|
||||||
|
high, ok := cfg.Macros.Get("HIGH")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, 2, high)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_APIKeysConcatenate(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\napiKeys: [key-a]\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\napiKeys: [key-b]\n")
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.ElementsMatch(t, []string{"key-a", "key-b"}, cfg.RequiredAPIKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_RoutingGroupsMerge(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m1", "echo m1")+`
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
settings:
|
||||||
|
groups:
|
||||||
|
groupA:
|
||||||
|
members: [m1]
|
||||||
|
`)
|
||||||
|
writeYAML(t, dir, "b.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m2", "echo m2")+`
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
settings:
|
||||||
|
groups:
|
||||||
|
groupB:
|
||||||
|
members: [m2]
|
||||||
|
`)
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
groups := cfg.Routing.Router.Settings.Groups
|
||||||
|
assert.Contains(t, groups, "groupA")
|
||||||
|
assert.Contains(t, groups, "groupB")
|
||||||
|
// default group added by pipeline for orphaned/leftover routing groups...
|
||||||
|
// here both groups reference distinct models
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EnvMacrosSubstituted(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
// Use ${PORT} in cmd so the pipeline allocates a port and substitutes it;
|
||||||
|
// verifies env/macro substitution runs on the merged document.
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: serve --port ${PORT}\n proxy: \"http://localhost:${PORT}\"\n")
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m := cfg.Models["m1"]
|
||||||
|
assert.NotContains(t, m.Cmd, "${PORT}", "PORT macro should have been substituted")
|
||||||
|
assert.NotContains(t, m.Proxy, "${PORT}", "PORT macro should have been substituted in proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EnvMacroInFlowStyleList(t *testing.T) {
|
||||||
|
// Regression: flow-style lists with ${env.*} must parse. Previously
|
||||||
|
// parseSource unmarshalled before env substitution, so the brace in
|
||||||
|
// [${env.API_KEY}] was misread as a flow mapping and parsing failed.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: echo hi\n proxy: \"http://localhost:9999\"\n")
|
||||||
|
writeYAML(t, dir, "keys.yaml", "apiKeys: [${env.TEST_API_KEY}]\nmodels:\n m2:\n cmd: echo hi\n proxy: \"http://localhost:9998\"\n")
|
||||||
|
|
||||||
|
t.Setenv("TEST_API_KEY", "secret123")
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, cfg.RequiredAPIKeys, "secret123")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_SortedOrderDeterministic(t *testing.T) {
|
||||||
|
// Two files defining distinct models, scanned in z..a order by filename.
|
||||||
|
// Determine merged result is the same regardless of how the FS returns them.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "z.yaml", "models:\n"+modelCfg("zmodel", "echo z"))
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("amodel", "echo a"))
|
||||||
|
|
||||||
|
const runs = 3
|
||||||
|
for i := 0; i < runs; i++ {
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// startPort-based allocation: first allocated model gets 5800.
|
||||||
|
// Sorted order means amodel gets 5800, zmodel gets 5801.
|
||||||
|
_, _, ok := cfg.FindConfig("amodel")
|
||||||
|
assert.True(t, ok)
|
||||||
|
_, _, ok = cfg.FindConfig("zmodel")
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EmptyDirWithConfig(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgDir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, cfg.Models, "m1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EmptyDirOnly(t *testing.T) {
|
||||||
|
// An empty -config-dir with no -config is an error: there is nothing to
|
||||||
|
// load and silently producing an empty config would mask the misconfig.
|
||||||
|
cfgDir := t.TempDir()
|
||||||
|
_, err := LoadConfigSources("", cfgDir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no configuration sources found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_AssertNoUnknownMacrosAfterMerge(t *testing.T) {
|
||||||
|
// Macros defined in one file should not satisfy unknown-macro validation in
|
||||||
|
// another — they do, because merge concats global macros before validation
|
||||||
|
// runs. This test documents that a macro from file A is usable in file B.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "macros.yaml", "macros:\n SHARED: hello\nmodels:\n"+modelCfg("dummy", "echo dummy"))
|
||||||
|
writeYAML(t, dir, "use.yaml", "models:\n"+modelCfg("user", "echo ${SHARED}"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m := cfg.Models["user"]
|
||||||
|
assert.Contains(t, m.Cmd, "hello")
|
||||||
|
assert.NotContains(t, m.Cmd, "${SHARED}")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_KindMismatchErrors(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "startPort: 5800\nmodels:\n"+modelCfg("m1", "echo m1"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "startPort: [5800, 5801]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "incompatible YAML node kinds")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_NullYieldsToValue(t *testing.T) {
|
||||||
|
// File A: routing.router block absent (null on root for routing);
|
||||||
|
// file B: defines routing.router.settings.groups. Merge should keep B's.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "routing:\n router:\n settings:\n groups:\n g1:\n members: [m1]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultUpstreamIgnorePathsPattern is the default regular expression applied
|
||||||
|
// to upstream.ignorePaths when the section is empty or absent from the config.
|
||||||
|
// It matches common static-asset suffixes so requests for .js/.css/.png/etc.
|
||||||
|
// files do not trigger a model swap.
|
||||||
|
const DefaultUpstreamIgnorePathsPattern = `.*\.(js|json|css|png|gif|jpg|jpeg|ico|txt)$`
|
||||||
|
|
||||||
|
// DefaultUpstreamIgnorePaths returns the default compiled ignore paths used
|
||||||
|
// when upstream.ignorePaths is not specified in the config. The returned slice
|
||||||
|
// is fresh so callers may mutate it without affecting other configs.
|
||||||
|
func DefaultUpstreamIgnorePaths() []*regexp.Regexp {
|
||||||
|
return []*regexp.Regexp{regexp.MustCompile(DefaultUpstreamIgnorePathsPattern)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamConfig controls behaviour of the /upstream passthrough endpoint.
|
||||||
|
type UpstreamConfig struct {
|
||||||
|
// IgnorePaths is a slice of compiled regular expressions. Any request to
|
||||||
|
// /upstream/<model>/<path> whose remaining path matches any of these
|
||||||
|
// expressions will be ignored and not trigger a swap. When the config
|
||||||
|
// does not specify any patterns, DefaultUpstreamIgnorePaths is applied.
|
||||||
|
IgnorePaths []*regexp.Regexp `yaml:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawUpstreamConfig is the intermediate form used to unmarshal the YAML into
|
||||||
|
// plain strings, which are then compiled into *regexp.Regexp.
|
||||||
|
type rawUpstreamConfig struct {
|
||||||
|
IgnorePaths []string `yaml:"ignorePaths"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalYAML compiles each ignorePaths entry into a *regexp.Regexp. If any
|
||||||
|
// entry fails to compile, an error is returned.
|
||||||
|
func (u *UpstreamConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
var raw rawUpstreamConfig
|
||||||
|
if err := value.Decode(&raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
patterns := make([]*regexp.Regexp, 0, len(raw.IgnorePaths))
|
||||||
|
for _, p := range raw.IgnorePaths {
|
||||||
|
re, err := regexp.Compile(p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("upstream.ignorePaths: invalid regular expression %q: %w", p, err)
|
||||||
|
}
|
||||||
|
patterns = append(patterns, re)
|
||||||
|
}
|
||||||
|
u.IgnorePaths = patterns
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const upstreamConfigHeader = `
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8080"
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_DefaultWhenAbsent(t *testing.T) {
|
||||||
|
// When upstream is not specified at all, the default pattern is applied.
|
||||||
|
content := upstreamConfigHeader
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||||
|
|
||||||
|
def := cfg.Upstream.IgnorePaths[0]
|
||||||
|
assert.IsType(t, ®exp.Regexp{}, def)
|
||||||
|
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, def.String())
|
||||||
|
|
||||||
|
// The default matches common static-asset suffixes.
|
||||||
|
assert.True(t, def.MatchString("/foo.js"))
|
||||||
|
assert.True(t, def.MatchString("/bar/baz.json"))
|
||||||
|
assert.True(t, def.MatchString("/static/img.png"))
|
||||||
|
assert.True(t, def.MatchString("/notes.txt"))
|
||||||
|
assert.True(t, def.MatchString("/favicon.ico"))
|
||||||
|
// And does not match inference API paths.
|
||||||
|
assert.False(t, def.MatchString("/v1/chat/completions"))
|
||||||
|
assert.False(t, def.MatchString("/v1/models"))
|
||||||
|
assert.False(t, def.MatchString("/health"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_DefaultWhenSectionEmpty(t *testing.T) {
|
||||||
|
// When upstream is present but ignorePaths is omitted, the default is still
|
||||||
|
// applied.
|
||||||
|
content := `upstream: {}` + "\n" + upstreamConfigHeader
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, cfg.Upstream.IgnorePaths, 1)
|
||||||
|
assert.Equal(t, DefaultUpstreamIgnorePathsPattern, cfg.Upstream.IgnorePaths[0].String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_Compiles(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
upstream:
|
||||||
|
ignorePaths:
|
||||||
|
- ".*\\.(js|json|css|png|gif|jpg|jpeg|txt)$"
|
||||||
|
- "^/static/.*"
|
||||||
|
` + upstreamConfigHeader
|
||||||
|
|
||||||
|
cfg, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, cfg.Upstream.IgnorePaths, 2)
|
||||||
|
|
||||||
|
// Verify the patterns are compiled into *regexp.Regexp and match as expected.
|
||||||
|
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/foo.js"))
|
||||||
|
assert.True(t, cfg.Upstream.IgnorePaths[0].MatchString("/bar/baz.json"))
|
||||||
|
assert.False(t, cfg.Upstream.IgnorePaths[0].MatchString("/v1/chat/completions"))
|
||||||
|
assert.True(t, cfg.Upstream.IgnorePaths[1].MatchString("/static/foo.png"))
|
||||||
|
assert.False(t, cfg.Upstream.IgnorePaths[1].MatchString("/v1/chat/completions"))
|
||||||
|
|
||||||
|
// Confirm the type is *regexp.Regexp to satisfy the API contract.
|
||||||
|
for _, re := range cfg.Upstream.IgnorePaths {
|
||||||
|
assert.IsType(t, ®exp.Regexp{}, re)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_UpstreamIgnorePaths_InvalidRegexReturnsError(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
upstream:
|
||||||
|
ignorePaths:
|
||||||
|
- "[invalid("
|
||||||
|
` + upstreamConfigHeader
|
||||||
|
|
||||||
|
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "upstream.ignorePaths")
|
||||||
|
assert.Contains(t, err.Error(), "invalid regular expression")
|
||||||
|
}
|
||||||
@@ -92,9 +92,14 @@ type Effects interface {
|
|||||||
StopProcesses(timeout time.Duration, ids []string)
|
StopProcesses(timeout time.Duration, ids []string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured
|
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured from
|
||||||
// from conf and bound to the given planner and effects. Currently only "fifo"
|
// conf and bound to the given planner and effects. Supported values are "fifo"
|
||||||
// (the default) is supported.
|
// (throughput-oriented, batches same-model requests) and "serial" (strict
|
||||||
|
// one-model-at-a-time, exact arrival order).
|
||||||
|
//
|
||||||
|
// The deployment default is applied by config loading (LoadConfig sets Use to
|
||||||
|
// "serial" when unset). The "" fallback here is the library default and remains
|
||||||
|
// "fifo" so callers that build a Config directly keep the original behavior.
|
||||||
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
|
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
|
||||||
use := conf.Routing.Scheduler.Use
|
use := conf.Routing.Scheduler.Use
|
||||||
if use == "" {
|
if use == "" {
|
||||||
@@ -103,6 +108,9 @@ func New(conf config.Config, name string, logger *logmon.Monitor, planner Swappe
|
|||||||
switch use {
|
switch use {
|
||||||
case "fifo":
|
case "fifo":
|
||||||
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
|
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
|
||||||
|
case "serial":
|
||||||
|
// Serial ignores the group planner: it always evicts every other model.
|
||||||
|
return NewSerial(name, logger, eff), nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
|
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,253 @@
|
|||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Serial is a strict one-model-at-a-time scheduler. Unlike FIFO it never reorders
|
||||||
|
// or batches: requests run in exact arrival order and at most one request runs at
|
||||||
|
// any instant. When the next request targets a model other than the one loaded,
|
||||||
|
// every other running model is evicted and the target is loaded before it runs,
|
||||||
|
// so a single model occupies memory at a time — at the cost of throughput.
|
||||||
|
//
|
||||||
|
// Example: A B C A is served as A B C A. The final A reloads its model even
|
||||||
|
// though it ran first, because B and C displaced it in between. (FIFO, by
|
||||||
|
// contrast, would batch the two A requests: A A B C.)
|
||||||
|
//
|
||||||
|
// Serial ignores group/eviction policy entirely: it always evicts every other
|
||||||
|
// running model, regardless of how groups are configured. That is what makes the
|
||||||
|
// single-model guarantee a property of the scheduler rather than of the config.
|
||||||
|
//
|
||||||
|
// Like FIFO, every method runs on the router's single run-loop goroutine, so no
|
||||||
|
// internal locking is needed.
|
||||||
|
type Serial struct {
|
||||||
|
name string
|
||||||
|
logger *logmon.Monitor
|
||||||
|
effects Effects
|
||||||
|
|
||||||
|
// queued holds requests in strict arrival order. It is never reordered.
|
||||||
|
queued []HandlerReq
|
||||||
|
|
||||||
|
// active is the one request currently being processed (loading or serving),
|
||||||
|
// or nil when idle. phase is meaningful only while active != nil.
|
||||||
|
active *HandlerReq
|
||||||
|
phase serialPhase
|
||||||
|
}
|
||||||
|
|
||||||
|
// serialPhase is the lifecycle stage of the active request.
|
||||||
|
type serialPhase int
|
||||||
|
|
||||||
|
const (
|
||||||
|
phaseIdle serialPhase = iota
|
||||||
|
phaseSwapping // waiting for OnSwapDone for active.Model
|
||||||
|
phaseServing // waiting for OnServeDone for active.Model
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewSerial builds a Serial scheduler. It takes no Swapper: eviction is always
|
||||||
|
// "stop every other running model", so the group planner is not consulted.
|
||||||
|
func NewSerial(name string, logger *logmon.Monitor, eff Effects) *Serial {
|
||||||
|
return &Serial{
|
||||||
|
name: name,
|
||||||
|
logger: logger,
|
||||||
|
effects: eff,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnRequest validates the model and appends the request to the tail of the queue,
|
||||||
|
// then tries to start the next job. Unknown models fail immediately.
|
||||||
|
func (s *Serial) OnRequest(req HandlerReq) {
|
||||||
|
if _, ok := s.effects.ModelState(req.Model); !ok {
|
||||||
|
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
|
||||||
|
s.effects.GrantError(req, ErrModelNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.queued = append(s.queued, req)
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
s.startNext()
|
||||||
|
}
|
||||||
|
|
||||||
|
// startNext begins processing the head of the queue when nothing is active. It
|
||||||
|
// fast-paths a request whose model is already the sole loaded-and-ready process;
|
||||||
|
// otherwise it launches a swap that evicts every other running model first. The
|
||||||
|
// loop skips over requests for models that vanished (e.g. a config reload) and
|
||||||
|
// requests whose caller disconnected before they could be served.
|
||||||
|
func (s *Serial) startNext() {
|
||||||
|
if s.active != nil {
|
||||||
|
return // a job is already loading or serving
|
||||||
|
}
|
||||||
|
for len(s.queued) > 0 {
|
||||||
|
req := s.queued[0]
|
||||||
|
s.queued = s.queued[1:]
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
|
||||||
|
state, ok := s.effects.ModelState(req.Model)
|
||||||
|
if !ok {
|
||||||
|
s.effects.GrantError(req, ErrModelNotFound)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
r := req
|
||||||
|
s.active = &r
|
||||||
|
|
||||||
|
evict := s.otherRunning(req.Model)
|
||||||
|
if state == process.StateReady && len(evict) == 0 {
|
||||||
|
// Already loaded and the only model running — serve immediately.
|
||||||
|
s.logger.Debugf("%s: serving model %s (already loaded)", s.name, req.Model)
|
||||||
|
if s.serve() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue // caller gone; pick the next request
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Debugf("%s: swapping to model %s, evicting %v", s.name, req.Model, evict)
|
||||||
|
s.phase = phaseSwapping
|
||||||
|
s.effects.StartSwap(req.Model, evict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// serve hands the active request its tracked handler. It returns true when the
|
||||||
|
// request is now serving (await OnServeDone); false when the caller had already
|
||||||
|
// disconnected, in which case active is cleared so the next job can start.
|
||||||
|
func (s *Serial) serve() bool {
|
||||||
|
if s.effects.GrantServe(*s.active, s.active.Model) {
|
||||||
|
s.phase = phaseServing
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
s.logger.Debugf("%s: caller for model %s gone before serve", s.name, s.active.Model)
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnSwapDone fires when the load for the active request completes. On success the
|
||||||
|
// request is served; on failure its caller receives the error and the queue
|
||||||
|
// advances. A SwapDone that does not match the active load (e.g. its request was
|
||||||
|
// unloaded or cancelled mid-load) is ignored.
|
||||||
|
func (s *Serial) OnSwapDone(ev SwapDone) {
|
||||||
|
if s.active == nil || s.phase != phaseSwapping || s.active.Model != ev.ModelID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ev.Err != nil {
|
||||||
|
s.logger.Debugf("%s: swap for model %s failed: %v", s.name, ev.ModelID, ev.Err)
|
||||||
|
s.effects.GrantError(*s.active, ev.Err)
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
s.startNext()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.serve() {
|
||||||
|
s.startNext() // caller vanished while the model loaded; move on
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnServeDone fires when the active request's handler returns. The slot is freed
|
||||||
|
// and the next queued request begins.
|
||||||
|
func (s *Serial) OnServeDone(ev ServeDoneEvent) {
|
||||||
|
if s.active == nil || s.phase != phaseServing {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
s.startNext()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnCancel removes a disconnected client's request from the queue. A request that
|
||||||
|
// is already active is left to finish: if it was loading, OnSwapDone's serve()
|
||||||
|
// will find the caller gone (GrantServe false) and advance; if it was serving,
|
||||||
|
// its handler returns normally and reaches OnServeDone.
|
||||||
|
func (s *Serial) OnCancel(req HandlerReq) {
|
||||||
|
if len(s.queued) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
kept := s.queued[:0]
|
||||||
|
removed := false
|
||||||
|
for _, q := range s.queued {
|
||||||
|
if q.Respond == req.Respond {
|
||||||
|
removed = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, q)
|
||||||
|
}
|
||||||
|
s.queued = kept
|
||||||
|
if removed {
|
||||||
|
s.logger.Debugf("%s: cancelled request for model %s pruned from queue", s.name, req.Model)
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnUnload reconciles state for an unload, stops the targeted processes, and
|
||||||
|
// advances the queue. It mirrors the FIFO contract: queued requests for unloaded
|
||||||
|
// models are failed; an active *loading* request for an unloaded model is failed
|
||||||
|
// (its swap goroutine is left to finish and its SwapDone is then ignored); an
|
||||||
|
// active *serving* request is left for its handler to end when StopProcesses
|
||||||
|
// kills the upstream. The Stop is synchronous so callers of Unload can rely on
|
||||||
|
// the processes being stopped on return.
|
||||||
|
func (s *Serial) OnUnload(targets []string, timeout time.Duration) {
|
||||||
|
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
|
||||||
|
|
||||||
|
targetSet := make(map[string]bool, len(targets))
|
||||||
|
for _, id := range targets {
|
||||||
|
targetSet[id] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.active != nil && s.phase == phaseSwapping && targetSet[s.active.Model] {
|
||||||
|
s.effects.GrantError(*s.active, unloadErr)
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.queued) > 0 {
|
||||||
|
kept := s.queued[:0]
|
||||||
|
for _, q := range s.queued {
|
||||||
|
if targetSet[q.Model] {
|
||||||
|
s.effects.GrantError(q, unloadErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, q)
|
||||||
|
}
|
||||||
|
s.queued = kept
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.effects.StopProcesses(timeout, targets)
|
||||||
|
|
||||||
|
// A still-serving active request advances via OnServeDone when its killed
|
||||||
|
// handler returns; only start the next job when nothing is active now.
|
||||||
|
if s.active == nil {
|
||||||
|
s.startNext()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnShutdown grants err to every request the scheduler still holds: an active
|
||||||
|
// loading request and all queued requests. A serving request is torn down with
|
||||||
|
// its process by the baseRouter.
|
||||||
|
func (s *Serial) OnShutdown(err error) {
|
||||||
|
if s.active != nil && s.phase == phaseSwapping {
|
||||||
|
s.effects.GrantError(*s.active, err)
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
}
|
||||||
|
for _, q := range s.queued {
|
||||||
|
s.effects.GrantError(q, err)
|
||||||
|
}
|
||||||
|
s.queued = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherRunning returns every running model except target, sorted for
|
||||||
|
// deterministic eviction.
|
||||||
|
func (s *Serial) otherRunning(target string) []string {
|
||||||
|
var out []string
|
||||||
|
for id := range s.effects.RunningModels() {
|
||||||
|
if id != target {
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(out)
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -0,0 +1,391 @@
|
|||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Serial methods all run on the router's single run-loop goroutine, so these
|
||||||
|
// tests drive them directly and synchronously, reusing fakeEffects and the
|
||||||
|
// req/reqCh helpers from fifo_test.go. A load completes via OnSwapDone and a
|
||||||
|
// served request finishes via OnServeDone — the events the run loop delivers.
|
||||||
|
|
||||||
|
func newSerial(eff Effects) *Serial {
|
||||||
|
return NewSerial("test", logmon.NewWriter(io.Discard), eff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lastStart returns the most recent StartSwap record.
|
||||||
|
func lastStart(t *testing.T, eff *fakeEffects) startRec {
|
||||||
|
t.Helper()
|
||||||
|
if len(eff.starts) == 0 {
|
||||||
|
t.Fatal("no StartSwap recorded")
|
||||||
|
}
|
||||||
|
return eff.starts[len(eff.starts)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func sameSet(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m := map[string]int{}
|
||||||
|
for _, x := range a {
|
||||||
|
m[x]++
|
||||||
|
}
|
||||||
|
for _, x := range b {
|
||||||
|
m[x]--
|
||||||
|
}
|
||||||
|
for _, v := range m {
|
||||||
|
if v != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// servedOrder returns the model IDs of every successful serve grant in order.
|
||||||
|
func servedOrder(eff *fakeEffects) []string {
|
||||||
|
var out []string
|
||||||
|
for _, g := range eff.grants {
|
||||||
|
if g.err == nil && g.serve {
|
||||||
|
out = append(out, g.model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_FastPath_AlreadyLoaded(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
|
||||||
|
if got := len(eff.starts); got != 0 {
|
||||||
|
t.Errorf("StartSwap calls=%d want 0 (already loaded, no swap)", got)
|
||||||
|
}
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Errorf("served(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_ColdStart_LoadsThenServes(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
if got := eff.startsFor("a"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := eff.served("a"); got != 0 {
|
||||||
|
t.Errorf("served(a)=%d want 0 before load completes", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Errorf("served(a)=%d want 1 after load", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_UnknownModel(t *testing.T) {
|
||||||
|
eff := newFakeEffects() // no states => unknown
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("ghost"))
|
||||||
|
|
||||||
|
if len(eff.starts) != 0 {
|
||||||
|
t.Errorf("StartSwap calls=%d want 0", len(eff.starts))
|
||||||
|
}
|
||||||
|
if eff.errored("ghost") != 1 {
|
||||||
|
t.Fatalf("errored(ghost)=%d want 1", eff.errored("ghost"))
|
||||||
|
}
|
||||||
|
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
|
||||||
|
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_EvictsEveryOtherModel(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["x"] = process.StateReady // already running
|
||||||
|
eff.states["y"] = process.StateReady // also running (e.g. left over)
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
|
||||||
|
st := lastStart(t, eff)
|
||||||
|
if st.model != "a" {
|
||||||
|
t.Fatalf("loading %s want a", st.model)
|
||||||
|
}
|
||||||
|
if !sameSet(st.evict, []string{"x", "y"}) {
|
||||||
|
t.Errorf("evict=%v want [x y] (serial evicts ALL other models)", st.evict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_OneJobAtATime verifies a second request waits while the first is
|
||||||
|
// serving, and only starts after the first finishes.
|
||||||
|
func TestSerial_OneJobAtATime(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // served immediately
|
||||||
|
s.OnRequest(req("b")) // must wait — a is serving
|
||||||
|
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 0 while a is serving", got)
|
||||||
|
}
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Fatalf("served(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// a finishes -> b may now load (evicting a).
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 after a finished", got)
|
||||||
|
}
|
||||||
|
if st := lastStart(t, eff); !sameSet(st.evict, []string{"a"}) {
|
||||||
|
t.Errorf("b evict=%v want [a]", st.evict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_SameModelConsecutive_NoReload verifies back-to-back requests for the
|
||||||
|
// already-loaded model run without a reload, one after another.
|
||||||
|
func TestSerial_SameModelConsecutive_NoReload(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // cold load
|
||||||
|
s.OnRequest(req("a")) // queued behind the first
|
||||||
|
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"}) // first serves
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Fatalf("served(a)=%d want 1 (one at a time)", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // first done -> second serves
|
||||||
|
if got := eff.served("a"); got != 2 {
|
||||||
|
t.Fatalf("served(a)=%d want 2", got)
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("a"); got != 1 {
|
||||||
|
t.Errorf("StartSwap(a)=%d want 1 (second request must not reload)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_StrictArrivalOrder is the core guarantee: qwen36, qwen35, sdxl,
|
||||||
|
// qwen36 execute in EXACTLY that order with evictions between each model switch,
|
||||||
|
// including reloading qwen36 at the end even though it ran first.
|
||||||
|
func TestSerial_StrictArrivalOrder(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
for _, m := range []string{"qwen36", "qwen35", "sdxl"} {
|
||||||
|
eff.states[m] = process.StateStopped
|
||||||
|
}
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
for _, m := range []string{"qwen36", "qwen35", "sdxl", "qwen36"} {
|
||||||
|
s.OnRequest(req(m))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only the first job starts loading; the rest wait their turn.
|
||||||
|
if len(eff.starts) != 1 || eff.starts[0].model != "qwen36" {
|
||||||
|
t.Fatalf("starts=%+v want only [qwen36] loading first", eff.starts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// step completes the current model's load+serve and returns control to the
|
||||||
|
// scheduler, which must start the next queued model.
|
||||||
|
step := func(model string, wantEvict []string) {
|
||||||
|
t.Helper()
|
||||||
|
st := lastStart(t, eff)
|
||||||
|
if st.model != model {
|
||||||
|
t.Fatalf("loading %q want %q", st.model, model)
|
||||||
|
}
|
||||||
|
if !sameSet(st.evict, wantEvict) {
|
||||||
|
t.Fatalf("loading %q evict=%v want %v", model, st.evict, wantEvict)
|
||||||
|
}
|
||||||
|
// Simulate the eviction + load actually happening.
|
||||||
|
for _, e := range st.evict {
|
||||||
|
eff.states[e] = process.StateStopped
|
||||||
|
}
|
||||||
|
eff.states[model] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: model})
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: model})
|
||||||
|
}
|
||||||
|
|
||||||
|
step("qwen36", nil) // cold load, nothing else running
|
||||||
|
step("qwen35", []string{"qwen36"}) // evict qwen36
|
||||||
|
step("sdxl", []string{"qwen35"}) // evict qwen35
|
||||||
|
step("qwen36", []string{"sdxl"}) // RELOAD qwen36, evict sdxl
|
||||||
|
|
||||||
|
want := []string{"qwen36", "qwen35", "sdxl", "qwen36"}
|
||||||
|
if got := servedOrder(eff); !sameOrder(got, want) {
|
||||||
|
t.Fatalf("serve order=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sameOrder(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_SwapError_FailsCallerAndAdvances(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("b")) // queued behind a
|
||||||
|
|
||||||
|
// a's load fails: its caller is errored and b proceeds.
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
|
||||||
|
if eff.errored("a") != 1 {
|
||||||
|
t.Fatalf("errored(a)=%d want 1", eff.errored("a"))
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 after a's load failed", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_GrantServeFalse_Advances verifies that when the active request's
|
||||||
|
// caller has disconnected by serve time, the queue advances to the next request.
|
||||||
|
func TestSerial_GrantServeFalse_Advances(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
eff.serveResult["a"] = false // a's caller is gone by grant time
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("b")) // queued
|
||||||
|
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails -> advance to b
|
||||||
|
|
||||||
|
if got := eff.served("a"); got != 0 {
|
||||||
|
t.Errorf("served(a)=%d want 0 (caller gone)", got)
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 (advanced after gone caller)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_OnCancel_QueuedRequest(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(reqCh("a")) // starts loading a
|
||||||
|
cancelled := reqCh("b")
|
||||||
|
s.OnRequest(cancelled) // queued behind a
|
||||||
|
if len(s.queued) != 1 {
|
||||||
|
t.Fatalf("queued=%d want 1", len(s.queued))
|
||||||
|
}
|
||||||
|
|
||||||
|
s.OnCancel(cancelled)
|
||||||
|
if len(s.queued) != 0 {
|
||||||
|
t.Fatalf("queued=%d want 0 after cancel", len(s.queued))
|
||||||
|
}
|
||||||
|
|
||||||
|
// a completes; b is gone, so nothing starts for it.
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Errorf("StartSwap(b)=%d want 0 (cancelled before its turn)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_OnShutdown_FailsQueuedAndActiveLoad(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
eff.states["c"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // active (loading)
|
||||||
|
s.OnRequest(req("b")) // queued
|
||||||
|
s.OnRequest(req("c")) // queued
|
||||||
|
|
||||||
|
s.OnShutdown(errors.New("shutting down"))
|
||||||
|
|
||||||
|
if got := eff.errored(""); got != 3 {
|
||||||
|
t.Errorf("error grants=%d want 3 (active load + 2 queued)", got)
|
||||||
|
}
|
||||||
|
if len(s.queued) != 0 {
|
||||||
|
t.Errorf("queued=%d want 0 after shutdown", len(s.queued))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_OnUnload_WhileServing verifies that unloading the model that is
|
||||||
|
// actively serving does not strand the queue: OnUnload stops the process but
|
||||||
|
// leaves the active request to end via OnServeDone, which then advances.
|
||||||
|
func TestSerial_OnUnload_WhileServing(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // served immediately (a ready)
|
||||||
|
s.OnRequest(req("b")) // queued behind a
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Fatalf("served(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unload a while it is serving: the process is stopped, but the queue must
|
||||||
|
// not advance yet — the active serve is still outstanding.
|
||||||
|
s.OnUnload([]string{"a"}, time.Second)
|
||||||
|
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||||
|
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 0 before the serving request ends", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The killed handler returns -> OnServeDone advances to b.
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 after the serving request ended", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_OnUnload_DropsQueuedAndStops(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // active (loading a)
|
||||||
|
s.OnRequest(req("b")) // queued
|
||||||
|
|
||||||
|
// Unload a: its active load is failed and a is stopped.
|
||||||
|
s.OnUnload([]string{"a"}, time.Second)
|
||||||
|
|
||||||
|
if eff.errored("a") != 1 {
|
||||||
|
t.Errorf("errored(a)=%d want 1 (active load failed)", eff.errored("a"))
|
||||||
|
}
|
||||||
|
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||||
|
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||||
|
}
|
||||||
|
// b was queued and not unloaded; with a's load cancelled it now starts.
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Errorf("StartSwap(b)=%d want 1 after unload advanced the queue", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -9,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/event"
|
"github.com/mostlygeek/llama-swap/internal/event"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -340,6 +342,28 @@ func (s *Server) handleUpstream(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Pin the resolved model so the router skips body/query extraction.
|
// Pin the resolved model so the router skips body/query extraction.
|
||||||
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
|
*r = *r.WithContext(shared.SetContext(r.Context(), shared.ReqContextData{Model: searchName, ModelID: modelID, Metadata: make(map[string]string)}))
|
||||||
|
|
||||||
|
// If the path matches an upstream.ignorePaths entry and the model is
|
||||||
|
// not already loaded, refuse the request without triggering a swap. The
|
||||||
|
// server was not able to process the response because the model was not
|
||||||
|
// already loaded.
|
||||||
|
for _, re := range s.cfg.Upstream.IgnorePaths {
|
||||||
|
if !re.MatchString(remainingPath) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s.local.Handles(modelID) {
|
||||||
|
state, ok := s.local.RunningModels()[modelID]
|
||||||
|
if !ok || state != process.StateReady {
|
||||||
|
shared.SendResponse(w, r, http.StatusConflict,
|
||||||
|
fmt.Sprintf("model %s is not loaded; path matches upstream.ignorePaths", modelID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Either the model is already loaded (no swap would be triggered)
|
||||||
|
// or this is a peer model (peer proxying never swaps). Fall through
|
||||||
|
// to normal dispatch.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case s.local.Handles(modelID):
|
case s.local.Handles(modelID):
|
||||||
s.local.ServeHTTP(w, r)
|
s.local.ServeHTTP(w, r)
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mostlygeek/llama-swap/internal/config"
|
"github.com/mostlygeek/llama-swap/internal/config"
|
||||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
"github.com/mostlygeek/llama-swap/internal/shared"
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -156,6 +158,91 @@ func upstreamMetricsServer(response string) *Server {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_HandleUpstream_IgnorePaths(t *testing.T) {
|
||||||
|
// Compile a pattern that matches static asset suffixes.
|
||||||
|
pattern := regexp.MustCompile(`.*\.(js|json|css|png|gif|jpg|jpeg|txt)$`)
|
||||||
|
|
||||||
|
t.Run("matched path, model not loaded, returns 409", func(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
// running is nil/empty: model is not in RunningModels() => not loaded.
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusConflict {
|
||||||
|
t.Fatalf("status = %d, want %d (body=%q)", w.Code, http.StatusConflict, w.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "not loaded") {
|
||||||
|
t.Errorf("body = %q, want it to contain 'not loaded'", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("matched path, model already loaded, serves normally", func(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
local.running = map[string]process.ProcessState{"m1": process.StateReady}
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||||
|
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-matched path, model not loaded, serves normally", func(t *testing.T) {
|
||||||
|
local := newStubRouter([]string{"m1"}, "upstream-body")
|
||||||
|
s := newTestServer(local, newStubRouter(nil, ""))
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/v1/chat/completions", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "upstream-body" {
|
||||||
|
t.Fatalf("status=%d body=%q, want 200 'upstream-body'", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("matched path, peer model, serves normally", func(t *testing.T) {
|
||||||
|
// Peer routers do not appear via RunningModels on the local router;
|
||||||
|
// they should fall through to normal dispatch without 409.
|
||||||
|
local := newStubRouter(nil, "")
|
||||||
|
peer := newStubRouter([]string{"m1"}, "peer-body")
|
||||||
|
s := newTestServer(local, peer)
|
||||||
|
s.cfg = config.Config{
|
||||||
|
Models: map[string]config.ModelConfig{"m1": {}},
|
||||||
|
Upstream: config.UpstreamConfig{
|
||||||
|
IgnorePaths: []*regexp.Regexp{pattern},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/upstream/m1/foo.js", nil))
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK || w.Body.String() != "peer-body" {
|
||||||
|
t.Fatalf("status=%d body=%q, want 200 'peer-body'", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) {
|
func TestServer_HandleUpstream_MetricsRecordsSupportedPath(t *testing.T) {
|
||||||
resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}`
|
resp := `{"usage":{"prompt_tokens":3,"completion_tokens":5}}`
|
||||||
s := upstreamMetricsServer(resp)
|
s := upstreamMetricsServer(resp)
|
||||||
|
|||||||
@@ -105,7 +105,9 @@ func (s *Server) handleAPIMetrics(w http.ResponseWriter, r *http.Request) {
|
|||||||
// filtered to samples after the ?after=<RFC3339> timestamp.
|
// filtered to samples after the ?after=<RFC3339> timestamp.
|
||||||
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
||||||
if s.perf == nil {
|
if s.perf == nil {
|
||||||
shared.SendResponse(w, r, http.StatusServiceUnavailable, "performance monitor not available")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
json.NewEncoder(w).Encode(map[string]bool{"enabled": false})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,6 +138,7 @@ func (s *Server) handleAPIPerformance(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"enabled": true,
|
||||||
"sys_stats": sysStats,
|
"sys_stats": sysStats,
|
||||||
"gpu_stats": gpuStats,
|
"gpu_stats": gpuStats,
|
||||||
})
|
})
|
||||||
|
|||||||
+115
-24
@@ -25,6 +25,8 @@ import (
|
|||||||
// TokenMetrics holds token usage and performance metrics.
|
// TokenMetrics holds token usage and performance metrics.
|
||||||
type TokenMetrics struct {
|
type TokenMetrics struct {
|
||||||
CachedTokens int `json:"cache_tokens"`
|
CachedTokens int `json:"cache_tokens"`
|
||||||
|
DraftTokens int `json:"draft_tokens"`
|
||||||
|
DraftAccTokens int `json:"draft_acc_tokens"`
|
||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
PromptPerSecond float64 `json:"prompt_per_second"`
|
||||||
@@ -42,6 +44,7 @@ type ActivityLogEntry struct {
|
|||||||
Tokens TokenMetrics `json:"tokens"`
|
Tokens TokenMetrics `json:"tokens"`
|
||||||
DurationMs int `json:"duration_ms"`
|
DurationMs int `json:"duration_ms"`
|
||||||
HasCapture bool `json:"has_capture"`
|
HasCapture bool `json:"has_capture"`
|
||||||
|
ErrorMsg string `json:"error_msg,omitempty"`
|
||||||
Metadata map[string]string `json:"metadata,omitempty"`
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,9 +126,11 @@ func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// record parses a completed response body and stores/emits an activity entry.
|
// record parses a completed response body and stores/emits an activity entry.
|
||||||
// When captures are enabled, a zstd+CBOR capture is stored for successful
|
// Successful requests store a zstd+CBOR capture (when enabled) with cf
|
||||||
// requests, with cf controlling which request/response parts are retained.
|
// controlling which parts are retained. Failed (non-200) requests capture the
|
||||||
// reqBody and reqHeaders are the request data buffered before dispatch.
|
// request only and set ErrorMsg to a description of the failure, so the error
|
||||||
|
// can be inspected without storing unreadable raw response bytes. reqBody and
|
||||||
|
// reqHeaders are the request data buffered before dispatch.
|
||||||
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string) {
|
||||||
tm := ActivityLogEntry{
|
tm := ActivityLogEntry{
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
@@ -150,7 +155,13 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
|
|
||||||
if recorder.Status() != http.StatusOK {
|
if recorder.Status() != http.StatusOK {
|
||||||
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
mp.logger.Warnf("non-200 response, recording partial metrics: status=%d, path=%s", recorder.Status(), r.URL.Path)
|
||||||
queueAndEmit()
|
decoded, decErr := mp.decodeResponseBody(recorder, r.URL.Path)
|
||||||
|
tm.ErrorMsg = failedErrorMessage(recorder.Status(), decoded, decErr)
|
||||||
|
tm.ID = mp.queueMetrics(tm)
|
||||||
|
// Capture the request only; the failure is surfaced via ErrorMsg
|
||||||
|
// rather than storing the (possibly undisplayable) response body.
|
||||||
|
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf&^captureRespBody, reqBody, reqHeaders, nil)
|
||||||
|
mp.emitMetric(tm)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +176,7 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
decoded, err := decompressBody(body, encoding)
|
decoded, err := decompressBody(body, encoding)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
mp.logger.Warnf("metrics: decompression failed: %v, path=%s, recording minimal metrics", err, r.URL.Path)
|
||||||
|
tm.ErrorMsg = fmt.Sprintf("response decompression failed: %v", err)
|
||||||
queueAndEmit()
|
queueAndEmit()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -203,28 +215,99 @@ func (mp *metricsMonitor) record(modelID string, r *http.Request, recorder *resp
|
|||||||
}
|
}
|
||||||
|
|
||||||
tm.ID = mp.queueMetrics(tm)
|
tm.ID = mp.queueMetrics(tm)
|
||||||
if mp.enableCaptures {
|
tm.HasCapture = mp.storeCapture(tm.ID, r, recorder, cf, reqBody, reqHeaders, body)
|
||||||
capture := ReqRespCapture{
|
mp.emitMetric(tm)
|
||||||
ID: tm.ID,
|
}
|
||||||
ReqPath: r.URL.Path,
|
|
||||||
ReqHeaders: reqHeaders,
|
// storeCapture assembles a ReqRespCapture for id, honoring the captureFields
|
||||||
}
|
// mask, and stores it when captures are enabled. body is the response body to
|
||||||
if cf&captureReqBody != 0 {
|
// capture (already decompressed by the caller); pass nil to omit it. Returns
|
||||||
capture.ReqBody = reqBody
|
// true if a capture was stored.
|
||||||
}
|
func (mp *metricsMonitor) storeCapture(id int, r *http.Request, recorder *responseBodyCopier, cf captureFields, reqBody []byte, reqHeaders map[string]string, body []byte) bool {
|
||||||
if cf&captureRespHeaders != 0 {
|
if !mp.enableCaptures {
|
||||||
capture.RespHeaders = headerMap(recorder.Header())
|
return false
|
||||||
redactHeaders(capture.RespHeaders)
|
}
|
||||||
delete(capture.RespHeaders, "Content-Encoding")
|
capture := ReqRespCapture{
|
||||||
}
|
ID: id,
|
||||||
if cf&captureRespBody != 0 {
|
ReqPath: r.URL.Path,
|
||||||
capture.RespBody = body
|
ReqHeaders: reqHeaders,
|
||||||
}
|
}
|
||||||
if mp.addCapture(capture) {
|
if cf&captureReqBody != 0 {
|
||||||
tm.HasCapture = true
|
capture.ReqBody = reqBody
|
||||||
|
}
|
||||||
|
if cf&captureRespHeaders != 0 {
|
||||||
|
capture.RespHeaders = headerMap(recorder.Header())
|
||||||
|
redactHeaders(capture.RespHeaders)
|
||||||
|
delete(capture.RespHeaders, "Content-Encoding")
|
||||||
|
}
|
||||||
|
if cf&captureRespBody != 0 {
|
||||||
|
capture.RespBody = body
|
||||||
|
}
|
||||||
|
return mp.addCapture(capture)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeResponseBody returns the buffered response body, decompressing it when
|
||||||
|
// the upstream set a Content-Encoding we recognize. On decompression failure it
|
||||||
|
// logs a warning and returns an error so the caller can record a description
|
||||||
|
// (via ErrorMsg) instead of storing unreadable raw bytes.
|
||||||
|
func (mp *metricsMonitor) decodeResponseBody(recorder *responseBodyCopier, path string) ([]byte, error) {
|
||||||
|
body := recorder.body.Bytes()
|
||||||
|
if len(body) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
encoding := recorder.Header().Get("Content-Encoding")
|
||||||
|
if encoding == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
decoded, err := decompressBody(body, encoding)
|
||||||
|
if err != nil {
|
||||||
|
mp.logger.Warnf("metrics: response decompression failed: %v, path=%s", err, path)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorMessagePaths lists JSON paths where a human-readable error message can
|
||||||
|
// live across OpenAI- and llama.cpp-style error responses.
|
||||||
|
var errorMessagePaths = []string{"error.message", "error", "message", "detail"}
|
||||||
|
|
||||||
|
// extractErrorMessage pulls a human-readable error string from a JSON error
|
||||||
|
// response. Returns "" if no message is found or the body is not valid JSON.
|
||||||
|
func extractErrorMessage(body []byte) string {
|
||||||
|
if !gjson.ValidBytes(body) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parsed := gjson.ParseBytes(body)
|
||||||
|
for _, path := range errorMessagePaths {
|
||||||
|
v := parsed.Get(path)
|
||||||
|
if v.Exists() && v.Type == gjson.String {
|
||||||
|
if s := strings.TrimSpace(v.String()); s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mp.emitMetric(tm)
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// failedErrorMessage builds a human-readable description for a non-200 response.
|
||||||
|
// It prefers an error message parsed from the (decompressed) body and falls back
|
||||||
|
// to the HTTP status text. A non-nil decErr indicates the body could not be
|
||||||
|
// decoded, in which case the decode error is described instead.
|
||||||
|
func failedErrorMessage(status int, body []byte, decErr error) string {
|
||||||
|
const maxLen = 500
|
||||||
|
if decErr != nil {
|
||||||
|
return fmt.Sprintf("response decode failed: %v", decErr)
|
||||||
|
}
|
||||||
|
if msg := extractErrorMessage(body); msg != "" {
|
||||||
|
if len(msg) > maxLen {
|
||||||
|
msg = msg[:maxLen] + "..."
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
if text := http.StatusText(status); text != "" {
|
||||||
|
return fmt.Sprintf("%d %s", status, text)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("HTTP %d", status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// usagePaths lists the JSON paths where a per-event usage object can live.
|
// usagePaths lists the JSON paths where a per-event usage object can live.
|
||||||
@@ -345,6 +428,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
|||||||
durationMs := wallDurationMs
|
durationMs := wallDurationMs
|
||||||
tokensPerSecond := -1.0
|
tokensPerSecond := -1.0
|
||||||
promptPerSecond := -1.0
|
promptPerSecond := -1.0
|
||||||
|
draftTokens := -1
|
||||||
|
draftAccTokens := -1
|
||||||
|
|
||||||
if timings.Exists() {
|
if timings.Exists() {
|
||||||
inputTokens = timings.Get("prompt_n").Int()
|
inputTokens = timings.Get("prompt_n").Int()
|
||||||
@@ -358,6 +443,10 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
|||||||
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
|
||||||
cachedTokens = cachedValue.Int()
|
cachedTokens = cachedValue.Int()
|
||||||
}
|
}
|
||||||
|
if timings.Get("draft_n").Exists() && timings.Get("draft_n_accepted").Exists() {
|
||||||
|
draftTokens = int(timings.Get("draft_n").Int())
|
||||||
|
draftAccTokens = int(timings.Get("draft_n_accepted").Int())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ActivityLogEntry{
|
return ActivityLogEntry{
|
||||||
@@ -365,6 +454,8 @@ func buildMetrics(modelID string, start time.Time, inputTokens, outputTokens, ca
|
|||||||
Model: modelID,
|
Model: modelID,
|
||||||
Tokens: TokenMetrics{
|
Tokens: TokenMetrics{
|
||||||
CachedTokens: int(cachedTokens),
|
CachedTokens: int(cachedTokens),
|
||||||
|
DraftTokens: draftTokens,
|
||||||
|
DraftAccTokens: draftAccTokens,
|
||||||
InputTokens: int(inputTokens),
|
InputTokens: int(inputTokens),
|
||||||
OutputTokens: int(outputTokens),
|
OutputTokens: int(outputTokens),
|
||||||
PromptPerSecond: promptPerSecond,
|
PromptPerSecond: promptPerSecond,
|
||||||
|
|||||||
@@ -90,6 +90,172 @@ func TestMetricsMonitor_RecordMetadata(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordFailedRequestCapture(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
reqHeaders := map[string]string{"content-type": "application/json"}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.Header().Set("Content-Type", "application/json")
|
||||||
|
copier.WriteHeader(http.StatusBadGateway)
|
||||||
|
copier.Write([]byte(`{"error":{"message":"model unavailable"}}`))
|
||||||
|
|
||||||
|
reqBody := []byte(`{"model":"m","messages":[]}`)
|
||||||
|
mm.record("m", r, copier, captureAll, reqBody, reqHeaders)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
entry := entries[0]
|
||||||
|
if entry.RespStatusCode != http.StatusBadGateway {
|
||||||
|
t.Errorf("status = %d, want %d", entry.RespStatusCode, http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
if entry.ErrorMsg != "model unavailable" {
|
||||||
|
t.Errorf("error_msg = %q, want extracted message", entry.ErrorMsg)
|
||||||
|
}
|
||||||
|
if !entry.HasCapture {
|
||||||
|
t.Fatal("failed request should capture the request so it can be inspected")
|
||||||
|
}
|
||||||
|
|
||||||
|
got := mm.getCaptureByID(entry.ID)
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("capture not found")
|
||||||
|
}
|
||||||
|
if string(got.ReqBody) != `{"model":"m","messages":[]}` {
|
||||||
|
t.Errorf("req body = %q", got.ReqBody)
|
||||||
|
}
|
||||||
|
if len(got.RespBody) != 0 {
|
||||||
|
t.Errorf("resp body stored for failed request (len=%d); want none", len(got.RespBody))
|
||||||
|
}
|
||||||
|
if got.RespHeaders["Content-Type"] != "application/json" {
|
||||||
|
t.Errorf("resp Content-Type = %q", got.RespHeaders["Content-Type"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordFailedRequestStatusFallback(t *testing.T) {
|
||||||
|
// Non-JSON error body: ErrorMsg falls back to the HTTP status text.
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.WriteHeader(http.StatusBadGateway)
|
||||||
|
copier.Write([]byte("<html>upstream down</html>"))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, captureAll, nil, nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].ErrorMsg != "502 Bad Gateway" {
|
||||||
|
t.Errorf("error_msg = %q, want status text", entries[0].ErrorMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordFailedRequestCaptureDisabled(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 0) // captures disabled
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.WriteHeader(http.StatusInternalServerError)
|
||||||
|
copier.Write([]byte(`{"error":"boom"}`))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].HasCapture {
|
||||||
|
t.Fatal("captures disabled, HasCapture should be false")
|
||||||
|
}
|
||||||
|
// ErrorMsg is independent of whether captures are enabled.
|
||||||
|
if entries[0].ErrorMsg != "boom" {
|
||||||
|
t.Errorf("error_msg = %q, want boom", entries[0].ErrorMsg)
|
||||||
|
}
|
||||||
|
if mm.getCaptureByID(entries[0].ID) != nil {
|
||||||
|
t.Fatal("no capture should be stored when disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_RecordDecompressionFailureSetsErrorMsg(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.Header().Set("Content-Encoding", "gzip")
|
||||||
|
copier.WriteHeader(http.StatusOK)
|
||||||
|
copier.Write([]byte("not-really-gzip"))
|
||||||
|
|
||||||
|
mm.record("m", r, copier, captureAll, []byte("req"), nil)
|
||||||
|
|
||||||
|
entries := mm.getMetrics()
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("want 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].ErrorMsg == "" {
|
||||||
|
t.Fatal("expected ErrorMsg for decompression failure")
|
||||||
|
}
|
||||||
|
// Raw bytes must not be stored when the body could not be decoded.
|
||||||
|
if entries[0].HasCapture {
|
||||||
|
t.Fatal("decompression failure should not store a capture")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_DecodeResponseBody(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(logmon.NewWriter(io.Discard), 10, 5)
|
||||||
|
|
||||||
|
// No Content-Encoding: body returned unchanged.
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
copier := newBodyCopier(w)
|
||||||
|
copier.Write([]byte("plain"))
|
||||||
|
got, err := mm.decodeResponseBody(copier, "/p")
|
||||||
|
if err != nil || string(got) != "plain" {
|
||||||
|
t.Fatalf("plain body = %q, err = %v", got, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bogus gzip payload: returns an error and no body (no raw bytes kept).
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
copier2 := newBodyCopier(w2)
|
||||||
|
copier2.Header().Set("Content-Encoding", "gzip")
|
||||||
|
copier2.Write([]byte("not-really-gzip"))
|
||||||
|
got, err = mm.decodeResponseBody(copier2, "/p")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected decompression error")
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Errorf("expected nil body on failure, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_ExtractErrorMessage(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"openai object", `{"error":{"message":"rate limited"}}`, "rate limited"},
|
||||||
|
{"string error", `{"error":"bad request"}`, "bad request"},
|
||||||
|
{"message field", `{"message":"nope"}`, "nope"},
|
||||||
|
{"detail field", `{"detail":"oops"}`, "oops"},
|
||||||
|
{"object error ignored", `{"error":{"code":42}}`, ""},
|
||||||
|
{"no error", `{"usage":{}}`, ""},
|
||||||
|
{"invalid json", `not-json`, ""},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := extractErrorMessage([]byte(tc.body)); got != tc.want {
|
||||||
|
t.Errorf("extractErrorMessage = %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
func TestServer_ParseMetrics_Infill(t *testing.T) {
|
||||||
// /infill responses are arrays; timings live in the last element.
|
// /infill responses are arrays; timings live in the last element.
|
||||||
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
body := `[{"content":"a"},{"content":"b","timings":{"prompt_n":5,"predicted_n":9,"prompt_ms":10,"predicted_ms":20}}]`
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DirWatcher polls a directory for changes to its set of *.yml / *.yaml files.
|
||||||
|
// It fires OnChange when a file is added, removed, or has its mod time/size
|
||||||
|
// change. Like Watcher it is poll-based so it works in Docker bind-mounts and
|
||||||
|
// k8s ConfigMap projections where inotify is unreliable.
|
||||||
|
//
|
||||||
|
// The baseline poll establishes initial state and does not fire OnChange.
|
||||||
|
type DirWatcher struct {
|
||||||
|
Path string
|
||||||
|
Interval time.Duration
|
||||||
|
OnChange func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// dirSnapshot is an ordered map of file name -> file state. The ordering is
|
||||||
|
// derived from sorted filenames so two snapshots compare deterministically
|
||||||
|
// regardless of readdir order. exists reflects whether the directory was
|
||||||
|
// readable at scan time; a missing directory yields exists=false.
|
||||||
|
type dirSnapshot struct {
|
||||||
|
exists bool
|
||||||
|
names []string
|
||||||
|
states map[string]snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDirSnapshot() dirSnapshot {
|
||||||
|
return dirSnapshot{states: make(map[string]snapshot)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equal reports whether two snapshots describe the same file set and per-file
|
||||||
|
// state. A missing directory (exists=false) is treated as equal to any other
|
||||||
|
// missing directory regardless of cached names.
|
||||||
|
func (s dirSnapshot) equal(other dirSnapshot) bool {
|
||||||
|
if !s.exists && !other.exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if s.exists != other.exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(s.names) != len(other.names) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, n := range s.names {
|
||||||
|
if other.names[i] != n {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, n := range s.names {
|
||||||
|
a, b := s.states[n], other.states[n]
|
||||||
|
if a.exists != b.exists || a.size != b.size || !a.modTime.Equal(b.modTime) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||||
|
// OnChange whenever the directory's YAML file set changes.
|
||||||
|
//
|
||||||
|
// Policy mirrors the single-file Watcher: disappearance (directory missing or
|
||||||
|
// empty) is treated as a transient rename-style write and stays quiet; the
|
||||||
|
// transition back to present-with-content fires OnChange.
|
||||||
|
func (w *DirWatcher) Run(ctx context.Context) {
|
||||||
|
interval := w.Interval
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = DefaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
prev := scanDir(w.Path)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
cur := scanDir(w.Path)
|
||||||
|
// Suppress transitions involving an empty or missing directory —
|
||||||
|
// these are treated as transient rename-style writes, mirroring
|
||||||
|
// the single-file Watcher. Only present-with-content →
|
||||||
|
// present-with-content (changed) or no-content →
|
||||||
|
// present-with-content fires OnChange.
|
||||||
|
prevHasContent := prev.exists && len(prev.names) > 0
|
||||||
|
curHasContent := cur.exists && len(cur.names) > 0
|
||||||
|
if curHasContent && (!prevHasContent || !prev.equal(cur)) && w.OnChange != nil {
|
||||||
|
w.OnChange()
|
||||||
|
}
|
||||||
|
prev = cur
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanDir returns a snapshot of the *.yml/*.yaml files in dir. If the
|
||||||
|
// directory cannot be read (missing, permission denied) the snapshot reports
|
||||||
|
// exists=false; the next successful scan will detect the recovery and fire
|
||||||
|
// OnChange.
|
||||||
|
func scanDir(dir string) dirSnapshot {
|
||||||
|
snap := newDirSnapshot()
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return snap // exists=false
|
||||||
|
}
|
||||||
|
snap.exists = true
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := e.Name()
|
||||||
|
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fi, err := os.Stat(filepath.Join(dir, name))
|
||||||
|
if err != nil {
|
||||||
|
// File disappeared between ReadDir and Stat; skip it — the
|
||||||
|
// next poll will observe the removal cleanly.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snap.names = append(snap.names, name)
|
||||||
|
snap.states[name] = snapshot{
|
||||||
|
exists: true,
|
||||||
|
modTime: fi.ModTime(),
|
||||||
|
size: fi.Size(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(snap.names)
|
||||||
|
return snap
|
||||||
|
}
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startDirWatcher launches w.Run in a goroutine and returns a function that
|
||||||
|
// cancels the context and waits for Run to return.
|
||||||
|
func startDirWatcher(t *testing.T, w *DirWatcher) func() {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.Run(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("DirWatcher did not stop within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeYAMLInDir(t *testing.T, dir, name, content string) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_NoFireOnBaseline(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 5)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_DetectsFileAdd(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is added")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_DetectsFileRemoval(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.Remove(filepath.Join(dir, "b.yaml")))
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_DetectsModTimeChange(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||||
|
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base, base))
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_IgnoresNonYAMLFiles(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Adding a .txt file must not fire.
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0o644))
|
||||||
|
time.Sleep(testInterval * 4)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "non-YAML files must be ignored")
|
||||||
|
|
||||||
|
// Adding a .yml file must fire.
|
||||||
|
writeYAMLInDir(t, dir, "b.yml", "b")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire for *.yml files")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_MissingDirRecovers(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Remove the directory. No fire expected on disappearance alone.
|
||||||
|
require.NoError(t, os.RemoveAll(dir))
|
||||||
|
time.Sleep(testInterval * 3)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "directory removal alone must not fire")
|
||||||
|
|
||||||
|
// Recreate the directory and a YAML file; the recovery should fire.
|
||||||
|
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||||
|
writeYAMLInDir(t, dir, "recovered.yaml", "r")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when dir returns with content")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_EmptyDirSuppressedThenRecovers(t *testing.T) {
|
||||||
|
// Present-with-content → empty (all YAML removed, dir still exists)
|
||||||
|
// must stay quiet — treated as transient per the documented policy.
|
||||||
|
// The transition back to content fires.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Remove the only YAML file. Dir still exists but is empty of YAML.
|
||||||
|
require.NoError(t, os.Remove(filepath.Join(dir, "a.yaml")))
|
||||||
|
time.Sleep(testInterval * 4)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "emptying the directory must not fire")
|
||||||
|
|
||||||
|
// Add a YAML file back; transition to present-with-content fires.
|
||||||
|
writeYAMLInDir(t, dir, "c.yaml", "c")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when content returns")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
w := &DirWatcher{Path: dir, Interval: testInterval}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() { w.Run(ctx); close(done) }()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Run did not return within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
+37
-19
@@ -55,7 +55,8 @@ var logTimeFormats = map[string]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flagConfig := flag.String("config", "", "path to config file (required)")
|
flagConfig := flag.String("config", "", "path to config file")
|
||||||
|
flagConfigDir := flag.String("config-dir", "", "directory of *.yml/*.yaml config files (additive to -config)")
|
||||||
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
|
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
|
||||||
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||||
flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
|
flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||||
@@ -68,8 +69,8 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
if *flagConfig == "" {
|
if *flagConfig == "" && *flagConfigDir == "" {
|
||||||
slog.Error("-config is required")
|
slog.Error("at least one of -config or -config-dir must be provided")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,10 +89,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
configPath := *flagConfig
|
cfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||||
cfg, err := config.LoadConfig(configPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to load config", "path", configPath, "error", err)
|
slog.Error("failed to load config", "config", *flagConfig, "config-dir", *flagConfigDir, "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,7 +187,7 @@ func main() {
|
|||||||
|
|
||||||
proxyLog.Info("reloading configuration")
|
proxyLog.Info("reloading configuration")
|
||||||
|
|
||||||
newCfg, err := config.LoadConfig(configPath)
|
newCfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
proxyLog.Warnf("failed to reload config: %v", err)
|
proxyLog.Warnf("failed to reload config: %v", err)
|
||||||
return
|
return
|
||||||
@@ -230,19 +230,37 @@ func main() {
|
|||||||
defer watcherCancel()
|
defer watcherCancel()
|
||||||
|
|
||||||
if *flagWatchConfig {
|
if *flagWatchConfig {
|
||||||
absConfigPath, err := filepath.Abs(configPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("watch-config: failed to resolve config path", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
||||||
go func() {
|
|
||||||
(&configwatcher.Watcher{
|
if *flagConfig != "" {
|
||||||
Path: absConfigPath,
|
absConfigPath, err := filepath.Abs(*flagConfig)
|
||||||
Interval: configwatcher.DefaultInterval,
|
if err != nil {
|
||||||
OnChange: reload,
|
slog.Error("watch-config: failed to resolve config path", "error", err)
|
||||||
}).Run(watcherCtx)
|
os.Exit(1)
|
||||||
}()
|
}
|
||||||
|
go func() {
|
||||||
|
(&configwatcher.Watcher{
|
||||||
|
Path: absConfigPath,
|
||||||
|
Interval: configwatcher.DefaultInterval,
|
||||||
|
OnChange: reload,
|
||||||
|
}).Run(watcherCtx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if *flagConfigDir != "" {
|
||||||
|
absConfigDir, err := filepath.Abs(*flagConfigDir)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("watch-config: failed to resolve config-dir path", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
(&configwatcher.DirWatcher{
|
||||||
|
Path: absConfigDir,
|
||||||
|
Interval: configwatcher.DefaultInterval,
|
||||||
|
OnChange: reload,
|
||||||
|
}).Run(watcherCtx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
import Performance from "./routes/Performance.svelte";
|
import Performance from "./routes/Performance.svelte";
|
||||||
import Playground from "./routes/Playground.svelte";
|
import Playground from "./routes/Playground.svelte";
|
||||||
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
import PlaygroundStub from "./routes/PlaygroundStub.svelte";
|
||||||
import { enableAPIEvents } from "./stores/api";
|
import { enableAPIEvents, checkPerformanceEnabled } from "./stores/api";
|
||||||
import { initScreenWidth, initSystemThemeListener, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
import { initScreenWidth, initSystemThemeListener, isDarkMode, appTitle, connectionState } from "./stores/theme";
|
||||||
import { currentRoute } from "./stores/route";
|
import { currentRoute } from "./stores/route";
|
||||||
|
|
||||||
@@ -39,6 +39,7 @@
|
|||||||
const cleanupScreenWidth = initScreenWidth();
|
const cleanupScreenWidth = initScreenWidth();
|
||||||
const cleanupSystemTheme = initSystemThemeListener();
|
const cleanupSystemTheme = initSystemThemeListener();
|
||||||
enableAPIEvents(true);
|
enableAPIEvents(true);
|
||||||
|
checkPerformanceEnabled();
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
cleanupScreenWidth();
|
cleanupScreenWidth();
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import { screenWidth, toggleTheme, themeMode, appTitle, isNarrow } from "../stores/theme";
|
import { screenWidth, toggleTheme, themeMode, appTitle, isNarrow } from "../stores/theme";
|
||||||
import { currentRoute } from "../stores/route";
|
import { currentRoute } from "../stores/route";
|
||||||
import { playgroundActivity } from "../stores/playgroundActivity";
|
import { playgroundActivity } from "../stores/playgroundActivity";
|
||||||
|
import { performanceEnabled } from "../stores/api";
|
||||||
import ConnectionStatus from "./ConnectionStatus.svelte";
|
import ConnectionStatus from "./ConnectionStatus.svelte";
|
||||||
|
|
||||||
function handleTitleChange(newTitle: string): void {
|
function handleTitleChange(newTitle: string): void {
|
||||||
@@ -84,16 +85,18 @@
|
|||||||
>
|
>
|
||||||
Logs
|
Logs
|
||||||
</a>
|
</a>
|
||||||
<a
|
{#if $performanceEnabled}
|
||||||
href="/performance"
|
<a
|
||||||
use:link
|
href="/performance"
|
||||||
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
use:link
|
||||||
class:font-semibold={isActive("/performance", $currentRoute)}
|
class="text-gray-600 hover:text-black dark:text-gray-300 dark:hover:text-gray-100 p-1 whitespace-nowrap"
|
||||||
class:underline={isActive("/performance", $currentRoute)}
|
class:font-semibold={isActive("/performance", $currentRoute)}
|
||||||
class:underline-offset-4={isActive("/performance", $currentRoute)}
|
class:underline={isActive("/performance", $currentRoute)}
|
||||||
>
|
class:underline-offset-4={isActive("/performance", $currentRoute)}
|
||||||
Performance
|
>
|
||||||
</a>
|
Performance
|
||||||
|
</a>
|
||||||
|
{/if}
|
||||||
<button onclick={toggleTheme} title="Toggle theme (current: {$themeMode})">
|
<button onclick={toggleTheme} title="Toggle theme (current: {$themeMode})">
|
||||||
{#if $themeMode === "system"}
|
{#if $themeMode === "system"}
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class="w-5 h-5">
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ export interface Model {
|
|||||||
|
|
||||||
export interface TokenMetrics {
|
export interface TokenMetrics {
|
||||||
cache_tokens: number;
|
cache_tokens: number;
|
||||||
|
draft_tokens: number;
|
||||||
|
draft_acc_tokens: number;
|
||||||
input_tokens: number;
|
input_tokens: number;
|
||||||
output_tokens: number;
|
output_tokens: number;
|
||||||
prompt_per_second: number;
|
prompt_per_second: number;
|
||||||
@@ -41,6 +43,7 @@ export interface ActivityLogEntry {
|
|||||||
tokens: TokenMetrics;
|
tokens: TokenMetrics;
|
||||||
duration_ms: number;
|
duration_ms: number;
|
||||||
has_capture: boolean;
|
has_capture: boolean;
|
||||||
|
error_msg?: string;
|
||||||
metadata?: Record<string, string>;
|
metadata?: Record<string, string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,11 +21,12 @@
|
|||||||
{ key: "time", label: "Time", defaultVisible: true },
|
{ key: "time", label: "Time", defaultVisible: true },
|
||||||
{ key: "model", label: "Model", defaultVisible: true },
|
{ key: "model", label: "Model", defaultVisible: true },
|
||||||
{ key: "req_path", label: "Path", defaultVisible: false },
|
{ key: "req_path", label: "Path", defaultVisible: false },
|
||||||
{ key: "resp_status_code", label: "Status", defaultVisible: false },
|
{ key: "resp_status_code", label: "Status", defaultVisible: true },
|
||||||
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
|
{ key: "resp_content_type", label: "Content-Type", defaultVisible: false },
|
||||||
{ key: "cached", label: "Cached", defaultVisible: true },
|
{ key: "cached", label: "Cached", defaultVisible: true },
|
||||||
{ key: "prompt", label: "Prompt", defaultVisible: true },
|
{ key: "prompt", label: "Prompt", defaultVisible: true },
|
||||||
{ key: "generated", label: "Generated", defaultVisible: true },
|
{ key: "generated", label: "Generated", defaultVisible: true },
|
||||||
|
{ key: "drafted", label: "Drafted", defaultVisible: false },
|
||||||
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
|
{ key: "prompt_speed", label: "Prompt Speed", defaultVisible: true },
|
||||||
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
|
{ key: "gen_speed", label: "Gen Speed", defaultVisible: true },
|
||||||
{ key: "duration", label: "Duration", defaultVisible: true },
|
{ key: "duration", label: "Duration", defaultVisible: true },
|
||||||
@@ -158,6 +159,10 @@
|
|||||||
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
return speed < 0 ? "unknown" : speed.toFixed(2) + " t/s";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function formatDrafted(drafted: number, accepted: number): string {
|
||||||
|
return drafted > 0 ? (accepted * 100 / drafted).toFixed(1) + "% (" + accepted + "/" + drafted + ")" : "-";
|
||||||
|
}
|
||||||
|
|
||||||
function formatDuration(ms: number): string {
|
function formatDuration(ms: number): string {
|
||||||
return (ms / 1000).toFixed(2) + "s";
|
return (ms / 1000).toFixed(2) + "s";
|
||||||
}
|
}
|
||||||
@@ -273,6 +278,8 @@
|
|||||||
Cached <Tooltip content="prompt tokens from cache" />
|
Cached <Tooltip content="prompt tokens from cache" />
|
||||||
{:else if key === "prompt"}
|
{:else if key === "prompt"}
|
||||||
Prompt <Tooltip content="new prompt tokens processed" />
|
Prompt <Tooltip content="new prompt tokens processed" />
|
||||||
|
{:else if key === "drafted"}
|
||||||
|
Drafted <Tooltip content="acceptance rate (accepted/drafted)" />
|
||||||
{:else}
|
{:else}
|
||||||
{columnLabelMap[key] ?? key}
|
{columnLabelMap[key] ?? key}
|
||||||
{/if}
|
{/if}
|
||||||
@@ -301,7 +308,13 @@
|
|||||||
{:else if key === "req_path"}
|
{:else if key === "req_path"}
|
||||||
{metric.req_path || "-"}
|
{metric.req_path || "-"}
|
||||||
{:else if key === "resp_status_code"}
|
{:else if key === "resp_status_code"}
|
||||||
{metric.resp_status_code || "-"}
|
{#if metric.error_msg}
|
||||||
|
<span class="text-red-500 dark:text-red-400 cursor-help" title={metric.error_msg}>
|
||||||
|
{metric.resp_status_code || "-"}
|
||||||
|
</span>
|
||||||
|
{:else}
|
||||||
|
{metric.resp_status_code || "-"}
|
||||||
|
{/if}
|
||||||
{:else if key === "resp_content_type"}
|
{:else if key === "resp_content_type"}
|
||||||
{metric.resp_content_type || "-"}
|
{metric.resp_content_type || "-"}
|
||||||
{:else if key === "cached"}
|
{:else if key === "cached"}
|
||||||
@@ -310,6 +323,8 @@
|
|||||||
{metric.tokens.input_tokens.toLocaleString()}
|
{metric.tokens.input_tokens.toLocaleString()}
|
||||||
{:else if key === "generated"}
|
{:else if key === "generated"}
|
||||||
{metric.tokens.output_tokens.toLocaleString()}
|
{metric.tokens.output_tokens.toLocaleString()}
|
||||||
|
{:else if key === "drafted"}
|
||||||
|
{formatDrafted(metric.tokens.draft_tokens, metric.tokens.draft_acc_tokens)}
|
||||||
{:else if key === "prompt_speed"}
|
{:else if key === "prompt_speed"}
|
||||||
{formatSpeed(metric.tokens.prompt_per_second)}
|
{formatSpeed(metric.tokens.prompt_per_second)}
|
||||||
{:else if key === "gen_speed"}
|
{:else if key === "gen_speed"}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ export const proxyLogs = writable<string>("");
|
|||||||
export const upstreamLogs = writable<string>("");
|
export const upstreamLogs = writable<string>("");
|
||||||
export const metrics = writable<ActivityLogEntry[]>([]);
|
export const metrics = writable<ActivityLogEntry[]>([]);
|
||||||
export const inFlightRequests = writable<number>(0);
|
export const inFlightRequests = writable<number>(0);
|
||||||
|
export const performanceEnabled = writable<boolean>(false);
|
||||||
export const versionInfo = writable<VersionInfo>({
|
export const versionInfo = writable<VersionInfo>({
|
||||||
build_date: "unknown",
|
build_date: "unknown",
|
||||||
commit: "unknown",
|
commit: "unknown",
|
||||||
@@ -210,6 +211,20 @@ export async function getCapture(id: number): Promise<ReqRespCapture | null> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function checkPerformanceEnabled(): Promise<void> {
|
||||||
|
try {
|
||||||
|
const response = await fetch("/api/performance");
|
||||||
|
if (!response.ok) {
|
||||||
|
performanceEnabled.set(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const data = await response.json();
|
||||||
|
performanceEnabled.set(data.enabled);
|
||||||
|
} catch {
|
||||||
|
performanceEnabled.set(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export async function fetchPerformance(after?: string): Promise<PerformanceResponse | null> {
|
export async function fetchPerformance(after?: string): Promise<PerformanceResponse | null> {
|
||||||
try {
|
try {
|
||||||
const url = after ? `/api/performance?after=${encodeURIComponent(after)}` : "/api/performance";
|
const url = after ? `/api/performance?after=${encodeURIComponent(after)}` : "/api/performance";
|
||||||
|
|||||||
Reference in New Issue
Block a user