Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0292c90ca1 | |||
| 617c7dc6b9 | |||
| 542b79dacf | |||
| 0a25b3bd31 | |||
| 32bc781326 |
@@ -0,0 +1,76 @@
|
|||||||
|
name: Build CUDA image (fork)
|
||||||
|
|
||||||
|
# Builds this fork's llama-swap (serial scheduler + embedded UI) from source and
|
||||||
|
# layers it on a pinned llama.cpp CUDA server base, then pushes to the Gitea
|
||||||
|
# container registry, e.g. gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||||
|
#
|
||||||
|
# Requires repo secrets: REGISTRY_USER, REGISTRY_PASSWORD (push to the registry).
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
llama_swap_version:
|
||||||
|
description: "llama-swap version label (image tag prefix)"
|
||||||
|
required: false
|
||||||
|
default: "v230"
|
||||||
|
llamacpp_build:
|
||||||
|
description: "llama.cpp CUDA server build (base image tag suffix)"
|
||||||
|
required: false
|
||||||
|
default: "b9821"
|
||||||
|
# Building the build definition itself kicks off a fresh image.
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
paths:
|
||||||
|
- ".gitea/workflows/build-cuda-image.yml"
|
||||||
|
- "docker/fork-cuda.Containerfile"
|
||||||
|
|
||||||
|
env:
|
||||||
|
REGISTRY: gitea.stevedudenhoeffer.com
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Compute image metadata
|
||||||
|
id: meta
|
||||||
|
run: |
|
||||||
|
LS_VER="${{ inputs.llama_swap_version || 'v230' }}"
|
||||||
|
LCPP="${{ inputs.llamacpp_build || 'b9821' }}"
|
||||||
|
{
|
||||||
|
echo "image=${REGISTRY}/${{ github.repository }}"
|
||||||
|
echo "tag=${LS_VER}-cuda-${LCPP}"
|
||||||
|
echo "base_tag=server-cuda-${LCPP}"
|
||||||
|
echo "ls_version=${LS_VER}"
|
||||||
|
echo "build_date=$(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||||
|
} >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Log in to Gitea registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ secrets.REGISTRY_USER }}
|
||||||
|
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||||
|
|
||||||
|
- name: Build and push
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: docker/fork-cuda.Containerfile
|
||||||
|
push: true
|
||||||
|
provenance: false
|
||||||
|
build-args: |
|
||||||
|
BASE_TAG=${{ steps.meta.outputs.base_tag }}
|
||||||
|
LS_VERSION=${{ steps.meta.outputs.ls_version }}
|
||||||
|
GIT_HASH=${{ github.sha }}
|
||||||
|
BUILD_DATE=${{ steps.meta.outputs.build_date }}
|
||||||
|
tags: ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}
|
||||||
|
|
||||||
|
- name: Summary
|
||||||
|
run: |
|
||||||
|
echo "Pushed ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}" >> "$GITHUB_STEP_SUMMARY"
|
||||||
@@ -5,16 +5,14 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
|||||||
## Tech stack
|
## Tech stack
|
||||||
|
|
||||||
- golang
|
- golang
|
||||||
- typescript, vite and svelt5 for UI (located in ui/)
|
- typescript, vite and svelte 5 for UI (located in ui-svelte/)
|
||||||
|
|
||||||
## Workflow Tasks
|
## Workflow Tasks
|
||||||
|
|
||||||
- when summarizing changes only include details that require further action
|
- when summarizing changes only include details that require further action
|
||||||
- just say "Done." when there is no further action
|
|
||||||
- use the github CLI `gh` to create pull requests and work with github
|
|
||||||
- Rules for creating pull requests:
|
- Rules for creating pull requests:
|
||||||
- keep them short and focused on changes.
|
- keep them short and focused on changes
|
||||||
- never include a test plan
|
- skip the test plan
|
||||||
- write the summary using the same style rules as commit message
|
- write the summary using the same style rules as commit message
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
@@ -30,7 +28,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
|||||||
### Commit message example format:
|
### Commit message example format:
|
||||||
|
|
||||||
```
|
```
|
||||||
proxy: add new feature
|
internal/server: add new feature
|
||||||
|
|
||||||
Add new feature that implements functionality X and Y.
|
Add new feature that implements functionality X and Y.
|
||||||
|
|
||||||
|
|||||||
+3
-2
@@ -601,10 +601,11 @@
|
|||||||
"use": {
|
"use": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
|
"serial",
|
||||||
"fifo"
|
"fifo"
|
||||||
],
|
],
|
||||||
"default": "fifo",
|
"default": "serial",
|
||||||
"description": "Scheduler to use. Only 'fifo' is currently supported."
|
"description": "Scheduler to use. 'serial' (default on this fork): strict one-model-at-a-time, requests run in exact arrival order, switching models evicts every other model first. 'fifo': throughput-oriented, batches same-model requests and allows parallel/co-resident models."
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
+13
-3
@@ -556,11 +556,21 @@ routing:
|
|||||||
# expands to: [L]
|
# expands to: [L]
|
||||||
full: "L"
|
full: "L"
|
||||||
|
|
||||||
# scheduler: how queued requests are ordered.
|
# scheduler: how queued requests are ordered and run.
|
||||||
# The default and only valid scheduler is "fifo"
|
# - optional, default on this fork: "serial"
|
||||||
|
# - valid values:
|
||||||
|
# - "serial": strict one-model-at-a-time. Requests run in exact arrival
|
||||||
|
# order; only one request runs at a time; switching to a different model
|
||||||
|
# evicts every other running model first so a single model occupies memory
|
||||||
|
# at a time. This ignores group/matrix co-residency entirely. The "fifo"
|
||||||
|
# settings below (priority) do not apply.
|
||||||
|
# - "fifo": throughput-oriented. Same-model requests are batched to reduce
|
||||||
|
# swaps and a model serves up to its concurrencyLimit in parallel; models
|
||||||
|
# in non-exclusive groups can run concurrently. Requests may be reordered.
|
||||||
scheduler:
|
scheduler:
|
||||||
use: fifo
|
use: serial
|
||||||
settings:
|
settings:
|
||||||
|
# fifo settings only apply when use: fifo
|
||||||
fifo:
|
fifo:
|
||||||
# priority: a dictionary of model ID -> priority
|
# priority: a dictionary of model ID -> priority
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
# Build a CUDA llama-swap image FROM THIS FORK's source (includes the serial
|
||||||
|
# scheduler) and layer it on a pinned llama.cpp CUDA server base. Produces e.g.:
|
||||||
|
# gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||||
|
#
|
||||||
|
# BASE_TAG selects the llama.cpp CUDA runtime + llama-server build, e.g.
|
||||||
|
# "server-cuda-b9821". The llama-swap binary (with the embedded Svelte UI) is
|
||||||
|
# compiled from the repo at build time, so no GitHub release is required.
|
||||||
|
#
|
||||||
|
# Build context is the repo root:
|
||||||
|
# docker build -f docker/fork-cuda.Containerfile \
|
||||||
|
# --build-arg BASE_TAG=server-cuda-b9821 -t llama-swap:v230-cuda-b9821 .
|
||||||
|
|
||||||
|
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||||
|
ARG BASE_TAG=server-cuda-b9821
|
||||||
|
|
||||||
|
# ---- Stage 1: build the Svelte UI (embedded into the binary) ----
|
||||||
|
FROM node:22-bookworm-slim AS ui
|
||||||
|
WORKDIR /src/ui-svelte
|
||||||
|
# Install deps first for layer caching. .npmrc carries legacy-peer-deps=true,
|
||||||
|
# which the project relies on (tailwind/vite peer ranges), so copy it before
|
||||||
|
# npm ci or the strict resolver fails with ERESOLVE.
|
||||||
|
COPY ui-svelte/package.json ui-svelte/package-lock.json ui-svelte/.npmrc ./
|
||||||
|
RUN npm ci
|
||||||
|
COPY ui-svelte/ ./
|
||||||
|
# `npm run build` is `vite build --emptyOutDir`; vite.config.ts writes to
|
||||||
|
# ../internal/server/ui_dist, which //go:embed picks up in the next stage.
|
||||||
|
RUN mkdir -p /src/internal/server && npm run build
|
||||||
|
|
||||||
|
# ---- Stage 2: build the llama-swap binary with the embedded UI ----
|
||||||
|
FROM golang:1.26-bookworm AS build
|
||||||
|
WORKDIR /src
|
||||||
|
# Cache modules independently of source churn.
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
COPY . .
|
||||||
|
# Overlay the freshly built UI so //go:embed ui_dist ships the real assets
|
||||||
|
# instead of the committed placeholder.
|
||||||
|
COPY --from=ui /src/internal/server/ui_dist/ ./internal/server/ui_dist/
|
||||||
|
ARG LS_VERSION=v230
|
||||||
|
ARG GIT_HASH=unknown
|
||||||
|
ARG BUILD_DATE=unknown
|
||||||
|
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||||
|
-ldflags="-X main.version=${LS_VERSION} -X main.commit=${GIT_HASH} -X main.date=${BUILD_DATE}" \
|
||||||
|
-o /out/llama-swap .
|
||||||
|
|
||||||
|
# ---- Stage 3: runtime image on the pinned llama.cpp CUDA base ----
|
||||||
|
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||||
|
|
||||||
|
# Run as root by default to match the upstream `vNNN-cuda-bNNNN` (non-suffixed)
|
||||||
|
# image that ragnaros pulls today: it needs root to reach the mounted docker
|
||||||
|
# socket for container-backed models (sd-server). Override UID/GID at build time
|
||||||
|
# for a non-root variant.
|
||||||
|
ARG UID=0
|
||||||
|
ARG GID=0
|
||||||
|
ARG USER_HOME=/root
|
||||||
|
ENV HOME=$USER_HOME
|
||||||
|
|
||||||
|
RUN set -eux; \
|
||||||
|
if [ "$UID" -ne 0 ]; then \
|
||||||
|
if [ "$GID" -ne 0 ]; then groupadd --system --gid "$GID" app; fi; \
|
||||||
|
useradd --system --uid "$UID" --gid "$GID" --home "$USER_HOME" app; \
|
||||||
|
fi; \
|
||||||
|
mkdir --parents "$HOME" /app; \
|
||||||
|
chown --recursive "$UID:$GID" "$HOME" /app
|
||||||
|
|
||||||
|
COPY --from=build --chown=$UID:$GID /out/llama-swap /app/llama-swap
|
||||||
|
COPY --chown=$UID:$GID docker/config.example.yaml /app/config.yaml
|
||||||
|
|
||||||
|
USER $UID:$GID
|
||||||
|
WORKDIR /app
|
||||||
|
ENV PATH="/app:${PATH}"
|
||||||
|
|
||||||
|
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||||
|
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/billziss-gh/golib/shlex"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Handle trailing backslashes by replacing with space
|
||||||
|
if strings.HasSuffix(trimmed, "\\") {
|
||||||
|
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||||
|
} else {
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// put it back together
|
||||||
|
cmdStr = strings.Join(cleanedLines, "\n")
|
||||||
|
|
||||||
|
// Split the command into arguments
|
||||||
|
var args []string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
args = shlex.Windows.Split(cmdStr)
|
||||||
|
} else {
|
||||||
|
args = shlex.Posix.Split(cmdStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the command is not empty
|
||||||
|
if len(args) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty command")
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripComments(cmdStr string) string {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
return strings.Join(cleanedLines, "\n")
|
||||||
|
}
|
||||||
@@ -2,16 +2,9 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
|
||||||
"runtime"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/billziss-gh/golib/shlex"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,12 +78,6 @@ type GroupConfig struct {
|
|||||||
Members []string `yaml:"members"`
|
Members []string `yaml:"members"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
|
||||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
|
||||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
|
||||||
)
|
|
||||||
|
|
||||||
// set default values for GroupConfig
|
// set default values for GroupConfig
|
||||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
type rawGroupConfig GroupConfig
|
type rawGroupConfig GroupConfig
|
||||||
@@ -224,430 +211,6 @@ func LoadConfig(path string) (Config, error) {
|
|||||||
return LoadConfigFromReader(file)
|
return LoadConfigFromReader(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
yamlStr := string(data)
|
|
||||||
|
|
||||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
|
||||||
// This is safe because env values are simple strings without YAML formatting
|
|
||||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal into full Config with defaults
|
|
||||||
config := Config{
|
|
||||||
HealthCheckTimeout: 120,
|
|
||||||
StartPort: 5800,
|
|
||||||
LogLevel: "info",
|
|
||||||
LogTimeFormat: "",
|
|
||||||
LogToStdout: LogToStdoutProxy,
|
|
||||||
MetricsMaxInMemory: 1000,
|
|
||||||
CaptureBuffer: 5,
|
|
||||||
GlobalTTL: 0,
|
|
||||||
}
|
|
||||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.HealthCheckTimeout < 15 {
|
|
||||||
config.HealthCheckTimeout = 15
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply defaults for performance config when section is missing
|
|
||||||
if config.Performance.Every == 0 {
|
|
||||||
config.Performance.Every = 5 * time.Second
|
|
||||||
}
|
|
||||||
if err = config.Performance.Validate(); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("performance: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.StartPort < 1 {
|
|
||||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.GlobalTTL < 0 {
|
|
||||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply default for upstream.ignorePaths when not specified. The default
|
|
||||||
// matches common static-asset suffixes so they do not trigger a swap.
|
|
||||||
if len(config.Upstream.IgnorePaths) == 0 {
|
|
||||||
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
|
||||||
}
|
|
||||||
|
|
||||||
switch config.LogToStdout {
|
|
||||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
|
||||||
default:
|
|
||||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate the aliases map
|
|
||||||
config.aliases = make(map[string]string)
|
|
||||||
for modelName, modelConfig := range config.Models {
|
|
||||||
for _, alias := range modelConfig.Aliases {
|
|
||||||
if _, found := config.aliases[alias]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
|
||||||
}
|
|
||||||
config.aliases[alias] = modelName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate global macros
|
|
||||||
for _, macro := range config.Macros {
|
|
||||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get and sort all model IDs for consistent port assignment
|
|
||||||
modelIds := make([]string, 0, len(config.Models))
|
|
||||||
for modelId := range config.Models {
|
|
||||||
modelIds = append(modelIds, modelId)
|
|
||||||
}
|
|
||||||
sort.Strings(modelIds)
|
|
||||||
|
|
||||||
nextPort := config.StartPort
|
|
||||||
for _, modelId := range modelIds {
|
|
||||||
modelConfig := config.Models[modelId]
|
|
||||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
|
||||||
|
|
||||||
// Strip comments from command fields
|
|
||||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
|
||||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
|
||||||
|
|
||||||
// set model TTL to globalTTL it is the default value
|
|
||||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
|
||||||
modelConfig.UnloadAfter = config.GlobalTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelConfig.UnloadAfter < 0 {
|
|
||||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate model macros
|
|
||||||
for _, macro := range modelConfig.Macros {
|
|
||||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
|
||||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
|
||||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
|
||||||
mergedMacros = append(mergedMacros, config.Macros...)
|
|
||||||
|
|
||||||
// Add model macros (override globals with same name)
|
|
||||||
for _, entry := range modelConfig.Macros {
|
|
||||||
found := false
|
|
||||||
for i, existing := range mergedMacros {
|
|
||||||
if existing.Name == entry.Name {
|
|
||||||
mergedMacros[i] = entry
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
mergedMacros = append(mergedMacros, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Substitute remaining macros in model fields (LIFO order)
|
|
||||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
|
||||||
entry := mergedMacros[i]
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
|
||||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
|
||||||
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
|
||||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
|
||||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
|
||||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
|
||||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
|
||||||
|
|
||||||
// Substitute macros in SetParamsByID keys and values
|
|
||||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
|
||||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
|
||||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
|
||||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
|
||||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
newParamMap, ok := newValAny.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
|
||||||
}
|
|
||||||
newSetParamsByID[newKey] = newParamMap
|
|
||||||
}
|
|
||||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Substitute in metadata (type-preserving)
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
modelConfig.Metadata = result.(map[string]any)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle PORT macro - only allocate if cmd uses it
|
|
||||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
|
||||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
|
||||||
if cmdHasPort || proxyHasPort {
|
|
||||||
if !cmdHasPort && proxyHasPort {
|
|
||||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
|
||||||
}
|
|
||||||
|
|
||||||
macroSlug := "${PORT}"
|
|
||||||
macroStr := fmt.Sprintf("%v", nextPort)
|
|
||||||
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
|
||||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
|
||||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
|
||||||
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
modelConfig.Metadata = result.(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
nextPort++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate no unknown macros remain
|
|
||||||
fieldMap := map[string]string{
|
|
||||||
"cmd": modelConfig.Cmd,
|
|
||||||
"cmdStop": modelConfig.CmdStop,
|
|
||||||
"proxy": modelConfig.Proxy,
|
|
||||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
|
||||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
|
||||||
"name": modelConfig.Name,
|
|
||||||
"description": modelConfig.Description,
|
|
||||||
}
|
|
||||||
|
|
||||||
for fieldName, fieldValue := range fieldMap {
|
|
||||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
macroName := match[1]
|
|
||||||
if macroName == "PID" && fieldName == "cmdStop" {
|
|
||||||
continue // replaced at runtime
|
|
||||||
}
|
|
||||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
|
||||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
|
||||||
}
|
|
||||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate SetParamsByID keys and values
|
|
||||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
|
||||||
}
|
|
||||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
|
||||||
for key := range modelConfig.Filters.SetParamsByID {
|
|
||||||
if key == modelId {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, exists := config.Models[key]; exists {
|
|
||||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
|
||||||
}
|
|
||||||
if existingModel, exists := config.aliases[key]; exists {
|
|
||||||
if existingModel != modelId {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
|
||||||
}
|
|
||||||
continue // already registered as explicit alias for this model
|
|
||||||
}
|
|
||||||
config.aliases[key] = modelId
|
|
||||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelConfig.SendLoadingState == nil {
|
|
||||||
v := config.SendLoadingState
|
|
||||||
modelConfig.SendLoadingState = &v
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Models[modelId] = modelConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
|
||||||
// the new `routing.router` block are mutually exclusive: a config may use
|
|
||||||
// either style, never both.
|
|
||||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
|
||||||
rtr := config.Routing.Router
|
|
||||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
|
||||||
|
|
||||||
if hasTopLevel && hasRouting {
|
|
||||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasTopLevel {
|
|
||||||
// Both groups and matrix may be defined under routing.router.settings;
|
|
||||||
// routing.router.use selects which one is active, so there is no conflict.
|
|
||||||
rs := config.Routing.Router.Settings
|
|
||||||
switch config.Routing.Router.Use {
|
|
||||||
case "matrix":
|
|
||||||
if rs.Matrix == nil {
|
|
||||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
|
||||||
}
|
|
||||||
config.Matrix = rs.Matrix
|
|
||||||
case "group", "":
|
|
||||||
config.Groups = rs.Groups
|
|
||||||
default:
|
|
||||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// groups XOR matrix
|
|
||||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Matrix != nil {
|
|
||||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
|
||||||
}
|
|
||||||
config.Matrix.ExpandedSets = expandedSets
|
|
||||||
} else {
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
|
||||||
|
|
||||||
// Validate group members
|
|
||||||
memberUsage := make(map[string]string)
|
|
||||||
for groupID, groupConfig := range config.Groups {
|
|
||||||
prevSet := make(map[string]bool)
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
if _, found := prevSet[member]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
|
||||||
}
|
|
||||||
prevSet[member] = true
|
|
||||||
|
|
||||||
if existingGroup, exists := memberUsage[member]; exists {
|
|
||||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
|
||||||
}
|
|
||||||
memberUsage[member] = groupID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
|
||||||
// and new-style configs converge here. The Matrix pointer is shared so
|
|
||||||
// ExpandedSets stays in one place.
|
|
||||||
if config.Matrix != nil {
|
|
||||||
config.Routing.Router.Use = "matrix"
|
|
||||||
} else {
|
|
||||||
config.Routing.Router.Use = "group"
|
|
||||||
}
|
|
||||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
|
||||||
config.Routing.Router.Settings.Groups = config.Groups
|
|
||||||
|
|
||||||
if config.Routing.Scheduler.Use == "" {
|
|
||||||
config.Routing.Scheduler.Use = "fifo"
|
|
||||||
}
|
|
||||||
if config.Routing.Scheduler.Use != "fifo" {
|
|
||||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
|
|
||||||
}
|
|
||||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
|
||||||
if _, found := config.RealModelName(modelID); !found {
|
|
||||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up hooks preload
|
|
||||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
|
||||||
var toPreload []string
|
|
||||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
|
||||||
modelID = strings.TrimSpace(modelID)
|
|
||||||
if modelID == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if real, found := config.RealModelName(modelID); found {
|
|
||||||
toPreload = append(toPreload, real)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config.Hooks.OnStartup.Preload = toPreload
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate API keys (env macros already substituted at string level)
|
|
||||||
for i, apikey := range config.RequiredAPIKeys {
|
|
||||||
if apikey == "" {
|
|
||||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
|
||||||
}
|
|
||||||
if strings.Contains(apikey, " ") {
|
|
||||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
|
||||||
}
|
|
||||||
config.RequiredAPIKeys[i] = apikey
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process peers with global macro substitution
|
|
||||||
for peerName, peerConfig := range config.Peers {
|
|
||||||
// Substitute global macros (LIFO order)
|
|
||||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
|
||||||
entry := config.Macros[i]
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
|
||||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
|
||||||
|
|
||||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
|
||||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
|
||||||
|
|
||||||
// Substitute in setParams (type-preserving)
|
|
||||||
if len(peerConfig.Filters.SetParams) > 0 {
|
|
||||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
|
||||||
}
|
|
||||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate no unknown macros remain
|
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
|
||||||
}
|
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
|
||||||
}
|
|
||||||
if len(peerConfig.Filters.SetParams) > 0 {
|
|
||||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config.Peers[peerName] = peerConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewrites the yaml to include a default group with any orphaned models
|
// rewrites the yaml to include a default group with any orphaned models
|
||||||
func AddDefaultGroupToConfig(config Config) Config {
|
func AddDefaultGroupToConfig(config Config) Config {
|
||||||
|
|
||||||
@@ -692,233 +255,3 @@ func AddDefaultGroupToConfig(config Config) Config {
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
|
||||||
var cleanedLines []string
|
|
||||||
for _, line := range strings.Split(cmdStr, "\n") {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
// Skip comment lines
|
|
||||||
if strings.HasPrefix(trimmed, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Handle trailing backslashes by replacing with space
|
|
||||||
if strings.HasSuffix(trimmed, "\\") {
|
|
||||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
|
||||||
} else {
|
|
||||||
cleanedLines = append(cleanedLines, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// put it back together
|
|
||||||
cmdStr = strings.Join(cleanedLines, "\n")
|
|
||||||
|
|
||||||
// Split the command into arguments
|
|
||||||
var args []string
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
args = shlex.Windows.Split(cmdStr)
|
|
||||||
} else {
|
|
||||||
args = shlex.Posix.Split(cmdStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the command is not empty
|
|
||||||
if len(args) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty command")
|
|
||||||
}
|
|
||||||
|
|
||||||
return args, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func StripComments(cmdStr string) string {
|
|
||||||
var cleanedLines []string
|
|
||||||
for _, line := range strings.Split(cmdStr, "\n") {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
// Skip comment lines
|
|
||||||
if strings.HasPrefix(trimmed, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cleanedLines = append(cleanedLines, line)
|
|
||||||
}
|
|
||||||
return strings.Join(cleanedLines, "\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateMacro validates macro name and value constraints
|
|
||||||
func validateMacro(name string, value any) error {
|
|
||||||
if len(name) >= 64 {
|
|
||||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
|
||||||
}
|
|
||||||
if !macroNameRegex.MatchString(name) {
|
|
||||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate that value is a scalar type
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
// Check for self-reference
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", name)
|
|
||||||
if strings.Contains(v, macroSlug) {
|
|
||||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
|
||||||
}
|
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
|
||||||
// These types are allowed
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch name {
|
|
||||||
case "PORT", "MODEL_ID":
|
|
||||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
|
||||||
func validateNestedForUnknownMacros(value any, context string) error {
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
macroName := match[1]
|
|
||||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
|
||||||
}
|
|
||||||
// Check for unsubstituted env macros
|
|
||||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
|
||||||
for _, match := range envMatches {
|
|
||||||
varName := match[1]
|
|
||||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case map[string]any:
|
|
||||||
for _, val := range v {
|
|
||||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case []any:
|
|
||||||
for _, val := range v {
|
|
||||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Scalar types don't contain macros
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
|
||||||
// This is called once per macro, allowing LIFO substitution order
|
|
||||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
|
||||||
macroStr := fmt.Sprintf("%v", macroValue)
|
|
||||||
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
// Check if this is a direct macro substitution
|
|
||||||
if v == macroSlug {
|
|
||||||
return macroValue, nil
|
|
||||||
}
|
|
||||||
// Handle string interpolation
|
|
||||||
if strings.Contains(v, macroSlug) {
|
|
||||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
|
||||||
}
|
|
||||||
return v, nil
|
|
||||||
|
|
||||||
case map[string]any:
|
|
||||||
// Recursively process map values
|
|
||||||
newMap := make(map[string]any)
|
|
||||||
for key, val := range v {
|
|
||||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newMap[key] = newVal
|
|
||||||
}
|
|
||||||
return newMap, nil
|
|
||||||
|
|
||||||
case []any:
|
|
||||||
// Recursively process slice elements
|
|
||||||
newSlice := make([]any, len(v))
|
|
||||||
for i, val := range v {
|
|
||||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newSlice[i] = newVal
|
|
||||||
}
|
|
||||||
return newSlice, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Return scalar types as-is
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
|
||||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
|
||||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
|
||||||
// (which strips comments) and only checking the comment-free version for macros.
|
|
||||||
func substituteEnvMacros(s string) (string, error) {
|
|
||||||
// Unmarshal and remarshal to strip YAML comments
|
|
||||||
var raw any
|
|
||||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
|
||||||
// If YAML is invalid, fall back to scanning the original string
|
|
||||||
// so the user gets the env var error rather than a confusing YAML parse error
|
|
||||||
return substituteEnvMacrosInString(s, s)
|
|
||||||
}
|
|
||||||
clean, err := yaml.Marshal(raw)
|
|
||||||
if err != nil {
|
|
||||||
return substituteEnvMacrosInString(s, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
return substituteEnvMacrosInString(s, string(clean))
|
|
||||||
}
|
|
||||||
|
|
||||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
|
||||||
// them in target. This separation allows scanning comment-free YAML while
|
|
||||||
// substituting in the original string.
|
|
||||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
|
||||||
result := target
|
|
||||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
fullMatch := match[0] // ${env.VAR_NAME}
|
|
||||||
varName := match[1] // VAR_NAME
|
|
||||||
|
|
||||||
value, exists := os.LookupEnv(varName)
|
|
||||||
if !exists {
|
|
||||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sanitize the value for safe YAML substitution
|
|
||||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
result = strings.ReplaceAll(result, fullMatch, value)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
|
||||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
|
||||||
// for compatibility with double-quoted YAML strings.
|
|
||||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
|
||||||
// Reject values that would break YAML structure regardless of quoting context
|
|
||||||
if strings.ContainsAny(value, "\n\r\x00") {
|
|
||||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
|
||||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
|
||||||
// In double-quoted contexts, they are interpreted correctly.
|
|
||||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
|
||||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
|
||||||
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ groups:
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Scheduler: SchedulerConfig{
|
Scheduler: SchedulerConfig{
|
||||||
Use: "fifo",
|
Use: "serial",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -777,22 +777,27 @@ func TestConfig_APIKeys_Invalid(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "blank spaces only",
|
name: "blank spaces only",
|
||||||
content: `apiKeys: [" "]`,
|
content: `apiKeys: [" "]`,
|
||||||
expectedErr: "api key cannot contain spaces: ` `",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "contains leading space",
|
name: "contains leading space",
|
||||||
content: `apiKeys: [" key123"]`,
|
content: `apiKeys: [" key123"]`,
|
||||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "contains trailing space",
|
name: "contains trailing space",
|
||||||
content: `apiKeys: ["key123 "]`,
|
content: `apiKeys: ["key123 "]`,
|
||||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "contains middle space",
|
name: "contains middle space",
|
||||||
content: `apiKeys: ["key 123"]`,
|
content: `apiKeys: ["key 123"]`,
|
||||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "space in second key reports correct index",
|
||||||
|
content: `apiKeys: ["valid-key", "bad key"]`,
|
||||||
|
expectedErr: "apiKeys[1]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty in list with valid keys",
|
name: "empty in list with valid keys",
|
||||||
@@ -1567,7 +1572,7 @@ groups:
|
|||||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||||
// default group injected for orphaned models (none here) still leaves g1
|
// default group injected for orphaned models (none here) still leaves g1
|
||||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
||||||
@@ -1626,7 +1631,7 @@ func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
|
|||||||
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
||||||
|
|||||||
@@ -266,7 +266,7 @@ groups:
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Scheduler: SchedulerConfig{
|
Scheduler: SchedulerConfig{
|
||||||
Use: "fifo",
|
Use: "serial",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,441 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
yamlStr := string(data)
|
||||||
|
|
||||||
|
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||||
|
// This is safe because env values are simple strings without YAML formatting
|
||||||
|
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal into full Config with defaults
|
||||||
|
config := Config{
|
||||||
|
HealthCheckTimeout: 120,
|
||||||
|
StartPort: 5800,
|
||||||
|
LogLevel: "info",
|
||||||
|
LogTimeFormat: "",
|
||||||
|
LogToStdout: LogToStdoutProxy,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
CaptureBuffer: 5,
|
||||||
|
GlobalTTL: 0,
|
||||||
|
}
|
||||||
|
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HealthCheckTimeout < 15 {
|
||||||
|
config.HealthCheckTimeout = 15
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults for performance config when section is missing
|
||||||
|
if config.Performance.Every == 0 {
|
||||||
|
config.Performance.Every = 5 * time.Second
|
||||||
|
}
|
||||||
|
if err = config.Performance.Validate(); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("performance: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.StartPort < 1 {
|
||||||
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.GlobalTTL < 0 {
|
||||||
|
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply default for upstream.ignorePaths when not specified. The default
|
||||||
|
// matches common static-asset suffixes so they do not trigger a swap.
|
||||||
|
if len(config.Upstream.IgnorePaths) == 0 {
|
||||||
|
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.LogToStdout {
|
||||||
|
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the aliases map
|
||||||
|
config.aliases = make(map[string]string)
|
||||||
|
for modelName, modelConfig := range config.Models {
|
||||||
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
if _, found := config.aliases[alias]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||||
|
}
|
||||||
|
config.aliases[alias] = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate global macros
|
||||||
|
for _, macro := range config.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get and sort all model IDs for consistent port assignment
|
||||||
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
|
for modelId := range config.Models {
|
||||||
|
modelIds = append(modelIds, modelId)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIds)
|
||||||
|
|
||||||
|
nextPort := config.StartPort
|
||||||
|
for _, modelId := range modelIds {
|
||||||
|
modelConfig := config.Models[modelId]
|
||||||
|
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||||
|
|
||||||
|
// Strip comments from command fields
|
||||||
|
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||||
|
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||||
|
|
||||||
|
// set model TTL to globalTTL it is the default value
|
||||||
|
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||||
|
modelConfig.UnloadAfter = config.GlobalTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelConfig.UnloadAfter < 0 {
|
||||||
|
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate model macros
|
||||||
|
for _, macro := range modelConfig.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
mergedMacros = append(mergedMacros, config.Macros...)
|
||||||
|
|
||||||
|
// Add model macros (override globals with same name)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Substitute remaining macros in model fields (LIFO order)
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||||
|
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute macros in SetParamsByID keys and values
|
||||||
|
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||||
|
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||||
|
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||||
|
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||||
|
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
newParamMap, ok := newValAny.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||||
|
}
|
||||||
|
newSetParamsByID[newKey] = newParamMap
|
||||||
|
}
|
||||||
|
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Substitute in metadata (type-preserving)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle PORT macro - only allocate if cmd uses it
|
||||||
|
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||||
|
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||||
|
if cmdHasPort || proxyHasPort {
|
||||||
|
if !cmdHasPort && proxyHasPort {
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
macroSlug := "${PORT}"
|
||||||
|
macroStr := fmt.Sprintf("%v", nextPort)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||||
|
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||||
|
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate no unknown macros remain
|
||||||
|
fieldMap := map[string]string{
|
||||||
|
"cmd": modelConfig.Cmd,
|
||||||
|
"cmdStop": modelConfig.CmdStop,
|
||||||
|
"proxy": modelConfig.Proxy,
|
||||||
|
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||||
|
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||||
|
"name": modelConfig.Name,
|
||||||
|
"description": modelConfig.Description,
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, fieldValue := range fieldMap {
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
if macroName == "PID" && fieldName == "cmdStop" {
|
||||||
|
continue // replaced at runtime
|
||||||
|
}
|
||||||
|
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||||
|
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate SetParamsByID keys and values
|
||||||
|
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||||
|
}
|
||||||
|
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||||
|
for key := range modelConfig.Filters.SetParamsByID {
|
||||||
|
if key == modelId {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := config.Models[key]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||||
|
}
|
||||||
|
if existingModel, exists := config.aliases[key]; exists {
|
||||||
|
if existingModel != modelId {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||||
|
}
|
||||||
|
continue // already registered as explicit alias for this model
|
||||||
|
}
|
||||||
|
config.aliases[key] = modelId
|
||||||
|
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelConfig.SendLoadingState == nil {
|
||||||
|
v := config.SendLoadingState
|
||||||
|
modelConfig.SendLoadingState = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Models[modelId] = modelConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||||
|
// the new `routing.router` block are mutually exclusive: a config may use
|
||||||
|
// either style, never both.
|
||||||
|
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||||
|
rtr := config.Routing.Router
|
||||||
|
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||||
|
|
||||||
|
if hasTopLevel && hasRouting {
|
||||||
|
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasTopLevel {
|
||||||
|
// Both groups and matrix may be defined under routing.router.settings;
|
||||||
|
// routing.router.use selects which one is active, so there is no conflict.
|
||||||
|
rs := config.Routing.Router.Settings
|
||||||
|
switch config.Routing.Router.Use {
|
||||||
|
case "matrix":
|
||||||
|
if rs.Matrix == nil {
|
||||||
|
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||||
|
}
|
||||||
|
config.Matrix = rs.Matrix
|
||||||
|
case "group", "":
|
||||||
|
config.Groups = rs.Groups
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// groups XOR matrix
|
||||||
|
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Matrix != nil {
|
||||||
|
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||||
|
}
|
||||||
|
config.Matrix.ExpandedSets = expandedSets
|
||||||
|
} else {
|
||||||
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
|
||||||
|
// Validate group members
|
||||||
|
memberUsage := make(map[string]string)
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||||
|
// and new-style configs converge here. The Matrix pointer is shared so
|
||||||
|
// ExpandedSets stays in one place.
|
||||||
|
if config.Matrix != nil {
|
||||||
|
config.Routing.Router.Use = "matrix"
|
||||||
|
} else {
|
||||||
|
config.Routing.Router.Use = "group"
|
||||||
|
}
|
||||||
|
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||||
|
config.Routing.Router.Settings.Groups = config.Groups
|
||||||
|
|
||||||
|
// This fork defaults to the "serial" scheduler: one model loaded at a time,
|
||||||
|
// requests served in strict arrival order. Set use: fifo for the upstream
|
||||||
|
// throughput-oriented behavior that batches same-model requests.
|
||||||
|
if config.Routing.Scheduler.Use == "" {
|
||||||
|
config.Routing.Scheduler.Use = "serial"
|
||||||
|
}
|
||||||
|
switch config.Routing.Scheduler.Use {
|
||||||
|
case "fifo", "serial":
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo, serial)", config.Routing.Scheduler.Use)
|
||||||
|
}
|
||||||
|
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||||
|
if _, found := config.RealModelName(modelID); !found {
|
||||||
|
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up hooks preload
|
||||||
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
var toPreload []string
|
||||||
|
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if real, found := config.RealModelName(modelID); found {
|
||||||
|
toPreload = append(toPreload, real)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Hooks.OnStartup.Preload = toPreload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate API keys (env macros already substituted at string level)
|
||||||
|
for i, apikey := range config.RequiredAPIKeys {
|
||||||
|
if apikey == "" {
|
||||||
|
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||||
|
}
|
||||||
|
if strings.Contains(apikey, " ") {
|
||||||
|
return Config{}, fmt.Errorf("apiKeys[%d]: api key cannot contain spaces", i)
|
||||||
|
}
|
||||||
|
config.RequiredAPIKeys[i] = apikey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process peers with global macro substitution
|
||||||
|
for peerName, peerConfig := range config.Peers {
|
||||||
|
// Substitute global macros (LIFO order)
|
||||||
|
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||||
|
entry := config.Macros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||||
|
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in setParams (type-preserving)
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||||
|
}
|
||||||
|
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate no unknown macros remain
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Peers[peerName] = peerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||||
|
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||||
|
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// validateMacro validates macro name and value constraints
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
if len(name) >= 64 {
|
||||||
|
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||||
|
}
|
||||||
|
if !macroNameRegex.MatchString(name) {
|
||||||
|
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that value is a scalar type
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check for self-reference
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||||
|
// These types are allowed
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "PORT", "MODEL_ID":
|
||||||
|
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||||
|
func validateNestedForUnknownMacros(value any, context string) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||||
|
}
|
||||||
|
// Check for unsubstituted env macros
|
||||||
|
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range envMatches {
|
||||||
|
varName := match[1]
|
||||||
|
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Scalar types don't contain macros
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||||
|
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||||
|
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||||
|
// (which strips comments) and only checking the comment-free version for macros.
|
||||||
|
func substituteEnvMacros(s string) (string, error) {
|
||||||
|
// Unmarshal and remarshal to strip YAML comments
|
||||||
|
var raw any
|
||||||
|
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||||
|
// If YAML is invalid, fall back to scanning the original string
|
||||||
|
// so the user gets the env var error rather than a confusing YAML parse error
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
clean, err := yaml.Marshal(raw)
|
||||||
|
if err != nil {
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return substituteEnvMacrosInString(s, string(clean))
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||||
|
// them in target. This separation allows scanning comment-free YAML while
|
||||||
|
// substituting in the original string.
|
||||||
|
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||||
|
result := target
|
||||||
|
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
fullMatch := match[0] // ${env.VAR_NAME}
|
||||||
|
varName := match[1] // VAR_NAME
|
||||||
|
|
||||||
|
value, exists := os.LookupEnv(varName)
|
||||||
|
if !exists {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the value for safe YAML substitution
|
||||||
|
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = strings.ReplaceAll(result, fullMatch, value)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||||
|
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||||
|
// for compatibility with double-quoted YAML strings.
|
||||||
|
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||||
|
// Reject values that would break YAML structure regardless of quoting context
|
||||||
|
if strings.ContainsAny(value, "\n\r\x00") {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||||
|
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||||
|
// In double-quoted contexts, they are interpreted correctly.
|
||||||
|
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||||
|
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,300 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// identityMapPaths is the set of dotted paths whose direct children are
|
||||||
|
// identity-keyed maps. A child key present in two sources is a hard error;
|
||||||
|
// such keys name discrete entities (a model, a group, a peer, etc.) and a
|
||||||
|
// duplicate means the user has split one entity across files by mistake.
|
||||||
|
var identityMapPaths = map[string]bool{
|
||||||
|
"models": true,
|
||||||
|
"groups": true,
|
||||||
|
"profiles": true,
|
||||||
|
"peers": true,
|
||||||
|
"matrix": true,
|
||||||
|
"routing.router.settings.groups": true,
|
||||||
|
"routing.router.settings.matrix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfigSources loads and merges configuration from -config (optional)
|
||||||
|
// and -config-dir (optional). At least one must be provided. The -config file
|
||||||
|
// is loaded first; *.yml/*.yaml files directly under -config-dir are then
|
||||||
|
// merged in sorted filename order. The merged document is passed through the
|
||||||
|
// existing LoadConfigFromReader pipeline unchanged.
|
||||||
|
func LoadConfigSources(configPath, configDir string) (Config, error) {
|
||||||
|
if configPath == "" && configDir == "" {
|
||||||
|
return Config{}, fmt.Errorf("at least one of -config or -config-dir must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
var sourcePaths []string
|
||||||
|
|
||||||
|
if configPath != "" {
|
||||||
|
sourcePaths = append(sourcePaths, configPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if configDir != "" {
|
||||||
|
dirFiles, err := listYAMLFiles(configDir)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("-config-dir %s: %w", configDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if configPath != "" {
|
||||||
|
absConfig, err := filepath.Abs(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to resolve -config path: %w", err)
|
||||||
|
}
|
||||||
|
for _, f := range dirFiles {
|
||||||
|
absF, err := filepath.Abs(f)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to resolve config dir file %s: %w", f, err)
|
||||||
|
}
|
||||||
|
if absConfig == absF {
|
||||||
|
return Config{}, fmt.Errorf("-config path %s is also present in -config-dir %s; remove it from one", configPath, configDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sourcePaths = append(sourcePaths, dirFiles...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sourcePaths) == 0 {
|
||||||
|
return Config{}, fmt.Errorf("no configuration sources found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var merged *yaml.Node
|
||||||
|
for _, p := range sourcePaths {
|
||||||
|
node, err := parseSource(p)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
if node == nil {
|
||||||
|
continue // empty file
|
||||||
|
}
|
||||||
|
if merged == nil {
|
||||||
|
merged = node
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := mergeNodes(merged, node, "", p); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if merged == nil {
|
||||||
|
// All sources were empty; run the pipeline on empty input so defaults
|
||||||
|
// and validation still apply (e.g. startPort, performance defaults).
|
||||||
|
return LoadConfigFromReader(strings.NewReader(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := yaml.Marshal(merged)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to marshal merged config: %w", err)
|
||||||
|
}
|
||||||
|
return LoadConfigFromReader(strings.NewReader(string(out)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// listYAMLFiles returns the top-level *.yml and *.yaml files in dir, sorted by
|
||||||
|
// filename for deterministic merge order. Subdirectories are not traversed.
|
||||||
|
func listYAMLFiles(dir string) ([]string, error) {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var files []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := e.Name()
|
||||||
|
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
files = append(files, filepath.Join(dir, name))
|
||||||
|
}
|
||||||
|
sort.Strings(files)
|
||||||
|
return files, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSource reads and parses one YAML config file into a root mapping node.
|
||||||
|
// Returns a nil node (no error) when the file is empty or contains only
|
||||||
|
// comments.
|
||||||
|
//
|
||||||
|
// Env macros (${env.VAR}) are substituted at the string level before YAML
|
||||||
|
// parsing so that flow-style constructs like [${env.API_KEY}] parse
|
||||||
|
// correctly — the brace would otherwise be interpreted as a flow mapping.
|
||||||
|
func parseSource(path string) (*yaml.Node, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
yamlStr, err := substituteEnvMacros(string(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
var doc yaml.Node
|
||||||
|
if err := yaml.Unmarshal([]byte(yamlStr), &doc); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
// yaml.Unmarshal into a yaml.Node yields a DocumentNode whose Content[0]
|
||||||
|
// is the actual root. Unwrap it so callers see the real top-level node.
|
||||||
|
root := &doc
|
||||||
|
if root.Kind == yaml.DocumentNode && len(root.Content) > 0 {
|
||||||
|
root = root.Content[0]
|
||||||
|
}
|
||||||
|
if root.Kind == 0 || root.Content == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if root.Kind != yaml.MappingNode {
|
||||||
|
return nil, fmt.Errorf("config %s: top-level YAML must be a mapping", path)
|
||||||
|
}
|
||||||
|
return root, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeNodes merges src into dst (both MappingNodes) in place. Keys present in
|
||||||
|
// only one side are kept; shared keys are merged recursively under the rules
|
||||||
|
// in mergeValue. srcPath is included in error messages to identify the file
|
||||||
|
// that introduced the conflict.
|
||||||
|
func mergeNodes(dst, src *yaml.Node, path, srcPath string) error {
|
||||||
|
srcIdx := indexMapping(src)
|
||||||
|
|
||||||
|
// First pass: merge shared keys in place.
|
||||||
|
for i := 0; i+1 < len(dst.Content); i += 2 {
|
||||||
|
keyNode := dst.Content[i]
|
||||||
|
dstVal := dst.Content[i+1]
|
||||||
|
key := keyNode.Value
|
||||||
|
|
||||||
|
srcVal, ok := srcIdx[key]
|
||||||
|
if !ok {
|
||||||
|
continue // dst-only key, keep as-is
|
||||||
|
}
|
||||||
|
|
||||||
|
childPath := joinPath(path, key)
|
||||||
|
|
||||||
|
if identityMapPaths[childPath] {
|
||||||
|
// Identity-keyed map: each child key names a discrete entity
|
||||||
|
// (a model, group, peer, ...). A shared child key is a hard
|
||||||
|
// error; src-only children are appended in the second pass.
|
||||||
|
if err := mergeIdentityMap(dstVal, srcVal, childPath, key, srcPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mergeValue(dstVal, srcVal, childPath, srcPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: append src-only keys.
|
||||||
|
dstIdx := indexMapping(dst)
|
||||||
|
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||||
|
keyNode := src.Content[i]
|
||||||
|
srcVal := src.Content[i+1]
|
||||||
|
key := keyNode.Value
|
||||||
|
|
||||||
|
if _, ok := dstIdx[key]; ok {
|
||||||
|
continue // already merged above
|
||||||
|
}
|
||||||
|
keyCopy := *keyNode
|
||||||
|
valCopy := *srcVal
|
||||||
|
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeIdentityMap merges two identity-keyed mapping nodes (e.g. `models`,
|
||||||
|
// `groups`, `peers`). Any child key present in both sides is a duplicate
|
||||||
|
// entity and produces an error naming the conflicting key and source file.
|
||||||
|
// src-only keys are appended to dst.
|
||||||
|
func mergeIdentityMap(dst, src *yaml.Node, path, mapName, srcPath string) error {
|
||||||
|
if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode {
|
||||||
|
return fmt.Errorf("conflict at %q: expected a mapping, introduced by %s", path, srcPath)
|
||||||
|
}
|
||||||
|
dstIdx := indexMapping(dst)
|
||||||
|
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||||
|
keyNode := src.Content[i]
|
||||||
|
srcVal := src.Content[i+1]
|
||||||
|
key := keyNode.Value
|
||||||
|
if _, dup := dstIdx[key]; dup {
|
||||||
|
return fmt.Errorf("duplicate %s %q found in %s (already defined in another config source)", mapName, key, srcPath)
|
||||||
|
}
|
||||||
|
keyCopy := *keyNode
|
||||||
|
valCopy := *srcVal
|
||||||
|
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeValue merges srcVal into dstVal (both pointing into the parent's
|
||||||
|
// Content slice). Mapping+Mapping recurses; Sequence+Sequence concatenates;
|
||||||
|
// Scalar+Scalar errors on value mismatch; null on either side yields to the
|
||||||
|
// non-null side.
|
||||||
|
func mergeValue(dstVal, srcVal *yaml.Node, path, srcPath string) error {
|
||||||
|
switch {
|
||||||
|
case dstVal.Kind == yaml.MappingNode && srcVal.Kind == yaml.MappingNode:
|
||||||
|
return mergeNodes(dstVal, srcVal, path, srcPath)
|
||||||
|
|
||||||
|
case dstVal.Kind == yaml.SequenceNode && srcVal.Kind == yaml.SequenceNode:
|
||||||
|
dstVal.Content = append(dstVal.Content, srcVal.Content...)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case dstVal.Kind == yaml.ScalarNode && srcVal.Kind == yaml.ScalarNode:
|
||||||
|
if isNullScalar(dstVal) {
|
||||||
|
*dstVal = *srcVal
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if isNullScalar(srcVal) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if dstVal.Value == srcVal.Value && dstVal.Tag == srcVal.Tag {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("conflict at %q: %s sets a different value than a previous source", path, srcPath)
|
||||||
|
|
||||||
|
case isNull(dstVal):
|
||||||
|
*dstVal = *srcVal
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case isNull(srcVal):
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("conflict at %q: incompatible YAML node kinds (kind %d vs %d) introduced by %s", path, dstVal.Kind, srcVal.Kind, srcPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNull reports whether n represents a YAML null (empty or !!null).
|
||||||
|
func isNull(n *yaml.Node) bool {
|
||||||
|
if n == nil || n.Kind == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return isNullScalar(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNullScalar(n *yaml.Node) bool {
|
||||||
|
return n.Kind == yaml.ScalarNode && (n.Tag == "!!null" || n.Tag == "") && n.Value == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// indexMapping builds a key -> value-node index for a mapping node.
|
||||||
|
func indexMapping(n *yaml.Node) map[string]*yaml.Node {
|
||||||
|
idx := make(map[string]*yaml.Node, len(n.Content)/2)
|
||||||
|
for i := 0; i+1 < len(n.Content); i += 2 {
|
||||||
|
idx[n.Content[i].Value] = n.Content[i+1]
|
||||||
|
}
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinPath(parent, key string) string {
|
||||||
|
if parent == "" {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
return parent + "." + key
|
||||||
|
}
|
||||||
@@ -0,0 +1,304 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeYAML writes content to a file named name inside dir. Returns the full
|
||||||
|
// path of the written file.
|
||||||
|
func writeYAML(t *testing.T, dir, name, content string) string {
|
||||||
|
t.Helper()
|
||||||
|
p := filepath.Join(dir, name)
|
||||||
|
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(p, []byte(content), 0o644))
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelCfg builds a single-model YAML snippet indented for nesting under a
|
||||||
|
// `models:` key. The proxy uses a fixed port so tests don't depend on
|
||||||
|
// ${PORT} allocation.
|
||||||
|
func modelCfg(id, cmd string) string {
|
||||||
|
return " " + id + ":\n cmd: " + cmd + "\n proxy: \"http://localhost:9999\"\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_NeitherProvided(t *testing.T) {
|
||||||
|
_, err := LoadConfigSources("", "")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "at least one of -config or -config-dir")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ConfigOnly(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "config.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("model1", "echo hi")+`
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
members: ["model1"]
|
||||||
|
`)
|
||||||
|
cfg, err := LoadConfigSources(cfgPath, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, id, ok := cfg.FindConfig("model1")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "model1", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DirOnly(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("alpha", "echo a"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("beta", "echo b"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, want := range []string{"alpha", "beta"} {
|
||||||
|
_, _, ok := cfg.FindConfig(want)
|
||||||
|
assert.True(t, ok, "model %s should be present", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ConfigPlusDirAdditive(t *testing.T) {
|
||||||
|
// -config lives outside -config-dir; both contribute models additively.
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "config.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||||
|
cfgDir := t.TempDir()
|
||||||
|
writeYAML(t, cfgDir, "extra.yaml", "models:\n"+modelCfg("ext", "echo ext"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, want := range []string{"base", "ext"} {
|
||||||
|
_, _, ok := cfg.FindConfig(want)
|
||||||
|
assert.True(t, ok, "model %s should be present after merge", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadConfigSources_ConfigInDirOverlap verifies that a -config file that
|
||||||
|
// is also a member of -config-dir is rejected.
|
||||||
|
func TestLoadConfigSources_ConfigInDirOverlap(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||||
|
|
||||||
|
_, err := LoadConfigSources(cfgPath, dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "is also present in -config-dir")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DuplicateModelID(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("dup", "echo a"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("dup", "echo b"))
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `duplicate models "dup"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DuplicateGroupID(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m1", "echo m1")+"groups:\n g1:\n members: [m1]\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m2", "echo m2")+"groups:\n g1:\n members: [m2]\n")
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `duplicate groups "g1"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DuplicatePeer(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
peerA := "peers:\n remote:\n proxy: http://x:1\n models: [m1]\n"
|
||||||
|
peerB := "peers:\n remote:\n proxy: http://x:2\n models: [m2]\n"
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\n"+peerA)
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\n"+peerB)
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `duplicate peers "remote"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ScalarConflict(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 200\n")
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `conflict at "globalTTL"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ScalarSameValueNoConflict(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 100\n")
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 100, cfg.GlobalTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_MacrosConcatenate(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "macros:\n LOW: 1\nmodels:\n"+modelCfg("m1", "echo ${LOW}"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "macros:\n HIGH: 2\nmodels:\n"+modelCfg("m2", "echo ${HIGH}"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Both macros are available globally after merge.
|
||||||
|
low, ok := cfg.Macros.Get("LOW")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, 1, low)
|
||||||
|
high, ok := cfg.Macros.Get("HIGH")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, 2, high)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_APIKeysConcatenate(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\napiKeys: [key-a]\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\napiKeys: [key-b]\n")
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.ElementsMatch(t, []string{"key-a", "key-b"}, cfg.RequiredAPIKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_RoutingGroupsMerge(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m1", "echo m1")+`
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
settings:
|
||||||
|
groups:
|
||||||
|
groupA:
|
||||||
|
members: [m1]
|
||||||
|
`)
|
||||||
|
writeYAML(t, dir, "b.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m2", "echo m2")+`
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
settings:
|
||||||
|
groups:
|
||||||
|
groupB:
|
||||||
|
members: [m2]
|
||||||
|
`)
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
groups := cfg.Routing.Router.Settings.Groups
|
||||||
|
assert.Contains(t, groups, "groupA")
|
||||||
|
assert.Contains(t, groups, "groupB")
|
||||||
|
// default group added by pipeline for orphaned/leftover routing groups...
|
||||||
|
// here both groups reference distinct models
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EnvMacrosSubstituted(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
// Use ${PORT} in cmd so the pipeline allocates a port and substitutes it;
|
||||||
|
// verifies env/macro substitution runs on the merged document.
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: serve --port ${PORT}\n proxy: \"http://localhost:${PORT}\"\n")
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m := cfg.Models["m1"]
|
||||||
|
assert.NotContains(t, m.Cmd, "${PORT}", "PORT macro should have been substituted")
|
||||||
|
assert.NotContains(t, m.Proxy, "${PORT}", "PORT macro should have been substituted in proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EnvMacroInFlowStyleList(t *testing.T) {
|
||||||
|
// Regression: flow-style lists with ${env.*} must parse. Previously
|
||||||
|
// parseSource unmarshalled before env substitution, so the brace in
|
||||||
|
// [${env.API_KEY}] was misread as a flow mapping and parsing failed.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: echo hi\n proxy: \"http://localhost:9999\"\n")
|
||||||
|
writeYAML(t, dir, "keys.yaml", "apiKeys: [${env.TEST_API_KEY}]\nmodels:\n m2:\n cmd: echo hi\n proxy: \"http://localhost:9998\"\n")
|
||||||
|
|
||||||
|
t.Setenv("TEST_API_KEY", "secret123")
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, cfg.RequiredAPIKeys, "secret123")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_SortedOrderDeterministic(t *testing.T) {
|
||||||
|
// Two files defining distinct models, scanned in z..a order by filename.
|
||||||
|
// Determine merged result is the same regardless of how the FS returns them.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "z.yaml", "models:\n"+modelCfg("zmodel", "echo z"))
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("amodel", "echo a"))
|
||||||
|
|
||||||
|
const runs = 3
|
||||||
|
for i := 0; i < runs; i++ {
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// startPort-based allocation: first allocated model gets 5800.
|
||||||
|
// Sorted order means amodel gets 5800, zmodel gets 5801.
|
||||||
|
_, _, ok := cfg.FindConfig("amodel")
|
||||||
|
assert.True(t, ok)
|
||||||
|
_, _, ok = cfg.FindConfig("zmodel")
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EmptyDirWithConfig(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgDir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, cfg.Models, "m1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EmptyDirOnly(t *testing.T) {
|
||||||
|
// An empty -config-dir with no -config is an error: there is nothing to
|
||||||
|
// load and silently producing an empty config would mask the misconfig.
|
||||||
|
cfgDir := t.TempDir()
|
||||||
|
_, err := LoadConfigSources("", cfgDir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no configuration sources found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_AssertNoUnknownMacrosAfterMerge(t *testing.T) {
|
||||||
|
// Macros defined in one file should not satisfy unknown-macro validation in
|
||||||
|
// another — they do, because merge concats global macros before validation
|
||||||
|
// runs. This test documents that a macro from file A is usable in file B.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "macros.yaml", "macros:\n SHARED: hello\nmodels:\n"+modelCfg("dummy", "echo dummy"))
|
||||||
|
writeYAML(t, dir, "use.yaml", "models:\n"+modelCfg("user", "echo ${SHARED}"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m := cfg.Models["user"]
|
||||||
|
assert.Contains(t, m.Cmd, "hello")
|
||||||
|
assert.NotContains(t, m.Cmd, "${SHARED}")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_KindMismatchErrors(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "startPort: 5800\nmodels:\n"+modelCfg("m1", "echo m1"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "startPort: [5800, 5801]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "incompatible YAML node kinds")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_NullYieldsToValue(t *testing.T) {
|
||||||
|
// File A: routing.router block absent (null on root for routing);
|
||||||
|
// file B: defines routing.router.settings.groups. Merge should keep B's.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "routing:\n router:\n settings:\n groups:\n g1:\n members: [m1]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||||
|
}
|
||||||
@@ -92,9 +92,14 @@ type Effects interface {
|
|||||||
StopProcesses(timeout time.Duration, ids []string)
|
StopProcesses(timeout time.Duration, ids []string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured
|
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured from
|
||||||
// from conf and bound to the given planner and effects. Currently only "fifo"
|
// conf and bound to the given planner and effects. Supported values are "fifo"
|
||||||
// (the default) is supported.
|
// (throughput-oriented, batches same-model requests) and "serial" (strict
|
||||||
|
// one-model-at-a-time, exact arrival order).
|
||||||
|
//
|
||||||
|
// The deployment default is applied by config loading (LoadConfig sets Use to
|
||||||
|
// "serial" when unset). The "" fallback here is the library default and remains
|
||||||
|
// "fifo" so callers that build a Config directly keep the original behavior.
|
||||||
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
|
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
|
||||||
use := conf.Routing.Scheduler.Use
|
use := conf.Routing.Scheduler.Use
|
||||||
if use == "" {
|
if use == "" {
|
||||||
@@ -103,6 +108,9 @@ func New(conf config.Config, name string, logger *logmon.Monitor, planner Swappe
|
|||||||
switch use {
|
switch use {
|
||||||
case "fifo":
|
case "fifo":
|
||||||
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
|
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
|
||||||
|
case "serial":
|
||||||
|
// Serial ignores the group planner: it always evicts every other model.
|
||||||
|
return NewSerial(name, logger, eff), nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
|
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,253 @@
|
|||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Serial is a strict one-model-at-a-time scheduler. Unlike FIFO it never reorders
|
||||||
|
// or batches: requests run in exact arrival order and at most one request runs at
|
||||||
|
// any instant. When the next request targets a model other than the one loaded,
|
||||||
|
// every other running model is evicted and the target is loaded before it runs,
|
||||||
|
// so a single model occupies memory at a time — at the cost of throughput.
|
||||||
|
//
|
||||||
|
// Example: A B C A is served as A B C A. The final A reloads its model even
|
||||||
|
// though it ran first, because B and C displaced it in between. (FIFO, by
|
||||||
|
// contrast, would batch the two A requests: A A B C.)
|
||||||
|
//
|
||||||
|
// Serial ignores group/eviction policy entirely: it always evicts every other
|
||||||
|
// running model, regardless of how groups are configured. That is what makes the
|
||||||
|
// single-model guarantee a property of the scheduler rather than of the config.
|
||||||
|
//
|
||||||
|
// Like FIFO, every method runs on the router's single run-loop goroutine, so no
|
||||||
|
// internal locking is needed.
|
||||||
|
type Serial struct {
|
||||||
|
name string
|
||||||
|
logger *logmon.Monitor
|
||||||
|
effects Effects
|
||||||
|
|
||||||
|
// queued holds requests in strict arrival order. It is never reordered.
|
||||||
|
queued []HandlerReq
|
||||||
|
|
||||||
|
// active is the one request currently being processed (loading or serving),
|
||||||
|
// or nil when idle. phase is meaningful only while active != nil.
|
||||||
|
active *HandlerReq
|
||||||
|
phase serialPhase
|
||||||
|
}
|
||||||
|
|
||||||
|
// serialPhase is the lifecycle stage of the active request.
|
||||||
|
type serialPhase int
|
||||||
|
|
||||||
|
const (
|
||||||
|
phaseIdle serialPhase = iota
|
||||||
|
phaseSwapping // waiting for OnSwapDone for active.Model
|
||||||
|
phaseServing // waiting for OnServeDone for active.Model
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewSerial builds a Serial scheduler. It takes no Swapper: eviction is always
|
||||||
|
// "stop every other running model", so the group planner is not consulted.
|
||||||
|
func NewSerial(name string, logger *logmon.Monitor, eff Effects) *Serial {
|
||||||
|
return &Serial{
|
||||||
|
name: name,
|
||||||
|
logger: logger,
|
||||||
|
effects: eff,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnRequest validates the model and appends the request to the tail of the queue,
|
||||||
|
// then tries to start the next job. Unknown models fail immediately.
|
||||||
|
func (s *Serial) OnRequest(req HandlerReq) {
|
||||||
|
if _, ok := s.effects.ModelState(req.Model); !ok {
|
||||||
|
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
|
||||||
|
s.effects.GrantError(req, ErrModelNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.queued = append(s.queued, req)
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
s.startNext()
|
||||||
|
}
|
||||||
|
|
||||||
|
// startNext begins processing the head of the queue when nothing is active. It
|
||||||
|
// fast-paths a request whose model is already the sole loaded-and-ready process;
|
||||||
|
// otherwise it launches a swap that evicts every other running model first. The
|
||||||
|
// loop skips over requests for models that vanished (e.g. a config reload) and
|
||||||
|
// requests whose caller disconnected before they could be served.
|
||||||
|
func (s *Serial) startNext() {
|
||||||
|
if s.active != nil {
|
||||||
|
return // a job is already loading or serving
|
||||||
|
}
|
||||||
|
for len(s.queued) > 0 {
|
||||||
|
req := s.queued[0]
|
||||||
|
s.queued = s.queued[1:]
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
|
||||||
|
state, ok := s.effects.ModelState(req.Model)
|
||||||
|
if !ok {
|
||||||
|
s.effects.GrantError(req, ErrModelNotFound)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
r := req
|
||||||
|
s.active = &r
|
||||||
|
|
||||||
|
evict := s.otherRunning(req.Model)
|
||||||
|
if state == process.StateReady && len(evict) == 0 {
|
||||||
|
// Already loaded and the only model running — serve immediately.
|
||||||
|
s.logger.Debugf("%s: serving model %s (already loaded)", s.name, req.Model)
|
||||||
|
if s.serve() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue // caller gone; pick the next request
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Debugf("%s: swapping to model %s, evicting %v", s.name, req.Model, evict)
|
||||||
|
s.phase = phaseSwapping
|
||||||
|
s.effects.StartSwap(req.Model, evict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// serve hands the active request its tracked handler. It returns true when the
|
||||||
|
// request is now serving (await OnServeDone); false when the caller had already
|
||||||
|
// disconnected, in which case active is cleared so the next job can start.
|
||||||
|
func (s *Serial) serve() bool {
|
||||||
|
if s.effects.GrantServe(*s.active, s.active.Model) {
|
||||||
|
s.phase = phaseServing
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
s.logger.Debugf("%s: caller for model %s gone before serve", s.name, s.active.Model)
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnSwapDone fires when the load for the active request completes. On success the
|
||||||
|
// request is served; on failure its caller receives the error and the queue
|
||||||
|
// advances. A SwapDone that does not match the active load (e.g. its request was
|
||||||
|
// unloaded or cancelled mid-load) is ignored.
|
||||||
|
func (s *Serial) OnSwapDone(ev SwapDone) {
|
||||||
|
if s.active == nil || s.phase != phaseSwapping || s.active.Model != ev.ModelID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ev.Err != nil {
|
||||||
|
s.logger.Debugf("%s: swap for model %s failed: %v", s.name, ev.ModelID, ev.Err)
|
||||||
|
s.effects.GrantError(*s.active, ev.Err)
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
s.startNext()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.serve() {
|
||||||
|
s.startNext() // caller vanished while the model loaded; move on
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnServeDone fires when the active request's handler returns. The slot is freed
|
||||||
|
// and the next queued request begins.
|
||||||
|
func (s *Serial) OnServeDone(ev ServeDoneEvent) {
|
||||||
|
if s.active == nil || s.phase != phaseServing {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
s.startNext()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnCancel removes a disconnected client's request from the queue. A request that
|
||||||
|
// is already active is left to finish: if it was loading, OnSwapDone's serve()
|
||||||
|
// will find the caller gone (GrantServe false) and advance; if it was serving,
|
||||||
|
// its handler returns normally and reaches OnServeDone.
|
||||||
|
func (s *Serial) OnCancel(req HandlerReq) {
|
||||||
|
if len(s.queued) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
kept := s.queued[:0]
|
||||||
|
removed := false
|
||||||
|
for _, q := range s.queued {
|
||||||
|
if q.Respond == req.Respond {
|
||||||
|
removed = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, q)
|
||||||
|
}
|
||||||
|
s.queued = kept
|
||||||
|
if removed {
|
||||||
|
s.logger.Debugf("%s: cancelled request for model %s pruned from queue", s.name, req.Model)
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnUnload reconciles state for an unload, stops the targeted processes, and
|
||||||
|
// advances the queue. It mirrors the FIFO contract: queued requests for unloaded
|
||||||
|
// models are failed; an active *loading* request for an unloaded model is failed
|
||||||
|
// (its swap goroutine is left to finish and its SwapDone is then ignored); an
|
||||||
|
// active *serving* request is left for its handler to end when StopProcesses
|
||||||
|
// kills the upstream. The Stop is synchronous so callers of Unload can rely on
|
||||||
|
// the processes being stopped on return.
|
||||||
|
func (s *Serial) OnUnload(targets []string, timeout time.Duration) {
|
||||||
|
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
|
||||||
|
|
||||||
|
targetSet := make(map[string]bool, len(targets))
|
||||||
|
for _, id := range targets {
|
||||||
|
targetSet[id] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.active != nil && s.phase == phaseSwapping && targetSet[s.active.Model] {
|
||||||
|
s.effects.GrantError(*s.active, unloadErr)
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.queued) > 0 {
|
||||||
|
kept := s.queued[:0]
|
||||||
|
for _, q := range s.queued {
|
||||||
|
if targetSet[q.Model] {
|
||||||
|
s.effects.GrantError(q, unloadErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, q)
|
||||||
|
}
|
||||||
|
s.queued = kept
|
||||||
|
broadcastQueuePositions(s.queued)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.effects.StopProcesses(timeout, targets)
|
||||||
|
|
||||||
|
// A still-serving active request advances via OnServeDone when its killed
|
||||||
|
// handler returns; only start the next job when nothing is active now.
|
||||||
|
if s.active == nil {
|
||||||
|
s.startNext()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnShutdown grants err to every request the scheduler still holds: an active
|
||||||
|
// loading request and all queued requests. A serving request is torn down with
|
||||||
|
// its process by the baseRouter.
|
||||||
|
func (s *Serial) OnShutdown(err error) {
|
||||||
|
if s.active != nil && s.phase == phaseSwapping {
|
||||||
|
s.effects.GrantError(*s.active, err)
|
||||||
|
s.active = nil
|
||||||
|
s.phase = phaseIdle
|
||||||
|
}
|
||||||
|
for _, q := range s.queued {
|
||||||
|
s.effects.GrantError(q, err)
|
||||||
|
}
|
||||||
|
s.queued = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherRunning returns every running model except target, sorted for
|
||||||
|
// deterministic eviction.
|
||||||
|
func (s *Serial) otherRunning(target string) []string {
|
||||||
|
var out []string
|
||||||
|
for id := range s.effects.RunningModels() {
|
||||||
|
if id != target {
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(out)
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -0,0 +1,391 @@
|
|||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||||
|
"github.com/mostlygeek/llama-swap/internal/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Serial methods all run on the router's single run-loop goroutine, so these
|
||||||
|
// tests drive them directly and synchronously, reusing fakeEffects and the
|
||||||
|
// req/reqCh helpers from fifo_test.go. A load completes via OnSwapDone and a
|
||||||
|
// served request finishes via OnServeDone — the events the run loop delivers.
|
||||||
|
|
||||||
|
func newSerial(eff Effects) *Serial {
|
||||||
|
return NewSerial("test", logmon.NewWriter(io.Discard), eff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lastStart returns the most recent StartSwap record.
|
||||||
|
func lastStart(t *testing.T, eff *fakeEffects) startRec {
|
||||||
|
t.Helper()
|
||||||
|
if len(eff.starts) == 0 {
|
||||||
|
t.Fatal("no StartSwap recorded")
|
||||||
|
}
|
||||||
|
return eff.starts[len(eff.starts)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func sameSet(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m := map[string]int{}
|
||||||
|
for _, x := range a {
|
||||||
|
m[x]++
|
||||||
|
}
|
||||||
|
for _, x := range b {
|
||||||
|
m[x]--
|
||||||
|
}
|
||||||
|
for _, v := range m {
|
||||||
|
if v != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// servedOrder returns the model IDs of every successful serve grant in order.
|
||||||
|
func servedOrder(eff *fakeEffects) []string {
|
||||||
|
var out []string
|
||||||
|
for _, g := range eff.grants {
|
||||||
|
if g.err == nil && g.serve {
|
||||||
|
out = append(out, g.model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_FastPath_AlreadyLoaded(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
|
||||||
|
if got := len(eff.starts); got != 0 {
|
||||||
|
t.Errorf("StartSwap calls=%d want 0 (already loaded, no swap)", got)
|
||||||
|
}
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Errorf("served(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_ColdStart_LoadsThenServes(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
if got := eff.startsFor("a"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
if got := eff.served("a"); got != 0 {
|
||||||
|
t.Errorf("served(a)=%d want 0 before load completes", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Errorf("served(a)=%d want 1 after load", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_UnknownModel(t *testing.T) {
|
||||||
|
eff := newFakeEffects() // no states => unknown
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("ghost"))
|
||||||
|
|
||||||
|
if len(eff.starts) != 0 {
|
||||||
|
t.Errorf("StartSwap calls=%d want 0", len(eff.starts))
|
||||||
|
}
|
||||||
|
if eff.errored("ghost") != 1 {
|
||||||
|
t.Fatalf("errored(ghost)=%d want 1", eff.errored("ghost"))
|
||||||
|
}
|
||||||
|
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
|
||||||
|
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_EvictsEveryOtherModel(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["x"] = process.StateReady // already running
|
||||||
|
eff.states["y"] = process.StateReady // also running (e.g. left over)
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
|
||||||
|
st := lastStart(t, eff)
|
||||||
|
if st.model != "a" {
|
||||||
|
t.Fatalf("loading %s want a", st.model)
|
||||||
|
}
|
||||||
|
if !sameSet(st.evict, []string{"x", "y"}) {
|
||||||
|
t.Errorf("evict=%v want [x y] (serial evicts ALL other models)", st.evict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_OneJobAtATime verifies a second request waits while the first is
|
||||||
|
// serving, and only starts after the first finishes.
|
||||||
|
func TestSerial_OneJobAtATime(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // served immediately
|
||||||
|
s.OnRequest(req("b")) // must wait — a is serving
|
||||||
|
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 0 while a is serving", got)
|
||||||
|
}
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Fatalf("served(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// a finishes -> b may now load (evicting a).
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 after a finished", got)
|
||||||
|
}
|
||||||
|
if st := lastStart(t, eff); !sameSet(st.evict, []string{"a"}) {
|
||||||
|
t.Errorf("b evict=%v want [a]", st.evict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_SameModelConsecutive_NoReload verifies back-to-back requests for the
|
||||||
|
// already-loaded model run without a reload, one after another.
|
||||||
|
func TestSerial_SameModelConsecutive_NoReload(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // cold load
|
||||||
|
s.OnRequest(req("a")) // queued behind the first
|
||||||
|
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"}) // first serves
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Fatalf("served(a)=%d want 1 (one at a time)", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // first done -> second serves
|
||||||
|
if got := eff.served("a"); got != 2 {
|
||||||
|
t.Fatalf("served(a)=%d want 2", got)
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("a"); got != 1 {
|
||||||
|
t.Errorf("StartSwap(a)=%d want 1 (second request must not reload)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_StrictArrivalOrder is the core guarantee: qwen36, qwen35, sdxl,
|
||||||
|
// qwen36 execute in EXACTLY that order with evictions between each model switch,
|
||||||
|
// including reloading qwen36 at the end even though it ran first.
|
||||||
|
func TestSerial_StrictArrivalOrder(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
for _, m := range []string{"qwen36", "qwen35", "sdxl"} {
|
||||||
|
eff.states[m] = process.StateStopped
|
||||||
|
}
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
for _, m := range []string{"qwen36", "qwen35", "sdxl", "qwen36"} {
|
||||||
|
s.OnRequest(req(m))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only the first job starts loading; the rest wait their turn.
|
||||||
|
if len(eff.starts) != 1 || eff.starts[0].model != "qwen36" {
|
||||||
|
t.Fatalf("starts=%+v want only [qwen36] loading first", eff.starts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// step completes the current model's load+serve and returns control to the
|
||||||
|
// scheduler, which must start the next queued model.
|
||||||
|
step := func(model string, wantEvict []string) {
|
||||||
|
t.Helper()
|
||||||
|
st := lastStart(t, eff)
|
||||||
|
if st.model != model {
|
||||||
|
t.Fatalf("loading %q want %q", st.model, model)
|
||||||
|
}
|
||||||
|
if !sameSet(st.evict, wantEvict) {
|
||||||
|
t.Fatalf("loading %q evict=%v want %v", model, st.evict, wantEvict)
|
||||||
|
}
|
||||||
|
// Simulate the eviction + load actually happening.
|
||||||
|
for _, e := range st.evict {
|
||||||
|
eff.states[e] = process.StateStopped
|
||||||
|
}
|
||||||
|
eff.states[model] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: model})
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: model})
|
||||||
|
}
|
||||||
|
|
||||||
|
step("qwen36", nil) // cold load, nothing else running
|
||||||
|
step("qwen35", []string{"qwen36"}) // evict qwen36
|
||||||
|
step("sdxl", []string{"qwen35"}) // evict qwen35
|
||||||
|
step("qwen36", []string{"sdxl"}) // RELOAD qwen36, evict sdxl
|
||||||
|
|
||||||
|
want := []string{"qwen36", "qwen35", "sdxl", "qwen36"}
|
||||||
|
if got := servedOrder(eff); !sameOrder(got, want) {
|
||||||
|
t.Fatalf("serve order=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sameOrder(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_SwapError_FailsCallerAndAdvances(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("b")) // queued behind a
|
||||||
|
|
||||||
|
// a's load fails: its caller is errored and b proceeds.
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
|
||||||
|
if eff.errored("a") != 1 {
|
||||||
|
t.Fatalf("errored(a)=%d want 1", eff.errored("a"))
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 after a's load failed", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_GrantServeFalse_Advances verifies that when the active request's
|
||||||
|
// caller has disconnected by serve time, the queue advances to the next request.
|
||||||
|
func TestSerial_GrantServeFalse_Advances(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
eff.serveResult["a"] = false // a's caller is gone by grant time
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a"))
|
||||||
|
s.OnRequest(req("b")) // queued
|
||||||
|
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails -> advance to b
|
||||||
|
|
||||||
|
if got := eff.served("a"); got != 0 {
|
||||||
|
t.Errorf("served(a)=%d want 0 (caller gone)", got)
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 (advanced after gone caller)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_OnCancel_QueuedRequest(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(reqCh("a")) // starts loading a
|
||||||
|
cancelled := reqCh("b")
|
||||||
|
s.OnRequest(cancelled) // queued behind a
|
||||||
|
if len(s.queued) != 1 {
|
||||||
|
t.Fatalf("queued=%d want 1", len(s.queued))
|
||||||
|
}
|
||||||
|
|
||||||
|
s.OnCancel(cancelled)
|
||||||
|
if len(s.queued) != 0 {
|
||||||
|
t.Fatalf("queued=%d want 0 after cancel", len(s.queued))
|
||||||
|
}
|
||||||
|
|
||||||
|
// a completes; b is gone, so nothing starts for it.
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Errorf("StartSwap(b)=%d want 0 (cancelled before its turn)", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_OnShutdown_FailsQueuedAndActiveLoad(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
eff.states["c"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // active (loading)
|
||||||
|
s.OnRequest(req("b")) // queued
|
||||||
|
s.OnRequest(req("c")) // queued
|
||||||
|
|
||||||
|
s.OnShutdown(errors.New("shutting down"))
|
||||||
|
|
||||||
|
if got := eff.errored(""); got != 3 {
|
||||||
|
t.Errorf("error grants=%d want 3 (active load + 2 queued)", got)
|
||||||
|
}
|
||||||
|
if len(s.queued) != 0 {
|
||||||
|
t.Errorf("queued=%d want 0 after shutdown", len(s.queued))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSerial_OnUnload_WhileServing verifies that unloading the model that is
|
||||||
|
// actively serving does not strand the queue: OnUnload stops the process but
|
||||||
|
// leaves the active request to end via OnServeDone, which then advances.
|
||||||
|
func TestSerial_OnUnload_WhileServing(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateReady
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // served immediately (a ready)
|
||||||
|
s.OnRequest(req("b")) // queued behind a
|
||||||
|
if got := eff.served("a"); got != 1 {
|
||||||
|
t.Fatalf("served(a)=%d want 1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unload a while it is serving: the process is stopped, but the queue must
|
||||||
|
// not advance yet — the active serve is still outstanding.
|
||||||
|
s.OnUnload([]string{"a"}, time.Second)
|
||||||
|
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||||
|
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||||
|
}
|
||||||
|
if got := eff.startsFor("b"); got != 0 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 0 before the serving request ends", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The killed handler returns -> OnServeDone advances to b.
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Fatalf("StartSwap(b)=%d want 1 after the serving request ended", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerial_OnUnload_DropsQueuedAndStops(t *testing.T) {
|
||||||
|
eff := newFakeEffects()
|
||||||
|
eff.states["a"] = process.StateStopped
|
||||||
|
eff.states["b"] = process.StateStopped
|
||||||
|
s := newSerial(eff)
|
||||||
|
|
||||||
|
s.OnRequest(req("a")) // active (loading a)
|
||||||
|
s.OnRequest(req("b")) // queued
|
||||||
|
|
||||||
|
// Unload a: its active load is failed and a is stopped.
|
||||||
|
s.OnUnload([]string{"a"}, time.Second)
|
||||||
|
|
||||||
|
if eff.errored("a") != 1 {
|
||||||
|
t.Errorf("errored(a)=%d want 1 (active load failed)", eff.errored("a"))
|
||||||
|
}
|
||||||
|
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||||
|
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||||
|
}
|
||||||
|
// b was queued and not unloaded; with a's load cancelled it now starts.
|
||||||
|
if got := eff.startsFor("b"); got != 1 {
|
||||||
|
t.Errorf("StartSwap(b)=%d want 1 after unload advanced the queue", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DirWatcher polls a directory for changes to its set of *.yml / *.yaml files.
|
||||||
|
// It fires OnChange when a file is added, removed, or has its mod time/size
|
||||||
|
// change. Like Watcher it is poll-based so it works in Docker bind-mounts and
|
||||||
|
// k8s ConfigMap projections where inotify is unreliable.
|
||||||
|
//
|
||||||
|
// The baseline poll establishes initial state and does not fire OnChange.
|
||||||
|
type DirWatcher struct {
|
||||||
|
Path string
|
||||||
|
Interval time.Duration
|
||||||
|
OnChange func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// dirSnapshot is an ordered map of file name -> file state. The ordering is
|
||||||
|
// derived from sorted filenames so two snapshots compare deterministically
|
||||||
|
// regardless of readdir order. exists reflects whether the directory was
|
||||||
|
// readable at scan time; a missing directory yields exists=false.
|
||||||
|
type dirSnapshot struct {
|
||||||
|
exists bool
|
||||||
|
names []string
|
||||||
|
states map[string]snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDirSnapshot() dirSnapshot {
|
||||||
|
return dirSnapshot{states: make(map[string]snapshot)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equal reports whether two snapshots describe the same file set and per-file
|
||||||
|
// state. A missing directory (exists=false) is treated as equal to any other
|
||||||
|
// missing directory regardless of cached names.
|
||||||
|
func (s dirSnapshot) equal(other dirSnapshot) bool {
|
||||||
|
if !s.exists && !other.exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if s.exists != other.exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(s.names) != len(other.names) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, n := range s.names {
|
||||||
|
if other.names[i] != n {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, n := range s.names {
|
||||||
|
a, b := s.states[n], other.states[n]
|
||||||
|
if a.exists != b.exists || a.size != b.size || !a.modTime.Equal(b.modTime) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||||
|
// OnChange whenever the directory's YAML file set changes.
|
||||||
|
//
|
||||||
|
// Policy mirrors the single-file Watcher: disappearance (directory missing or
|
||||||
|
// empty) is treated as a transient rename-style write and stays quiet; the
|
||||||
|
// transition back to present-with-content fires OnChange.
|
||||||
|
func (w *DirWatcher) Run(ctx context.Context) {
|
||||||
|
interval := w.Interval
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = DefaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
prev := scanDir(w.Path)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
cur := scanDir(w.Path)
|
||||||
|
// Suppress transitions involving an empty or missing directory —
|
||||||
|
// these are treated as transient rename-style writes, mirroring
|
||||||
|
// the single-file Watcher. Only present-with-content →
|
||||||
|
// present-with-content (changed) or no-content →
|
||||||
|
// present-with-content fires OnChange.
|
||||||
|
prevHasContent := prev.exists && len(prev.names) > 0
|
||||||
|
curHasContent := cur.exists && len(cur.names) > 0
|
||||||
|
if curHasContent && (!prevHasContent || !prev.equal(cur)) && w.OnChange != nil {
|
||||||
|
w.OnChange()
|
||||||
|
}
|
||||||
|
prev = cur
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanDir returns a snapshot of the *.yml/*.yaml files in dir. If the
|
||||||
|
// directory cannot be read (missing, permission denied) the snapshot reports
|
||||||
|
// exists=false; the next successful scan will detect the recovery and fire
|
||||||
|
// OnChange.
|
||||||
|
func scanDir(dir string) dirSnapshot {
|
||||||
|
snap := newDirSnapshot()
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return snap // exists=false
|
||||||
|
}
|
||||||
|
snap.exists = true
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := e.Name()
|
||||||
|
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fi, err := os.Stat(filepath.Join(dir, name))
|
||||||
|
if err != nil {
|
||||||
|
// File disappeared between ReadDir and Stat; skip it — the
|
||||||
|
// next poll will observe the removal cleanly.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snap.names = append(snap.names, name)
|
||||||
|
snap.states[name] = snapshot{
|
||||||
|
exists: true,
|
||||||
|
modTime: fi.ModTime(),
|
||||||
|
size: fi.Size(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(snap.names)
|
||||||
|
return snap
|
||||||
|
}
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startDirWatcher launches w.Run in a goroutine and returns a function that
|
||||||
|
// cancels the context and waits for Run to return.
|
||||||
|
func startDirWatcher(t *testing.T, w *DirWatcher) func() {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.Run(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("DirWatcher did not stop within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeYAMLInDir(t *testing.T, dir, name, content string) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_NoFireOnBaseline(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 5)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_DetectsFileAdd(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is added")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_DetectsFileRemoval(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.Remove(filepath.Join(dir, "b.yaml")))
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_DetectsModTimeChange(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||||
|
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base, base))
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_IgnoresNonYAMLFiles(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Adding a .txt file must not fire.
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0o644))
|
||||||
|
time.Sleep(testInterval * 4)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "non-YAML files must be ignored")
|
||||||
|
|
||||||
|
// Adding a .yml file must fire.
|
||||||
|
writeYAMLInDir(t, dir, "b.yml", "b")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire for *.yml files")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_MissingDirRecovers(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Remove the directory. No fire expected on disappearance alone.
|
||||||
|
require.NoError(t, os.RemoveAll(dir))
|
||||||
|
time.Sleep(testInterval * 3)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "directory removal alone must not fire")
|
||||||
|
|
||||||
|
// Recreate the directory and a YAML file; the recovery should fire.
|
||||||
|
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||||
|
writeYAMLInDir(t, dir, "recovered.yaml", "r")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when dir returns with content")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_EmptyDirSuppressedThenRecovers(t *testing.T) {
|
||||||
|
// Present-with-content → empty (all YAML removed, dir still exists)
|
||||||
|
// must stay quiet — treated as transient per the documented policy.
|
||||||
|
// The transition back to content fires.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Remove the only YAML file. Dir still exists but is empty of YAML.
|
||||||
|
require.NoError(t, os.Remove(filepath.Join(dir, "a.yaml")))
|
||||||
|
time.Sleep(testInterval * 4)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "emptying the directory must not fire")
|
||||||
|
|
||||||
|
// Add a YAML file back; transition to present-with-content fires.
|
||||||
|
writeYAMLInDir(t, dir, "c.yaml", "c")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when content returns")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
w := &DirWatcher{Path: dir, Interval: testInterval}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() { w.Run(ctx); close(done) }()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Run did not return within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
+37
-19
@@ -55,7 +55,8 @@ var logTimeFormats = map[string]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flagConfig := flag.String("config", "", "path to config file (required)")
|
flagConfig := flag.String("config", "", "path to config file")
|
||||||
|
flagConfigDir := flag.String("config-dir", "", "directory of *.yml/*.yaml config files (additive to -config)")
|
||||||
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
|
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
|
||||||
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||||
flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
|
flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||||
@@ -68,8 +69,8 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
if *flagConfig == "" {
|
if *flagConfig == "" && *flagConfigDir == "" {
|
||||||
slog.Error("-config is required")
|
slog.Error("at least one of -config or -config-dir must be provided")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,10 +89,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
configPath := *flagConfig
|
cfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||||
cfg, err := config.LoadConfig(configPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to load config", "path", configPath, "error", err)
|
slog.Error("failed to load config", "config", *flagConfig, "config-dir", *flagConfigDir, "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,7 +187,7 @@ func main() {
|
|||||||
|
|
||||||
proxyLog.Info("reloading configuration")
|
proxyLog.Info("reloading configuration")
|
||||||
|
|
||||||
newCfg, err := config.LoadConfig(configPath)
|
newCfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
proxyLog.Warnf("failed to reload config: %v", err)
|
proxyLog.Warnf("failed to reload config: %v", err)
|
||||||
return
|
return
|
||||||
@@ -230,19 +230,37 @@ func main() {
|
|||||||
defer watcherCancel()
|
defer watcherCancel()
|
||||||
|
|
||||||
if *flagWatchConfig {
|
if *flagWatchConfig {
|
||||||
absConfigPath, err := filepath.Abs(configPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("watch-config: failed to resolve config path", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
||||||
go func() {
|
|
||||||
(&configwatcher.Watcher{
|
if *flagConfig != "" {
|
||||||
Path: absConfigPath,
|
absConfigPath, err := filepath.Abs(*flagConfig)
|
||||||
Interval: configwatcher.DefaultInterval,
|
if err != nil {
|
||||||
OnChange: reload,
|
slog.Error("watch-config: failed to resolve config path", "error", err)
|
||||||
}).Run(watcherCtx)
|
os.Exit(1)
|
||||||
}()
|
}
|
||||||
|
go func() {
|
||||||
|
(&configwatcher.Watcher{
|
||||||
|
Path: absConfigPath,
|
||||||
|
Interval: configwatcher.DefaultInterval,
|
||||||
|
OnChange: reload,
|
||||||
|
}).Run(watcherCtx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if *flagConfigDir != "" {
|
||||||
|
absConfigDir, err := filepath.Abs(*flagConfigDir)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("watch-config: failed to resolve config-dir path", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
(&configwatcher.DirWatcher{
|
||||||
|
Path: absConfigDir,
|
||||||
|
Interval: configwatcher.DefaultInterval,
|
||||||
|
OnChange: reload,
|
||||||
|
}).Run(watcherCtx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user