Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0292c90ca1 | |||
| 617c7dc6b9 | |||
| 542b79dacf | |||
| 0a25b3bd31 | |||
| 32bc781326 |
@@ -0,0 +1,76 @@
|
||||
name: Build CUDA image (fork)
|
||||
|
||||
# Builds this fork's llama-swap (serial scheduler + embedded UI) from source and
|
||||
# layers it on a pinned llama.cpp CUDA server base, then pushes to the Gitea
|
||||
# container registry, e.g. gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||
#
|
||||
# Requires repo secrets: REGISTRY_USER, REGISTRY_PASSWORD (push to the registry).
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
llama_swap_version:
|
||||
description: "llama-swap version label (image tag prefix)"
|
||||
required: false
|
||||
default: "v230"
|
||||
llamacpp_build:
|
||||
description: "llama.cpp CUDA server build (base image tag suffix)"
|
||||
required: false
|
||||
default: "b9821"
|
||||
# Building the build definition itself kicks off a fresh image.
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- ".gitea/workflows/build-cuda-image.yml"
|
||||
- "docker/fork-cuda.Containerfile"
|
||||
|
||||
env:
|
||||
REGISTRY: gitea.stevedudenhoeffer.com
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Compute image metadata
|
||||
id: meta
|
||||
run: |
|
||||
LS_VER="${{ inputs.llama_swap_version || 'v230' }}"
|
||||
LCPP="${{ inputs.llamacpp_build || 'b9821' }}"
|
||||
{
|
||||
echo "image=${REGISTRY}/${{ github.repository }}"
|
||||
echo "tag=${LS_VER}-cuda-${LCPP}"
|
||||
echo "base_tag=server-cuda-${LCPP}"
|
||||
echo "ls_version=${LS_VER}"
|
||||
echo "build_date=$(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Gitea registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ secrets.REGISTRY_USER }}
|
||||
password: ${{ secrets.REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: docker/fork-cuda.Containerfile
|
||||
push: true
|
||||
provenance: false
|
||||
build-args: |
|
||||
BASE_TAG=${{ steps.meta.outputs.base_tag }}
|
||||
LS_VERSION=${{ steps.meta.outputs.ls_version }}
|
||||
GIT_HASH=${{ github.sha }}
|
||||
BUILD_DATE=${{ steps.meta.outputs.build_date }}
|
||||
tags: ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}
|
||||
|
||||
- name: Summary
|
||||
run: |
|
||||
echo "Pushed ${{ steps.meta.outputs.image }}:${{ steps.meta.outputs.tag }}" >> "$GITHUB_STEP_SUMMARY"
|
||||
@@ -5,16 +5,14 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
||||
## Tech stack
|
||||
|
||||
- golang
|
||||
- typescript, vite and svelt5 for UI (located in ui/)
|
||||
- typescript, vite and svelte 5 for UI (located in ui-svelte/)
|
||||
|
||||
## Workflow Tasks
|
||||
|
||||
- when summarizing changes only include details that require further action
|
||||
- just say "Done." when there is no further action
|
||||
- use the github CLI `gh` to create pull requests and work with github
|
||||
- Rules for creating pull requests:
|
||||
- keep them short and focused on changes.
|
||||
- never include a test plan
|
||||
- keep them short and focused on changes
|
||||
- skip the test plan
|
||||
- write the summary using the same style rules as commit message
|
||||
|
||||
## Testing
|
||||
@@ -30,7 +28,7 @@ llama-swap is a light weight, transparent proxy server that provides automatic m
|
||||
### Commit message example format:
|
||||
|
||||
```
|
||||
proxy: add new feature
|
||||
internal/server: add new feature
|
||||
|
||||
Add new feature that implements functionality X and Y.
|
||||
|
||||
|
||||
+3
-2
@@ -601,10 +601,11 @@
|
||||
"use": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"serial",
|
||||
"fifo"
|
||||
],
|
||||
"default": "fifo",
|
||||
"description": "Scheduler to use. Only 'fifo' is currently supported."
|
||||
"default": "serial",
|
||||
"description": "Scheduler to use. 'serial' (default on this fork): strict one-model-at-a-time, requests run in exact arrival order, switching models evicts every other model first. 'fifo': throughput-oriented, batches same-model requests and allows parallel/co-resident models."
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
|
||||
+13
-3
@@ -556,11 +556,21 @@ routing:
|
||||
# expands to: [L]
|
||||
full: "L"
|
||||
|
||||
# scheduler: how queued requests are ordered.
|
||||
# The default and only valid scheduler is "fifo"
|
||||
# scheduler: how queued requests are ordered and run.
|
||||
# - optional, default on this fork: "serial"
|
||||
# - valid values:
|
||||
# - "serial": strict one-model-at-a-time. Requests run in exact arrival
|
||||
# order; only one request runs at a time; switching to a different model
|
||||
# evicts every other running model first so a single model occupies memory
|
||||
# at a time. This ignores group/matrix co-residency entirely. The "fifo"
|
||||
# settings below (priority) do not apply.
|
||||
# - "fifo": throughput-oriented. Same-model requests are batched to reduce
|
||||
# swaps and a model serves up to its concurrencyLimit in parallel; models
|
||||
# in non-exclusive groups can run concurrently. Requests may be reordered.
|
||||
scheduler:
|
||||
use: fifo
|
||||
use: serial
|
||||
settings:
|
||||
# fifo settings only apply when use: fifo
|
||||
fifo:
|
||||
# priority: a dictionary of model ID -> priority
|
||||
# - optional, default: empty dictionary
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
# Build a CUDA llama-swap image FROM THIS FORK's source (includes the serial
|
||||
# scheduler) and layer it on a pinned llama.cpp CUDA server base. Produces e.g.:
|
||||
# gitea.stevedudenhoeffer.com/steve/llama-swap:v230-cuda-b9821
|
||||
#
|
||||
# BASE_TAG selects the llama.cpp CUDA runtime + llama-server build, e.g.
|
||||
# "server-cuda-b9821". The llama-swap binary (with the embedded Svelte UI) is
|
||||
# compiled from the repo at build time, so no GitHub release is required.
|
||||
#
|
||||
# Build context is the repo root:
|
||||
# docker build -f docker/fork-cuda.Containerfile \
|
||||
# --build-arg BASE_TAG=server-cuda-b9821 -t llama-swap:v230-cuda-b9821 .
|
||||
|
||||
ARG BASE_IMAGE=ghcr.io/ggml-org/llama.cpp
|
||||
ARG BASE_TAG=server-cuda-b9821
|
||||
|
||||
# ---- Stage 1: build the Svelte UI (embedded into the binary) ----
|
||||
FROM node:22-bookworm-slim AS ui
|
||||
WORKDIR /src/ui-svelte
|
||||
# Install deps first for layer caching. .npmrc carries legacy-peer-deps=true,
|
||||
# which the project relies on (tailwind/vite peer ranges), so copy it before
|
||||
# npm ci or the strict resolver fails with ERESOLVE.
|
||||
COPY ui-svelte/package.json ui-svelte/package-lock.json ui-svelte/.npmrc ./
|
||||
RUN npm ci
|
||||
COPY ui-svelte/ ./
|
||||
# `npm run build` is `vite build --emptyOutDir`; vite.config.ts writes to
|
||||
# ../internal/server/ui_dist, which //go:embed picks up in the next stage.
|
||||
RUN mkdir -p /src/internal/server && npm run build
|
||||
|
||||
# ---- Stage 2: build the llama-swap binary with the embedded UI ----
|
||||
FROM golang:1.26-bookworm AS build
|
||||
WORKDIR /src
|
||||
# Cache modules independently of source churn.
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
# Overlay the freshly built UI so //go:embed ui_dist ships the real assets
|
||||
# instead of the committed placeholder.
|
||||
COPY --from=ui /src/internal/server/ui_dist/ ./internal/server/ui_dist/
|
||||
ARG LS_VERSION=v230
|
||||
ARG GIT_HASH=unknown
|
||||
ARG BUILD_DATE=unknown
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags="-X main.version=${LS_VERSION} -X main.commit=${GIT_HASH} -X main.date=${BUILD_DATE}" \
|
||||
-o /out/llama-swap .
|
||||
|
||||
# ---- Stage 3: runtime image on the pinned llama.cpp CUDA base ----
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG}
|
||||
|
||||
# Run as root by default to match the upstream `vNNN-cuda-bNNNN` (non-suffixed)
|
||||
# image that ragnaros pulls today: it needs root to reach the mounted docker
|
||||
# socket for container-backed models (sd-server). Override UID/GID at build time
|
||||
# for a non-root variant.
|
||||
ARG UID=0
|
||||
ARG GID=0
|
||||
ARG USER_HOME=/root
|
||||
ENV HOME=$USER_HOME
|
||||
|
||||
RUN set -eux; \
|
||||
if [ "$UID" -ne 0 ]; then \
|
||||
if [ "$GID" -ne 0 ]; then groupadd --system --gid "$GID" app; fi; \
|
||||
useradd --system --uid "$UID" --gid "$GID" --home "$USER_HOME" app; \
|
||||
fi; \
|
||||
mkdir --parents "$HOME" /app; \
|
||||
chown --recursive "$UID:$GID" "$HOME" /app
|
||||
|
||||
COPY --from=build --chown=$UID:$GID /out/llama-swap /app/llama-swap
|
||||
COPY --chown=$UID:$GID docker/config.example.yaml /app/config.yaml
|
||||
|
||||
USER $UID:$GID
|
||||
WORKDIR /app
|
||||
ENV PATH="/app:${PATH}"
|
||||
|
||||
HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1
|
||||
ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ]
|
||||
@@ -0,0 +1,57 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
)
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
@@ -2,16 +2,9 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/billziss-gh/golib/shlex"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -85,12 +78,6 @@ type GroupConfig struct {
|
||||
Members []string `yaml:"members"`
|
||||
}
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// set default values for GroupConfig
|
||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
type rawGroupConfig GroupConfig
|
||||
@@ -224,430 +211,6 @@ func LoadConfig(path string) (Config, error) {
|
||||
return LoadConfigFromReader(file)
|
||||
}
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Apply defaults for performance config when section is missing
|
||||
if config.Performance.Every == 0 {
|
||||
config.Performance.Every = 5 * time.Second
|
||||
}
|
||||
if err = config.Performance.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("performance: %w", err)
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
// Apply default for upstream.ignorePaths when not specified. The default
|
||||
// matches common static-asset suffixes so they do not trigger a swap.
|
||||
if len(config.Upstream.IgnorePaths) == 0 {
|
||||
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||
// the new `routing.router` block are mutually exclusive: a config may use
|
||||
// either style, never both.
|
||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||
rtr := config.Routing.Router
|
||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||
|
||||
if hasTopLevel && hasRouting {
|
||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||
}
|
||||
|
||||
if !hasTopLevel {
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active, so there is no conflict.
|
||||
rs := config.Routing.Router.Settings
|
||||
switch config.Routing.Router.Use {
|
||||
case "matrix":
|
||||
if rs.Matrix == nil {
|
||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||
}
|
||||
config.Matrix = rs.Matrix
|
||||
case "group", "":
|
||||
config.Groups = rs.Groups
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||
}
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
}
|
||||
|
||||
if config.Matrix != nil {
|
||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.Matrix.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||
// and new-style configs converge here. The Matrix pointer is shared so
|
||||
// ExpandedSets stays in one place.
|
||||
if config.Matrix != nil {
|
||||
config.Routing.Router.Use = "matrix"
|
||||
} else {
|
||||
config.Routing.Router.Use = "group"
|
||||
}
|
||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||
config.Routing.Router.Settings.Groups = config.Groups
|
||||
|
||||
if config.Routing.Scheduler.Use == "" {
|
||||
config.Routing.Scheduler.Use = "fifo"
|
||||
}
|
||||
if config.Routing.Scheduler.Use != "fifo" {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
|
||||
}
|
||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||
if _, found := config.RealModelName(modelID); !found {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// rewrites the yaml to include a default group with any orphaned models
|
||||
func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
@@ -692,233 +255,3 @@ func AddDefaultGroupToConfig(config Config) Config {
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
// Handle trailing backslashes by replacing with space
|
||||
if strings.HasSuffix(trimmed, "\\") {
|
||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
// put it back together
|
||||
cmdStr = strings.Join(cleanedLines, "\n")
|
||||
|
||||
// Split the command into arguments
|
||||
var args []string
|
||||
if runtime.GOOS == "windows" {
|
||||
args = shlex.Windows.Split(cmdStr)
|
||||
} else {
|
||||
args = shlex.Posix.Split(cmdStr)
|
||||
}
|
||||
|
||||
// Ensure the command is not empty
|
||||
if len(args) == 0 {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func StripComments(cmdStr string) string {
|
||||
var cleanedLines []string
|
||||
for _, line := range strings.Split(cmdStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip comment lines
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
cleanedLines = append(cleanedLines, line)
|
||||
}
|
||||
return strings.Join(cleanedLines, "\n")
|
||||
}
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -277,7 +277,7 @@ groups:
|
||||
},
|
||||
},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "fifo",
|
||||
Use: "serial",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -777,22 +777,27 @@ func TestConfig_APIKeys_Invalid(t *testing.T) {
|
||||
{
|
||||
name: "blank spaces only",
|
||||
content: `apiKeys: [" "]`,
|
||||
expectedErr: "api key cannot contain spaces: ` `",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains leading space",
|
||||
content: `apiKeys: [" key123"]`,
|
||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains trailing space",
|
||||
content: `apiKeys: ["key123 "]`,
|
||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "contains middle space",
|
||||
content: `apiKeys: ["key 123"]`,
|
||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
||||
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "space in second key reports correct index",
|
||||
content: `apiKeys: ["valid-key", "bad key"]`,
|
||||
expectedErr: "apiKeys[1]: api key cannot contain spaces",
|
||||
},
|
||||
{
|
||||
name: "empty in list with valid keys",
|
||||
@@ -1567,7 +1572,7 @@ groups:
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
// default group injected for orphaned models (none here) still leaves g1
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyTopLevelMatrix(t *testing.T) {
|
||||
@@ -1626,7 +1631,7 @@ func TestConfig_Routing_DefaultsToGroup(t *testing.T) {
|
||||
cfg, err := LoadConfigFromReader(strings.NewReader(twoModels))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "group", cfg.Routing.Router.Use)
|
||||
assert.Equal(t, "fifo", cfg.Routing.Scheduler.Use)
|
||||
assert.Equal(t, "serial", cfg.Routing.Scheduler.Use)
|
||||
}
|
||||
|
||||
func TestConfig_Routing_LegacyAndRoutingConflict(t *testing.T) {
|
||||
|
||||
@@ -266,7 +266,7 @@ groups:
|
||||
},
|
||||
},
|
||||
Scheduler: SchedulerConfig{
|
||||
Use: "fifo",
|
||||
Use: "serial",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,441 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
yamlStr := string(data)
|
||||
|
||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||
// This is safe because env values are simple strings without YAML formatting
|
||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
// Unmarshal into full Config with defaults
|
||||
config := Config{
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
LogTimeFormat: "",
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
if config.HealthCheckTimeout < 15 {
|
||||
config.HealthCheckTimeout = 15
|
||||
}
|
||||
|
||||
// Apply defaults for performance config when section is missing
|
||||
if config.Performance.Every == 0 {
|
||||
config.Performance.Every = 5 * time.Second
|
||||
}
|
||||
if err = config.Performance.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("performance: %w", err)
|
||||
}
|
||||
|
||||
if config.StartPort < 1 {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
// Apply default for upstream.ignorePaths when not specified. The default
|
||||
// matches common static-asset suffixes so they do not trigger a swap.
|
||||
if len(config.Upstream.IgnorePaths) == 0 {
|
||||
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||
}
|
||||
|
||||
// Populate the aliases map
|
||||
config.aliases = make(map[string]string)
|
||||
for modelName, modelConfig := range config.Models {
|
||||
for _, alias := range modelConfig.Aliases {
|
||||
if _, found := config.aliases[alias]; found {
|
||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||
}
|
||||
config.aliases[alias] = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate global macros
|
||||
for _, macro := range config.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get and sort all model IDs for consistent port assignment
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
modelIds = append(modelIds, modelId)
|
||||
}
|
||||
sort.Strings(modelIds)
|
||||
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||
|
||||
// Strip comments from command fields
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||
mergedMacros = append(mergedMacros, config.Macros...)
|
||||
|
||||
// Add model macros (override globals with same name)
|
||||
for _, entry := range modelConfig.Macros {
|
||||
found := false
|
||||
for i, existing := range mergedMacros {
|
||||
if existing.Name == entry.Name {
|
||||
mergedMacros[i] = entry
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
mergedMacros = append(mergedMacros, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute remaining macros in model fields (LIFO order)
|
||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||
entry := mergedMacros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
// Substitute macros in SetParamsByID keys and values
|
||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||
}
|
||||
newParamMap, ok := newValAny.(map[string]any)
|
||||
if !ok {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||
}
|
||||
newSetParamsByID[newKey] = newParamMap
|
||||
}
|
||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||
}
|
||||
|
||||
// Substitute in metadata (type-preserving)
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle PORT macro - only allocate if cmd uses it
|
||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||
if cmdHasPort || proxyHasPort {
|
||||
if !cmdHasPort && proxyHasPort {
|
||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||
}
|
||||
|
||||
macroSlug := "${PORT}"
|
||||
macroStr := fmt.Sprintf("%v", nextPort)
|
||||
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||
}
|
||||
modelConfig.Metadata = result.(map[string]any)
|
||||
}
|
||||
|
||||
nextPort++
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
fieldMap := map[string]string{
|
||||
"cmd": modelConfig.Cmd,
|
||||
"cmdStop": modelConfig.CmdStop,
|
||||
"proxy": modelConfig.Proxy,
|
||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||
"name": modelConfig.Name,
|
||||
"description": modelConfig.Description,
|
||||
}
|
||||
|
||||
for fieldName, fieldValue := range fieldMap {
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
if macroName == "PID" && fieldName == "cmdStop" {
|
||||
continue // replaced at runtime
|
||||
}
|
||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modelConfig.Metadata) > 0 {
|
||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||
}
|
||||
|
||||
// Validate SetParamsByID keys and values
|
||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||
}
|
||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||
for key := range modelConfig.Filters.SetParamsByID {
|
||||
if key == modelId {
|
||||
continue
|
||||
}
|
||||
if _, exists := config.Models[key]; exists {
|
||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||
}
|
||||
if existingModel, exists := config.aliases[key]; exists {
|
||||
if existingModel != modelId {
|
||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||
}
|
||||
continue // already registered as explicit alias for this model
|
||||
}
|
||||
config.aliases[key] = modelId
|
||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||
}
|
||||
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||
// the new `routing.router` block are mutually exclusive: a config may use
|
||||
// either style, never both.
|
||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||
rtr := config.Routing.Router
|
||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||
|
||||
if hasTopLevel && hasRouting {
|
||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||
}
|
||||
|
||||
if !hasTopLevel {
|
||||
// Both groups and matrix may be defined under routing.router.settings;
|
||||
// routing.router.use selects which one is active, so there is no conflict.
|
||||
rs := config.Routing.Router.Settings
|
||||
switch config.Routing.Router.Use {
|
||||
case "matrix":
|
||||
if rs.Matrix == nil {
|
||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||
}
|
||||
config.Matrix = rs.Matrix
|
||||
case "group", "":
|
||||
config.Groups = rs.Groups
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||
}
|
||||
}
|
||||
|
||||
// groups XOR matrix
|
||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||
}
|
||||
|
||||
if config.Matrix != nil {
|
||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||
}
|
||||
config.Matrix.ExpandedSets = expandedSets
|
||||
} else {
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
|
||||
// Validate group members
|
||||
memberUsage := make(map[string]string)
|
||||
for groupID, groupConfig := range config.Groups {
|
||||
prevSet := make(map[string]bool)
|
||||
for _, member := range groupConfig.Members {
|
||||
if _, found := prevSet[member]; found {
|
||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||
}
|
||||
prevSet[member] = true
|
||||
|
||||
if existingGroup, exists := memberUsage[member]; exists {
|
||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||
}
|
||||
memberUsage[member] = groupID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||
// and new-style configs converge here. The Matrix pointer is shared so
|
||||
// ExpandedSets stays in one place.
|
||||
if config.Matrix != nil {
|
||||
config.Routing.Router.Use = "matrix"
|
||||
} else {
|
||||
config.Routing.Router.Use = "group"
|
||||
}
|
||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||
config.Routing.Router.Settings.Groups = config.Groups
|
||||
|
||||
// This fork defaults to the "serial" scheduler: one model loaded at a time,
|
||||
// requests served in strict arrival order. Set use: fifo for the upstream
|
||||
// throughput-oriented behavior that batches same-model requests.
|
||||
if config.Routing.Scheduler.Use == "" {
|
||||
config.Routing.Scheduler.Use = "serial"
|
||||
}
|
||||
switch config.Routing.Scheduler.Use {
|
||||
case "fifo", "serial":
|
||||
default:
|
||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo, serial)", config.Routing.Scheduler.Use)
|
||||
}
|
||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||
if _, found := config.RealModelName(modelID); !found {
|
||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hooks preload
|
||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||
var toPreload []string
|
||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
if real, found := config.RealModelName(modelID); found {
|
||||
toPreload = append(toPreload, real)
|
||||
}
|
||||
}
|
||||
config.Hooks.OnStartup.Preload = toPreload
|
||||
}
|
||||
|
||||
// Validate API keys (env macros already substituted at string level)
|
||||
for i, apikey := range config.RequiredAPIKeys {
|
||||
if apikey == "" {
|
||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||
}
|
||||
if strings.Contains(apikey, " ") {
|
||||
return Config{}, fmt.Errorf("apiKeys[%d]: api key cannot contain spaces", i)
|
||||
}
|
||||
config.RequiredAPIKeys[i] = apikey
|
||||
}
|
||||
|
||||
// Process peers with global macro substitution
|
||||
for peerName, peerConfig := range config.Peers {
|
||||
// Substitute global macros (LIFO order)
|
||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||
entry := config.Macros[i]
|
||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||
|
||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||
|
||||
// Substitute in setParams (type-preserving)
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||
}
|
||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate no unknown macros remain
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||
}
|
||||
if len(peerConfig.Filters.SetParams) > 0 {
|
||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
config.Peers[peerName] = peerConfig
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||
)
|
||||
|
||||
// validateMacro validates macro name and value constraints
|
||||
func validateMacro(name string, value any) error {
|
||||
if len(name) >= 64 {
|
||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||
}
|
||||
if !macroNameRegex.MatchString(name) {
|
||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||
}
|
||||
|
||||
// Validate that value is a scalar type
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check for self-reference
|
||||
macroSlug := fmt.Sprintf("${%s}", name)
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||
// These types are allowed
|
||||
default:
|
||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "PORT", "MODEL_ID":
|
||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||
func validateNestedForUnknownMacros(value any, context string) error {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range matches {
|
||||
macroName := match[1]
|
||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||
}
|
||||
// Check for unsubstituted env macros
|
||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||
for _, match := range envMatches {
|
||||
varName := match[1]
|
||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
for _, val := range v {
|
||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
// Scalar types don't contain macros
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||
// This is called once per macro, allowing LIFO substitution order
|
||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||
macroStr := fmt.Sprintf("%v", macroValue)
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
// Check if this is a direct macro substitution
|
||||
if v == macroSlug {
|
||||
return macroValue, nil
|
||||
}
|
||||
// Handle string interpolation
|
||||
if strings.Contains(v, macroSlug) {
|
||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case map[string]any:
|
||||
// Recursively process map values
|
||||
newMap := make(map[string]any)
|
||||
for key, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newMap[key] = newVal
|
||||
}
|
||||
return newMap, nil
|
||||
|
||||
case []any:
|
||||
// Recursively process slice elements
|
||||
newSlice := make([]any, len(v))
|
||||
for i, val := range v {
|
||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSlice[i] = newVal
|
||||
}
|
||||
return newSlice, nil
|
||||
|
||||
default:
|
||||
// Return scalar types as-is
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||
// (which strips comments) and only checking the comment-free version for macros.
|
||||
func substituteEnvMacros(s string) (string, error) {
|
||||
// Unmarshal and remarshal to strip YAML comments
|
||||
var raw any
|
||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||
// If YAML is invalid, fall back to scanning the original string
|
||||
// so the user gets the env var error rather than a confusing YAML parse error
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
clean, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return substituteEnvMacrosInString(s, s)
|
||||
}
|
||||
|
||||
return substituteEnvMacrosInString(s, string(clean))
|
||||
}
|
||||
|
||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||
// them in target. This separation allows scanning comment-free YAML while
|
||||
// substituting in the original string.
|
||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||
result := target
|
||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||
for _, match := range matches {
|
||||
fullMatch := match[0] // ${env.VAR_NAME}
|
||||
varName := match[1] // VAR_NAME
|
||||
|
||||
value, exists := os.LookupEnv(varName)
|
||||
if !exists {
|
||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||
}
|
||||
|
||||
// Sanitize the value for safe YAML substitution
|
||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = strings.ReplaceAll(result, fullMatch, value)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||
// for compatibility with double-quoted YAML strings.
|
||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||
// Reject values that would break YAML structure regardless of quoting context
|
||||
if strings.ContainsAny(value, "\n\r\x00") {
|
||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||
}
|
||||
|
||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||
// In double-quoted contexts, they are interpreted correctly.
|
||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// identityMapPaths is the set of dotted paths whose direct children are
|
||||
// identity-keyed maps. A child key present in two sources is a hard error;
|
||||
// such keys name discrete entities (a model, a group, a peer, etc.) and a
|
||||
// duplicate means the user has split one entity across files by mistake.
|
||||
var identityMapPaths = map[string]bool{
|
||||
"models": true,
|
||||
"groups": true,
|
||||
"profiles": true,
|
||||
"peers": true,
|
||||
"matrix": true,
|
||||
"routing.router.settings.groups": true,
|
||||
"routing.router.settings.matrix": true,
|
||||
}
|
||||
|
||||
// LoadConfigSources loads and merges configuration from -config (optional)
|
||||
// and -config-dir (optional). At least one must be provided. The -config file
|
||||
// is loaded first; *.yml/*.yaml files directly under -config-dir are then
|
||||
// merged in sorted filename order. The merged document is passed through the
|
||||
// existing LoadConfigFromReader pipeline unchanged.
|
||||
func LoadConfigSources(configPath, configDir string) (Config, error) {
|
||||
if configPath == "" && configDir == "" {
|
||||
return Config{}, fmt.Errorf("at least one of -config or -config-dir must be provided")
|
||||
}
|
||||
|
||||
var sourcePaths []string
|
||||
|
||||
if configPath != "" {
|
||||
sourcePaths = append(sourcePaths, configPath)
|
||||
}
|
||||
|
||||
if configDir != "" {
|
||||
dirFiles, err := listYAMLFiles(configDir)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("-config-dir %s: %w", configDir, err)
|
||||
}
|
||||
|
||||
if configPath != "" {
|
||||
absConfig, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to resolve -config path: %w", err)
|
||||
}
|
||||
for _, f := range dirFiles {
|
||||
absF, err := filepath.Abs(f)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to resolve config dir file %s: %w", f, err)
|
||||
}
|
||||
if absConfig == absF {
|
||||
return Config{}, fmt.Errorf("-config path %s is also present in -config-dir %s; remove it from one", configPath, configDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sourcePaths = append(sourcePaths, dirFiles...)
|
||||
}
|
||||
|
||||
if len(sourcePaths) == 0 {
|
||||
return Config{}, fmt.Errorf("no configuration sources found")
|
||||
}
|
||||
|
||||
var merged *yaml.Node
|
||||
for _, p := range sourcePaths {
|
||||
node, err := parseSource(p)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
if node == nil {
|
||||
continue // empty file
|
||||
}
|
||||
if merged == nil {
|
||||
merged = node
|
||||
continue
|
||||
}
|
||||
if err := mergeNodes(merged, node, "", p); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if merged == nil {
|
||||
// All sources were empty; run the pipeline on empty input so defaults
|
||||
// and validation still apply (e.g. startPort, performance defaults).
|
||||
return LoadConfigFromReader(strings.NewReader(""))
|
||||
}
|
||||
|
||||
out, err := yaml.Marshal(merged)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to marshal merged config: %w", err)
|
||||
}
|
||||
return LoadConfigFromReader(strings.NewReader(string(out)))
|
||||
}
|
||||
|
||||
// listYAMLFiles returns the top-level *.yml and *.yaml files in dir, sorted by
|
||||
// filename for deterministic merge order. Subdirectories are not traversed.
|
||||
func listYAMLFiles(dir string) ([]string, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var files []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||
continue
|
||||
}
|
||||
files = append(files, filepath.Join(dir, name))
|
||||
}
|
||||
sort.Strings(files)
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// parseSource reads and parses one YAML config file into a root mapping node.
|
||||
// Returns a nil node (no error) when the file is empty or contains only
|
||||
// comments.
|
||||
//
|
||||
// Env macros (${env.VAR}) are substituted at the string level before YAML
|
||||
// parsing so that flow-style constructs like [${env.API_KEY}] parse
|
||||
// correctly — the brace would otherwise be interpreted as a flow mapping.
|
||||
func parseSource(path string) (*yaml.Node, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config %s: %w", path, err)
|
||||
}
|
||||
yamlStr, err := substituteEnvMacros(string(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config %s: %w", path, err)
|
||||
}
|
||||
var doc yaml.Node
|
||||
if err := yaml.Unmarshal([]byte(yamlStr), &doc); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config %s: %w", path, err)
|
||||
}
|
||||
// yaml.Unmarshal into a yaml.Node yields a DocumentNode whose Content[0]
|
||||
// is the actual root. Unwrap it so callers see the real top-level node.
|
||||
root := &doc
|
||||
if root.Kind == yaml.DocumentNode && len(root.Content) > 0 {
|
||||
root = root.Content[0]
|
||||
}
|
||||
if root.Kind == 0 || root.Content == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if root.Kind != yaml.MappingNode {
|
||||
return nil, fmt.Errorf("config %s: top-level YAML must be a mapping", path)
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// mergeNodes merges src into dst (both MappingNodes) in place. Keys present in
|
||||
// only one side are kept; shared keys are merged recursively under the rules
|
||||
// in mergeValue. srcPath is included in error messages to identify the file
|
||||
// that introduced the conflict.
|
||||
func mergeNodes(dst, src *yaml.Node, path, srcPath string) error {
|
||||
srcIdx := indexMapping(src)
|
||||
|
||||
// First pass: merge shared keys in place.
|
||||
for i := 0; i+1 < len(dst.Content); i += 2 {
|
||||
keyNode := dst.Content[i]
|
||||
dstVal := dst.Content[i+1]
|
||||
key := keyNode.Value
|
||||
|
||||
srcVal, ok := srcIdx[key]
|
||||
if !ok {
|
||||
continue // dst-only key, keep as-is
|
||||
}
|
||||
|
||||
childPath := joinPath(path, key)
|
||||
|
||||
if identityMapPaths[childPath] {
|
||||
// Identity-keyed map: each child key names a discrete entity
|
||||
// (a model, group, peer, ...). A shared child key is a hard
|
||||
// error; src-only children are appended in the second pass.
|
||||
if err := mergeIdentityMap(dstVal, srcVal, childPath, key, srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := mergeValue(dstVal, srcVal, childPath, srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: append src-only keys.
|
||||
dstIdx := indexMapping(dst)
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
keyNode := src.Content[i]
|
||||
srcVal := src.Content[i+1]
|
||||
key := keyNode.Value
|
||||
|
||||
if _, ok := dstIdx[key]; ok {
|
||||
continue // already merged above
|
||||
}
|
||||
keyCopy := *keyNode
|
||||
valCopy := *srcVal
|
||||
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeIdentityMap merges two identity-keyed mapping nodes (e.g. `models`,
|
||||
// `groups`, `peers`). Any child key present in both sides is a duplicate
|
||||
// entity and produces an error naming the conflicting key and source file.
|
||||
// src-only keys are appended to dst.
|
||||
func mergeIdentityMap(dst, src *yaml.Node, path, mapName, srcPath string) error {
|
||||
if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode {
|
||||
return fmt.Errorf("conflict at %q: expected a mapping, introduced by %s", path, srcPath)
|
||||
}
|
||||
dstIdx := indexMapping(dst)
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
keyNode := src.Content[i]
|
||||
srcVal := src.Content[i+1]
|
||||
key := keyNode.Value
|
||||
if _, dup := dstIdx[key]; dup {
|
||||
return fmt.Errorf("duplicate %s %q found in %s (already defined in another config source)", mapName, key, srcPath)
|
||||
}
|
||||
keyCopy := *keyNode
|
||||
valCopy := *srcVal
|
||||
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeValue merges srcVal into dstVal (both pointing into the parent's
|
||||
// Content slice). Mapping+Mapping recurses; Sequence+Sequence concatenates;
|
||||
// Scalar+Scalar errors on value mismatch; null on either side yields to the
|
||||
// non-null side.
|
||||
func mergeValue(dstVal, srcVal *yaml.Node, path, srcPath string) error {
|
||||
switch {
|
||||
case dstVal.Kind == yaml.MappingNode && srcVal.Kind == yaml.MappingNode:
|
||||
return mergeNodes(dstVal, srcVal, path, srcPath)
|
||||
|
||||
case dstVal.Kind == yaml.SequenceNode && srcVal.Kind == yaml.SequenceNode:
|
||||
dstVal.Content = append(dstVal.Content, srcVal.Content...)
|
||||
return nil
|
||||
|
||||
case dstVal.Kind == yaml.ScalarNode && srcVal.Kind == yaml.ScalarNode:
|
||||
if isNullScalar(dstVal) {
|
||||
*dstVal = *srcVal
|
||||
return nil
|
||||
}
|
||||
if isNullScalar(srcVal) {
|
||||
return nil
|
||||
}
|
||||
if dstVal.Value == srcVal.Value && dstVal.Tag == srcVal.Tag {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("conflict at %q: %s sets a different value than a previous source", path, srcPath)
|
||||
|
||||
case isNull(dstVal):
|
||||
*dstVal = *srcVal
|
||||
return nil
|
||||
|
||||
case isNull(srcVal):
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("conflict at %q: incompatible YAML node kinds (kind %d vs %d) introduced by %s", path, dstVal.Kind, srcVal.Kind, srcPath)
|
||||
}
|
||||
}
|
||||
|
||||
// isNull reports whether n represents a YAML null (empty or !!null).
|
||||
func isNull(n *yaml.Node) bool {
|
||||
if n == nil || n.Kind == 0 {
|
||||
return true
|
||||
}
|
||||
return isNullScalar(n)
|
||||
}
|
||||
|
||||
func isNullScalar(n *yaml.Node) bool {
|
||||
return n.Kind == yaml.ScalarNode && (n.Tag == "!!null" || n.Tag == "") && n.Value == ""
|
||||
}
|
||||
|
||||
// indexMapping builds a key -> value-node index for a mapping node.
|
||||
func indexMapping(n *yaml.Node) map[string]*yaml.Node {
|
||||
idx := make(map[string]*yaml.Node, len(n.Content)/2)
|
||||
for i := 0; i+1 < len(n.Content); i += 2 {
|
||||
idx[n.Content[i].Value] = n.Content[i+1]
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func joinPath(parent, key string) string {
|
||||
if parent == "" {
|
||||
return key
|
||||
}
|
||||
return parent + "." + key
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// writeYAML writes content to a file named name inside dir. Returns the full
|
||||
// path of the written file.
|
||||
func writeYAML(t *testing.T, dir, name, content string) string {
|
||||
t.Helper()
|
||||
p := filepath.Join(dir, name)
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0o755))
|
||||
require.NoError(t, os.WriteFile(p, []byte(content), 0o644))
|
||||
return p
|
||||
}
|
||||
|
||||
// modelCfg builds a single-model YAML snippet indented for nesting under a
|
||||
// `models:` key. The proxy uses a fixed port so tests don't depend on
|
||||
// ${PORT} allocation.
|
||||
func modelCfg(id, cmd string) string {
|
||||
return " " + id + ":\n cmd: " + cmd + "\n proxy: \"http://localhost:9999\"\n"
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_NeitherProvided(t *testing.T) {
|
||||
_, err := LoadConfigSources("", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "at least one of -config or -config-dir")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ConfigOnly(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "config.yaml", `
|
||||
models:
|
||||
`+modelCfg("model1", "echo hi")+`
|
||||
groups:
|
||||
group1:
|
||||
members: ["model1"]
|
||||
`)
|
||||
cfg, err := LoadConfigSources(cfgPath, "")
|
||||
require.NoError(t, err)
|
||||
_, id, ok := cfg.FindConfig("model1")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "model1", id)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DirOnly(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("alpha", "echo a"))
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("beta", "echo b"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
for _, want := range []string{"alpha", "beta"} {
|
||||
_, _, ok := cfg.FindConfig(want)
|
||||
assert.True(t, ok, "model %s should be present", want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ConfigPlusDirAdditive(t *testing.T) {
|
||||
// -config lives outside -config-dir; both contribute models additively.
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "config.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||
cfgDir := t.TempDir()
|
||||
writeYAML(t, cfgDir, "extra.yaml", "models:\n"+modelCfg("ext", "echo ext"))
|
||||
|
||||
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||
require.NoError(t, err)
|
||||
for _, want := range []string{"base", "ext"} {
|
||||
_, _, ok := cfg.FindConfig(want)
|
||||
assert.True(t, ok, "model %s should be present after merge", want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfigSources_ConfigInDirOverlap verifies that a -config file that
|
||||
// is also a member of -config-dir is rejected.
|
||||
func TestLoadConfigSources_ConfigInDirOverlap(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||
|
||||
_, err := LoadConfigSources(cfgPath, dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "is also present in -config-dir")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicateModelID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("dup", "echo a"))
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("dup", "echo b"))
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate models "dup"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicateGroupID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", `
|
||||
models:
|
||||
`+modelCfg("m1", "echo m1")+"groups:\n g1:\n members: [m1]\n")
|
||||
writeYAML(t, dir, "b.yaml", `
|
||||
models:
|
||||
`+modelCfg("m2", "echo m2")+"groups:\n g1:\n members: [m2]\n")
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate groups "g1"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_DuplicatePeer(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
peerA := "peers:\n remote:\n proxy: http://x:1\n models: [m1]\n"
|
||||
peerB := "peers:\n remote:\n proxy: http://x:2\n models: [m2]\n"
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\n"+peerA)
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\n"+peerB)
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `duplicate peers "remote"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ScalarConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 200\n")
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), `conflict at "globalTTL"`)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_ScalarSameValueNoConflict(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 100\n")
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, cfg.GlobalTTL)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_MacrosConcatenate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "macros:\n LOW: 1\nmodels:\n"+modelCfg("m1", "echo ${LOW}"))
|
||||
writeYAML(t, dir, "b.yaml", "macros:\n HIGH: 2\nmodels:\n"+modelCfg("m2", "echo ${HIGH}"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
// Both macros are available globally after merge.
|
||||
low, ok := cfg.Macros.Get("LOW")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 1, low)
|
||||
high, ok := cfg.Macros.Get("HIGH")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 2, high)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_APIKeysConcatenate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\napiKeys: [key-a]\n")
|
||||
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\napiKeys: [key-b]\n")
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"key-a", "key-b"}, cfg.RequiredAPIKeys)
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_RoutingGroupsMerge(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", `
|
||||
models:
|
||||
`+modelCfg("m1", "echo m1")+`
|
||||
routing:
|
||||
router:
|
||||
settings:
|
||||
groups:
|
||||
groupA:
|
||||
members: [m1]
|
||||
`)
|
||||
writeYAML(t, dir, "b.yaml", `
|
||||
models:
|
||||
`+modelCfg("m2", "echo m2")+`
|
||||
routing:
|
||||
router:
|
||||
settings:
|
||||
groups:
|
||||
groupB:
|
||||
members: [m2]
|
||||
`)
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
groups := cfg.Routing.Router.Settings.Groups
|
||||
assert.Contains(t, groups, "groupA")
|
||||
assert.Contains(t, groups, "groupB")
|
||||
// default group added by pipeline for orphaned/leftover routing groups...
|
||||
// here both groups reference distinct models
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EnvMacrosSubstituted(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Use ${PORT} in cmd so the pipeline allocates a port and substitutes it;
|
||||
// verifies env/macro substitution runs on the merged document.
|
||||
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: serve --port ${PORT}\n proxy: \"http://localhost:${PORT}\"\n")
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
m := cfg.Models["m1"]
|
||||
assert.NotContains(t, m.Cmd, "${PORT}", "PORT macro should have been substituted")
|
||||
assert.NotContains(t, m.Proxy, "${PORT}", "PORT macro should have been substituted in proxy")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EnvMacroInFlowStyleList(t *testing.T) {
|
||||
// Regression: flow-style lists with ${env.*} must parse. Previously
|
||||
// parseSource unmarshalled before env substitution, so the brace in
|
||||
// [${env.API_KEY}] was misread as a flow mapping and parsing failed.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: echo hi\n proxy: \"http://localhost:9999\"\n")
|
||||
writeYAML(t, dir, "keys.yaml", "apiKeys: [${env.TEST_API_KEY}]\nmodels:\n m2:\n cmd: echo hi\n proxy: \"http://localhost:9998\"\n")
|
||||
|
||||
t.Setenv("TEST_API_KEY", "secret123")
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.RequiredAPIKeys, "secret123")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_SortedOrderDeterministic(t *testing.T) {
|
||||
// Two files defining distinct models, scanned in z..a order by filename.
|
||||
// Determine merged result is the same regardless of how the FS returns them.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "z.yaml", "models:\n"+modelCfg("zmodel", "echo z"))
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("amodel", "echo a"))
|
||||
|
||||
const runs = 3
|
||||
for i := 0; i < runs; i++ {
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
// startPort-based allocation: first allocated model gets 5800.
|
||||
// Sorted order means amodel gets 5800, zmodel gets 5801.
|
||||
_, _, ok := cfg.FindConfig("amodel")
|
||||
assert.True(t, ok)
|
||||
_, _, ok = cfg.FindConfig("zmodel")
|
||||
assert.True(t, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EmptyDirWithConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgDir := t.TempDir()
|
||||
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||
|
||||
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.Models, "m1")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_EmptyDirOnly(t *testing.T) {
|
||||
// An empty -config-dir with no -config is an error: there is nothing to
|
||||
// load and silently producing an empty config would mask the misconfig.
|
||||
cfgDir := t.TempDir()
|
||||
_, err := LoadConfigSources("", cfgDir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no configuration sources found")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_AssertNoUnknownMacrosAfterMerge(t *testing.T) {
|
||||
// Macros defined in one file should not satisfy unknown-macro validation in
|
||||
// another — they do, because merge concats global macros before validation
|
||||
// runs. This test documents that a macro from file A is usable in file B.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "macros.yaml", "macros:\n SHARED: hello\nmodels:\n"+modelCfg("dummy", "echo dummy"))
|
||||
writeYAML(t, dir, "use.yaml", "models:\n"+modelCfg("user", "echo ${SHARED}"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
m := cfg.Models["user"]
|
||||
assert.Contains(t, m.Cmd, "hello")
|
||||
assert.NotContains(t, m.Cmd, "${SHARED}")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_KindMismatchErrors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "startPort: 5800\nmodels:\n"+modelCfg("m1", "echo m1"))
|
||||
writeYAML(t, dir, "b.yaml", "startPort: [5800, 5801]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||
|
||||
_, err := LoadConfigSources("", dir)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "incompatible YAML node kinds")
|
||||
}
|
||||
|
||||
func TestLoadConfigSources_NullYieldsToValue(t *testing.T) {
|
||||
// File A: routing.router block absent (null on root for routing);
|
||||
// file B: defines routing.router.settings.groups. Merge should keep B's.
|
||||
dir := t.TempDir()
|
||||
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||
writeYAML(t, dir, "b.yaml", "routing:\n router:\n settings:\n groups:\n g1:\n members: [m1]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||
|
||||
cfg, err := LoadConfigSources("", dir)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||
}
|
||||
@@ -92,9 +92,14 @@ type Effects interface {
|
||||
StopProcesses(timeout time.Duration, ids []string)
|
||||
}
|
||||
|
||||
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured
|
||||
// from conf and bound to the given planner and effects. Currently only "fifo"
|
||||
// (the default) is supported.
|
||||
// New returns a Scheduler selected by conf.Routing.Scheduler.Use, configured from
|
||||
// conf and bound to the given planner and effects. Supported values are "fifo"
|
||||
// (throughput-oriented, batches same-model requests) and "serial" (strict
|
||||
// one-model-at-a-time, exact arrival order).
|
||||
//
|
||||
// The deployment default is applied by config loading (LoadConfig sets Use to
|
||||
// "serial" when unset). The "" fallback here is the library default and remains
|
||||
// "fifo" so callers that build a Config directly keep the original behavior.
|
||||
func New(conf config.Config, name string, logger *logmon.Monitor, planner Swapper, eff Effects) (Scheduler, error) {
|
||||
use := conf.Routing.Scheduler.Use
|
||||
if use == "" {
|
||||
@@ -103,6 +108,9 @@ func New(conf config.Config, name string, logger *logmon.Monitor, planner Swappe
|
||||
switch use {
|
||||
case "fifo":
|
||||
return NewFIFO(name, logger, planner, conf.Routing.Scheduler.Settings.Fifo, conf.Models, eff), nil
|
||||
case "serial":
|
||||
// Serial ignores the group planner: it always evicts every other model.
|
||||
return NewSerial(name, logger, eff), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported scheduler type: %q", use)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,253 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// Serial is a strict one-model-at-a-time scheduler. Unlike FIFO it never reorders
|
||||
// or batches: requests run in exact arrival order and at most one request runs at
|
||||
// any instant. When the next request targets a model other than the one loaded,
|
||||
// every other running model is evicted and the target is loaded before it runs,
|
||||
// so a single model occupies memory at a time — at the cost of throughput.
|
||||
//
|
||||
// Example: A B C A is served as A B C A. The final A reloads its model even
|
||||
// though it ran first, because B and C displaced it in between. (FIFO, by
|
||||
// contrast, would batch the two A requests: A A B C.)
|
||||
//
|
||||
// Serial ignores group/eviction policy entirely: it always evicts every other
|
||||
// running model, regardless of how groups are configured. That is what makes the
|
||||
// single-model guarantee a property of the scheduler rather than of the config.
|
||||
//
|
||||
// Like FIFO, every method runs on the router's single run-loop goroutine, so no
|
||||
// internal locking is needed.
|
||||
type Serial struct {
|
||||
name string
|
||||
logger *logmon.Monitor
|
||||
effects Effects
|
||||
|
||||
// queued holds requests in strict arrival order. It is never reordered.
|
||||
queued []HandlerReq
|
||||
|
||||
// active is the one request currently being processed (loading or serving),
|
||||
// or nil when idle. phase is meaningful only while active != nil.
|
||||
active *HandlerReq
|
||||
phase serialPhase
|
||||
}
|
||||
|
||||
// serialPhase is the lifecycle stage of the active request.
|
||||
type serialPhase int
|
||||
|
||||
const (
|
||||
phaseIdle serialPhase = iota
|
||||
phaseSwapping // waiting for OnSwapDone for active.Model
|
||||
phaseServing // waiting for OnServeDone for active.Model
|
||||
)
|
||||
|
||||
// NewSerial builds a Serial scheduler. It takes no Swapper: eviction is always
|
||||
// "stop every other running model", so the group planner is not consulted.
|
||||
func NewSerial(name string, logger *logmon.Monitor, eff Effects) *Serial {
|
||||
return &Serial{
|
||||
name: name,
|
||||
logger: logger,
|
||||
effects: eff,
|
||||
}
|
||||
}
|
||||
|
||||
// OnRequest validates the model and appends the request to the tail of the queue,
|
||||
// then tries to start the next job. Unknown models fail immediately.
|
||||
func (s *Serial) OnRequest(req HandlerReq) {
|
||||
if _, ok := s.effects.ModelState(req.Model); !ok {
|
||||
s.logger.Debugf("%s: model %s not handled by this router", s.name, req.Model)
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
return
|
||||
}
|
||||
s.queued = append(s.queued, req)
|
||||
broadcastQueuePositions(s.queued)
|
||||
s.startNext()
|
||||
}
|
||||
|
||||
// startNext begins processing the head of the queue when nothing is active. It
|
||||
// fast-paths a request whose model is already the sole loaded-and-ready process;
|
||||
// otherwise it launches a swap that evicts every other running model first. The
|
||||
// loop skips over requests for models that vanished (e.g. a config reload) and
|
||||
// requests whose caller disconnected before they could be served.
|
||||
func (s *Serial) startNext() {
|
||||
if s.active != nil {
|
||||
return // a job is already loading or serving
|
||||
}
|
||||
for len(s.queued) > 0 {
|
||||
req := s.queued[0]
|
||||
s.queued = s.queued[1:]
|
||||
broadcastQueuePositions(s.queued)
|
||||
|
||||
state, ok := s.effects.ModelState(req.Model)
|
||||
if !ok {
|
||||
s.effects.GrantError(req, ErrModelNotFound)
|
||||
continue
|
||||
}
|
||||
|
||||
r := req
|
||||
s.active = &r
|
||||
|
||||
evict := s.otherRunning(req.Model)
|
||||
if state == process.StateReady && len(evict) == 0 {
|
||||
// Already loaded and the only model running — serve immediately.
|
||||
s.logger.Debugf("%s: serving model %s (already loaded)", s.name, req.Model)
|
||||
if s.serve() {
|
||||
return
|
||||
}
|
||||
continue // caller gone; pick the next request
|
||||
}
|
||||
|
||||
s.logger.Debugf("%s: swapping to model %s, evicting %v", s.name, req.Model, evict)
|
||||
s.phase = phaseSwapping
|
||||
s.effects.StartSwap(req.Model, evict)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// serve hands the active request its tracked handler. It returns true when the
|
||||
// request is now serving (await OnServeDone); false when the caller had already
|
||||
// disconnected, in which case active is cleared so the next job can start.
|
||||
func (s *Serial) serve() bool {
|
||||
if s.effects.GrantServe(*s.active, s.active.Model) {
|
||||
s.phase = phaseServing
|
||||
return true
|
||||
}
|
||||
s.logger.Debugf("%s: caller for model %s gone before serve", s.name, s.active.Model)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
return false
|
||||
}
|
||||
|
||||
// OnSwapDone fires when the load for the active request completes. On success the
|
||||
// request is served; on failure its caller receives the error and the queue
|
||||
// advances. A SwapDone that does not match the active load (e.g. its request was
|
||||
// unloaded or cancelled mid-load) is ignored.
|
||||
func (s *Serial) OnSwapDone(ev SwapDone) {
|
||||
if s.active == nil || s.phase != phaseSwapping || s.active.Model != ev.ModelID {
|
||||
return
|
||||
}
|
||||
if ev.Err != nil {
|
||||
s.logger.Debugf("%s: swap for model %s failed: %v", s.name, ev.ModelID, ev.Err)
|
||||
s.effects.GrantError(*s.active, ev.Err)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
s.startNext()
|
||||
return
|
||||
}
|
||||
if !s.serve() {
|
||||
s.startNext() // caller vanished while the model loaded; move on
|
||||
}
|
||||
}
|
||||
|
||||
// OnServeDone fires when the active request's handler returns. The slot is freed
|
||||
// and the next queued request begins.
|
||||
func (s *Serial) OnServeDone(ev ServeDoneEvent) {
|
||||
if s.active == nil || s.phase != phaseServing {
|
||||
return
|
||||
}
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
s.startNext()
|
||||
}
|
||||
|
||||
// OnCancel removes a disconnected client's request from the queue. A request that
|
||||
// is already active is left to finish: if it was loading, OnSwapDone's serve()
|
||||
// will find the caller gone (GrantServe false) and advance; if it was serving,
|
||||
// its handler returns normally and reaches OnServeDone.
|
||||
func (s *Serial) OnCancel(req HandlerReq) {
|
||||
if len(s.queued) == 0 {
|
||||
return
|
||||
}
|
||||
kept := s.queued[:0]
|
||||
removed := false
|
||||
for _, q := range s.queued {
|
||||
if q.Respond == req.Respond {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, q)
|
||||
}
|
||||
s.queued = kept
|
||||
if removed {
|
||||
s.logger.Debugf("%s: cancelled request for model %s pruned from queue", s.name, req.Model)
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnload reconciles state for an unload, stops the targeted processes, and
|
||||
// advances the queue. It mirrors the FIFO contract: queued requests for unloaded
|
||||
// models are failed; an active *loading* request for an unloaded model is failed
|
||||
// (its swap goroutine is left to finish and its SwapDone is then ignored); an
|
||||
// active *serving* request is left for its handler to end when StopProcesses
|
||||
// kills the upstream. The Stop is synchronous so callers of Unload can rely on
|
||||
// the processes being stopped on return.
|
||||
func (s *Serial) OnUnload(targets []string, timeout time.Duration) {
|
||||
unloadErr := fmt.Errorf("%s: model unloaded", s.name)
|
||||
|
||||
targetSet := make(map[string]bool, len(targets))
|
||||
for _, id := range targets {
|
||||
targetSet[id] = true
|
||||
}
|
||||
|
||||
if s.active != nil && s.phase == phaseSwapping && targetSet[s.active.Model] {
|
||||
s.effects.GrantError(*s.active, unloadErr)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
}
|
||||
|
||||
if len(s.queued) > 0 {
|
||||
kept := s.queued[:0]
|
||||
for _, q := range s.queued {
|
||||
if targetSet[q.Model] {
|
||||
s.effects.GrantError(q, unloadErr)
|
||||
continue
|
||||
}
|
||||
kept = append(kept, q)
|
||||
}
|
||||
s.queued = kept
|
||||
broadcastQueuePositions(s.queued)
|
||||
}
|
||||
|
||||
s.effects.StopProcesses(timeout, targets)
|
||||
|
||||
// A still-serving active request advances via OnServeDone when its killed
|
||||
// handler returns; only start the next job when nothing is active now.
|
||||
if s.active == nil {
|
||||
s.startNext()
|
||||
}
|
||||
}
|
||||
|
||||
// OnShutdown grants err to every request the scheduler still holds: an active
|
||||
// loading request and all queued requests. A serving request is torn down with
|
||||
// its process by the baseRouter.
|
||||
func (s *Serial) OnShutdown(err error) {
|
||||
if s.active != nil && s.phase == phaseSwapping {
|
||||
s.effects.GrantError(*s.active, err)
|
||||
s.active = nil
|
||||
s.phase = phaseIdle
|
||||
}
|
||||
for _, q := range s.queued {
|
||||
s.effects.GrantError(q, err)
|
||||
}
|
||||
s.queued = nil
|
||||
}
|
||||
|
||||
// otherRunning returns every running model except target, sorted for
|
||||
// deterministic eviction.
|
||||
func (s *Serial) otherRunning(target string) []string {
|
||||
var out []string
|
||||
for id := range s.effects.RunningModels() {
|
||||
if id != target {
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/internal/logmon"
|
||||
"github.com/mostlygeek/llama-swap/internal/process"
|
||||
)
|
||||
|
||||
// Serial methods all run on the router's single run-loop goroutine, so these
|
||||
// tests drive them directly and synchronously, reusing fakeEffects and the
|
||||
// req/reqCh helpers from fifo_test.go. A load completes via OnSwapDone and a
|
||||
// served request finishes via OnServeDone — the events the run loop delivers.
|
||||
|
||||
func newSerial(eff Effects) *Serial {
|
||||
return NewSerial("test", logmon.NewWriter(io.Discard), eff)
|
||||
}
|
||||
|
||||
// lastStart returns the most recent StartSwap record.
|
||||
func lastStart(t *testing.T, eff *fakeEffects) startRec {
|
||||
t.Helper()
|
||||
if len(eff.starts) == 0 {
|
||||
t.Fatal("no StartSwap recorded")
|
||||
}
|
||||
return eff.starts[len(eff.starts)-1]
|
||||
}
|
||||
|
||||
func sameSet(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
m := map[string]int{}
|
||||
for _, x := range a {
|
||||
m[x]++
|
||||
}
|
||||
for _, x := range b {
|
||||
m[x]--
|
||||
}
|
||||
for _, v := range m {
|
||||
if v != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// servedOrder returns the model IDs of every successful serve grant in order.
|
||||
func servedOrder(eff *fakeEffects) []string {
|
||||
var out []string
|
||||
for _, g := range eff.grants {
|
||||
if g.err == nil && g.serve {
|
||||
out = append(out, g.model)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestSerial_FastPath_AlreadyLoaded(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
if got := len(eff.starts); got != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0 (already loaded, no swap)", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_ColdStart_LoadsThenServes(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Fatalf("StartSwap(a)=%d want 1", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 0 {
|
||||
t.Errorf("served(a)=%d want 0 before load completes", got)
|
||||
}
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Errorf("served(a)=%d want 1 after load", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_UnknownModel(t *testing.T) {
|
||||
eff := newFakeEffects() // no states => unknown
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("ghost"))
|
||||
|
||||
if len(eff.starts) != 0 {
|
||||
t.Errorf("StartSwap calls=%d want 0", len(eff.starts))
|
||||
}
|
||||
if eff.errored("ghost") != 1 {
|
||||
t.Fatalf("errored(ghost)=%d want 1", eff.errored("ghost"))
|
||||
}
|
||||
if !errors.Is(eff.grants[0].err, ErrModelNotFound) {
|
||||
t.Errorf("err=%v want ErrModelNotFound", eff.grants[0].err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_EvictsEveryOtherModel(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["x"] = process.StateReady // already running
|
||||
eff.states["y"] = process.StateReady // also running (e.g. left over)
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
|
||||
st := lastStart(t, eff)
|
||||
if st.model != "a" {
|
||||
t.Fatalf("loading %s want a", st.model)
|
||||
}
|
||||
if !sameSet(st.evict, []string{"x", "y"}) {
|
||||
t.Errorf("evict=%v want [x y] (serial evicts ALL other models)", st.evict)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_OneJobAtATime verifies a second request waits while the first is
|
||||
// serving, and only starts after the first finishes.
|
||||
func TestSerial_OneJobAtATime(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // served immediately
|
||||
s.OnRequest(req("b")) // must wait — a is serving
|
||||
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("StartSwap(b)=%d want 0 while a is serving", got)
|
||||
}
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// a finishes -> b may now load (evicting a).
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a finished", got)
|
||||
}
|
||||
if st := lastStart(t, eff); !sameSet(st.evict, []string{"a"}) {
|
||||
t.Errorf("b evict=%v want [a]", st.evict)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_SameModelConsecutive_NoReload verifies back-to-back requests for the
|
||||
// already-loaded model run without a reload, one after another.
|
||||
func TestSerial_SameModelConsecutive_NoReload(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // cold load
|
||||
s.OnRequest(req("a")) // queued behind the first
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"}) // first serves
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1 (one at a time)", got)
|
||||
}
|
||||
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"}) // first done -> second serves
|
||||
if got := eff.served("a"); got != 2 {
|
||||
t.Fatalf("served(a)=%d want 2", got)
|
||||
}
|
||||
if got := eff.startsFor("a"); got != 1 {
|
||||
t.Errorf("StartSwap(a)=%d want 1 (second request must not reload)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_StrictArrivalOrder is the core guarantee: qwen36, qwen35, sdxl,
|
||||
// qwen36 execute in EXACTLY that order with evictions between each model switch,
|
||||
// including reloading qwen36 at the end even though it ran first.
|
||||
func TestSerial_StrictArrivalOrder(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
for _, m := range []string{"qwen36", "qwen35", "sdxl"} {
|
||||
eff.states[m] = process.StateStopped
|
||||
}
|
||||
s := newSerial(eff)
|
||||
|
||||
for _, m := range []string{"qwen36", "qwen35", "sdxl", "qwen36"} {
|
||||
s.OnRequest(req(m))
|
||||
}
|
||||
|
||||
// Only the first job starts loading; the rest wait their turn.
|
||||
if len(eff.starts) != 1 || eff.starts[0].model != "qwen36" {
|
||||
t.Fatalf("starts=%+v want only [qwen36] loading first", eff.starts)
|
||||
}
|
||||
|
||||
// step completes the current model's load+serve and returns control to the
|
||||
// scheduler, which must start the next queued model.
|
||||
step := func(model string, wantEvict []string) {
|
||||
t.Helper()
|
||||
st := lastStart(t, eff)
|
||||
if st.model != model {
|
||||
t.Fatalf("loading %q want %q", st.model, model)
|
||||
}
|
||||
if !sameSet(st.evict, wantEvict) {
|
||||
t.Fatalf("loading %q evict=%v want %v", model, st.evict, wantEvict)
|
||||
}
|
||||
// Simulate the eviction + load actually happening.
|
||||
for _, e := range st.evict {
|
||||
eff.states[e] = process.StateStopped
|
||||
}
|
||||
eff.states[model] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: model})
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: model})
|
||||
}
|
||||
|
||||
step("qwen36", nil) // cold load, nothing else running
|
||||
step("qwen35", []string{"qwen36"}) // evict qwen36
|
||||
step("sdxl", []string{"qwen35"}) // evict qwen35
|
||||
step("qwen36", []string{"sdxl"}) // RELOAD qwen36, evict sdxl
|
||||
|
||||
want := []string{"qwen36", "qwen35", "sdxl", "qwen36"}
|
||||
if got := servedOrder(eff); !sameOrder(got, want) {
|
||||
t.Fatalf("serve order=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func sameOrder(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestSerial_SwapError_FailsCallerAndAdvances(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("b")) // queued behind a
|
||||
|
||||
// a's load fails: its caller is errored and b proceeds.
|
||||
s.OnSwapDone(SwapDone{ModelID: "a", Err: errors.New("boom")})
|
||||
if eff.errored("a") != 1 {
|
||||
t.Fatalf("errored(a)=%d want 1", eff.errored("a"))
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after a's load failed", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_GrantServeFalse_Advances verifies that when the active request's
|
||||
// caller has disconnected by serve time, the queue advances to the next request.
|
||||
func TestSerial_GrantServeFalse_Advances(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.serveResult["a"] = false // a's caller is gone by grant time
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a"))
|
||||
s.OnRequest(req("b")) // queued
|
||||
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"}) // grant fails -> advance to b
|
||||
|
||||
if got := eff.served("a"); got != 0 {
|
||||
t.Errorf("served(a)=%d want 0 (caller gone)", got)
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 (advanced after gone caller)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_OnCancel_QueuedRequest(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(reqCh("a")) // starts loading a
|
||||
cancelled := reqCh("b")
|
||||
s.OnRequest(cancelled) // queued behind a
|
||||
if len(s.queued) != 1 {
|
||||
t.Fatalf("queued=%d want 1", len(s.queued))
|
||||
}
|
||||
|
||||
s.OnCancel(cancelled)
|
||||
if len(s.queued) != 0 {
|
||||
t.Fatalf("queued=%d want 0 after cancel", len(s.queued))
|
||||
}
|
||||
|
||||
// a completes; b is gone, so nothing starts for it.
|
||||
eff.states["a"] = process.StateReady
|
||||
s.OnSwapDone(SwapDone{ModelID: "a"})
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Errorf("StartSwap(b)=%d want 0 (cancelled before its turn)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_OnShutdown_FailsQueuedAndActiveLoad(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
eff.states["c"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // active (loading)
|
||||
s.OnRequest(req("b")) // queued
|
||||
s.OnRequest(req("c")) // queued
|
||||
|
||||
s.OnShutdown(errors.New("shutting down"))
|
||||
|
||||
if got := eff.errored(""); got != 3 {
|
||||
t.Errorf("error grants=%d want 3 (active load + 2 queued)", got)
|
||||
}
|
||||
if len(s.queued) != 0 {
|
||||
t.Errorf("queued=%d want 0 after shutdown", len(s.queued))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerial_OnUnload_WhileServing verifies that unloading the model that is
|
||||
// actively serving does not strand the queue: OnUnload stops the process but
|
||||
// leaves the active request to end via OnServeDone, which then advances.
|
||||
func TestSerial_OnUnload_WhileServing(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateReady
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // served immediately (a ready)
|
||||
s.OnRequest(req("b")) // queued behind a
|
||||
if got := eff.served("a"); got != 1 {
|
||||
t.Fatalf("served(a)=%d want 1", got)
|
||||
}
|
||||
|
||||
// Unload a while it is serving: the process is stopped, but the queue must
|
||||
// not advance yet — the active serve is still outstanding.
|
||||
s.OnUnload([]string{"a"}, time.Second)
|
||||
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||
}
|
||||
if got := eff.startsFor("b"); got != 0 {
|
||||
t.Fatalf("StartSwap(b)=%d want 0 before the serving request ends", got)
|
||||
}
|
||||
|
||||
// The killed handler returns -> OnServeDone advances to b.
|
||||
eff.states["a"] = process.StateStopped
|
||||
s.OnServeDone(ServeDoneEvent{ModelID: "a"})
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Fatalf("StartSwap(b)=%d want 1 after the serving request ended", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerial_OnUnload_DropsQueuedAndStops(t *testing.T) {
|
||||
eff := newFakeEffects()
|
||||
eff.states["a"] = process.StateStopped
|
||||
eff.states["b"] = process.StateStopped
|
||||
s := newSerial(eff)
|
||||
|
||||
s.OnRequest(req("a")) // active (loading a)
|
||||
s.OnRequest(req("b")) // queued
|
||||
|
||||
// Unload a: its active load is failed and a is stopped.
|
||||
s.OnUnload([]string{"a"}, time.Second)
|
||||
|
||||
if eff.errored("a") != 1 {
|
||||
t.Errorf("errored(a)=%d want 1 (active load failed)", eff.errored("a"))
|
||||
}
|
||||
if len(eff.stops) != 1 || !sameSet(eff.stops[0].ids, []string{"a"}) {
|
||||
t.Errorf("StopProcesses=%+v want one call stopping [a]", eff.stops)
|
||||
}
|
||||
// b was queued and not unloaded; with a's load cancelled it now starts.
|
||||
if got := eff.startsFor("b"); got != 1 {
|
||||
t.Errorf("StartSwap(b)=%d want 1 after unload advanced the queue", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DirWatcher polls a directory for changes to its set of *.yml / *.yaml files.
|
||||
// It fires OnChange when a file is added, removed, or has its mod time/size
|
||||
// change. Like Watcher it is poll-based so it works in Docker bind-mounts and
|
||||
// k8s ConfigMap projections where inotify is unreliable.
|
||||
//
|
||||
// The baseline poll establishes initial state and does not fire OnChange.
|
||||
type DirWatcher struct {
|
||||
Path string
|
||||
Interval time.Duration
|
||||
OnChange func()
|
||||
}
|
||||
|
||||
// dirSnapshot is an ordered map of file name -> file state. The ordering is
|
||||
// derived from sorted filenames so two snapshots compare deterministically
|
||||
// regardless of readdir order. exists reflects whether the directory was
|
||||
// readable at scan time; a missing directory yields exists=false.
|
||||
type dirSnapshot struct {
|
||||
exists bool
|
||||
names []string
|
||||
states map[string]snapshot
|
||||
}
|
||||
|
||||
func newDirSnapshot() dirSnapshot {
|
||||
return dirSnapshot{states: make(map[string]snapshot)}
|
||||
}
|
||||
|
||||
// equal reports whether two snapshots describe the same file set and per-file
|
||||
// state. A missing directory (exists=false) is treated as equal to any other
|
||||
// missing directory regardless of cached names.
|
||||
func (s dirSnapshot) equal(other dirSnapshot) bool {
|
||||
if !s.exists && !other.exists {
|
||||
return true
|
||||
}
|
||||
if s.exists != other.exists {
|
||||
return false
|
||||
}
|
||||
if len(s.names) != len(other.names) {
|
||||
return false
|
||||
}
|
||||
for i, n := range s.names {
|
||||
if other.names[i] != n {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, n := range s.names {
|
||||
a, b := s.states[n], other.states[n]
|
||||
if a.exists != b.exists || a.size != b.size || !a.modTime.Equal(b.modTime) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||
// OnChange whenever the directory's YAML file set changes.
|
||||
//
|
||||
// Policy mirrors the single-file Watcher: disappearance (directory missing or
|
||||
// empty) is treated as a transient rename-style write and stays quiet; the
|
||||
// transition back to present-with-content fires OnChange.
|
||||
func (w *DirWatcher) Run(ctx context.Context) {
|
||||
interval := w.Interval
|
||||
if interval <= 0 {
|
||||
interval = DefaultInterval
|
||||
}
|
||||
|
||||
prev := scanDir(w.Path)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cur := scanDir(w.Path)
|
||||
// Suppress transitions involving an empty or missing directory —
|
||||
// these are treated as transient rename-style writes, mirroring
|
||||
// the single-file Watcher. Only present-with-content →
|
||||
// present-with-content (changed) or no-content →
|
||||
// present-with-content fires OnChange.
|
||||
prevHasContent := prev.exists && len(prev.names) > 0
|
||||
curHasContent := cur.exists && len(cur.names) > 0
|
||||
if curHasContent && (!prevHasContent || !prev.equal(cur)) && w.OnChange != nil {
|
||||
w.OnChange()
|
||||
}
|
||||
prev = cur
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// scanDir returns a snapshot of the *.yml/*.yaml files in dir. If the
|
||||
// directory cannot be read (missing, permission denied) the snapshot reports
|
||||
// exists=false; the next successful scan will detect the recovery and fire
|
||||
// OnChange.
|
||||
func scanDir(dir string) dirSnapshot {
|
||||
snap := newDirSnapshot()
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return snap // exists=false
|
||||
}
|
||||
snap.exists = true
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||
continue
|
||||
}
|
||||
fi, err := os.Stat(filepath.Join(dir, name))
|
||||
if err != nil {
|
||||
// File disappeared between ReadDir and Stat; skip it — the
|
||||
// next poll will observe the removal cleanly.
|
||||
continue
|
||||
}
|
||||
snap.names = append(snap.names, name)
|
||||
snap.states[name] = snapshot{
|
||||
exists: true,
|
||||
modTime: fi.ModTime(),
|
||||
size: fi.Size(),
|
||||
}
|
||||
}
|
||||
sort.Strings(snap.names)
|
||||
return snap
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package configwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// startDirWatcher launches w.Run in a goroutine and returns a function that
|
||||
// cancels the context and waits for Run to return.
|
||||
func startDirWatcher(t *testing.T, w *DirWatcher) func() {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("DirWatcher did not stop within 2s of cancel")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeYAMLInDir(t *testing.T, dir, name, content string) {
|
||||
t.Helper()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644))
|
||||
}
|
||||
|
||||
func TestDirWatcher_NoFireOnBaseline(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
|
||||
time.Sleep(testInterval * 5)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||
}
|
||||
|
||||
func TestDirWatcher_DetectsFileAdd(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is added")
|
||||
}
|
||||
|
||||
func TestDirWatcher_DetectsFileRemoval(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, "b.yaml")))
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is removed")
|
||||
}
|
||||
|
||||
func TestDirWatcher_DetectsModTimeChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base, base))
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||
}
|
||||
|
||||
func TestDirWatcher_IgnoresNonYAMLFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Adding a .txt file must not fire.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0o644))
|
||||
time.Sleep(testInterval * 4)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "non-YAML files must be ignored")
|
||||
|
||||
// Adding a .yml file must fire.
|
||||
writeYAMLInDir(t, dir, "b.yml", "b")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire for *.yml files")
|
||||
}
|
||||
|
||||
func TestDirWatcher_MissingDirRecovers(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Remove the directory. No fire expected on disappearance alone.
|
||||
require.NoError(t, os.RemoveAll(dir))
|
||||
time.Sleep(testInterval * 3)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "directory removal alone must not fire")
|
||||
|
||||
// Recreate the directory and a YAML file; the recovery should fire.
|
||||
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||
writeYAMLInDir(t, dir, "recovered.yaml", "r")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when dir returns with content")
|
||||
}
|
||||
|
||||
func TestDirWatcher_EmptyDirSuppressedThenRecovers(t *testing.T) {
|
||||
// Present-with-content → empty (all YAML removed, dir still exists)
|
||||
// must stay quiet — treated as transient per the documented policy.
|
||||
// The transition back to content fires.
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
var n int64
|
||||
stop := startDirWatcher(t, &DirWatcher{
|
||||
Path: dir,
|
||||
Interval: testInterval,
|
||||
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||
})
|
||||
defer stop()
|
||||
time.Sleep(testInterval * 2)
|
||||
|
||||
// Remove the only YAML file. Dir still exists but is empty of YAML.
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, "a.yaml")))
|
||||
time.Sleep(testInterval * 4)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&n), "emptying the directory must not fire")
|
||||
|
||||
// Add a YAML file back; transition to present-with-content fires.
|
||||
writeYAMLInDir(t, dir, "c.yaml", "c")
|
||||
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when content returns")
|
||||
}
|
||||
|
||||
func TestDirWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||
|
||||
w := &DirWatcher{Path: dir, Interval: testInterval}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() { w.Run(ctx); close(done) }()
|
||||
|
||||
time.Sleep(testInterval * 2)
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Run did not return within 2s of cancel")
|
||||
}
|
||||
}
|
||||
+27
-9
@@ -55,7 +55,8 @@ var logTimeFormats = map[string]string{
|
||||
}
|
||||
|
||||
func main() {
|
||||
flagConfig := flag.String("config", "", "path to config file (required)")
|
||||
flagConfig := flag.String("config", "", "path to config file")
|
||||
flagConfigDir := flag.String("config-dir", "", "directory of *.yml/*.yaml config files (additive to -config)")
|
||||
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
|
||||
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||
flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||
@@ -68,8 +69,8 @@ func main() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if *flagConfig == "" {
|
||||
slog.Error("-config is required")
|
||||
if *flagConfig == "" && *flagConfigDir == "" {
|
||||
slog.Error("at least one of -config or -config-dir must be provided")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -88,10 +89,9 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
configPath := *flagConfig
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
cfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||
if err != nil {
|
||||
slog.Error("failed to load config", "path", configPath, "error", err)
|
||||
slog.Error("failed to load config", "config", *flagConfig, "config-dir", *flagConfigDir, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -187,7 +187,7 @@ func main() {
|
||||
|
||||
proxyLog.Info("reloading configuration")
|
||||
|
||||
newCfg, err := config.LoadConfig(configPath)
|
||||
newCfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||
if err != nil {
|
||||
proxyLog.Warnf("failed to reload config: %v", err)
|
||||
return
|
||||
@@ -230,12 +230,14 @@ func main() {
|
||||
defer watcherCancel()
|
||||
|
||||
if *flagWatchConfig {
|
||||
absConfigPath, err := filepath.Abs(configPath)
|
||||
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
||||
|
||||
if *flagConfig != "" {
|
||||
absConfigPath, err := filepath.Abs(*flagConfig)
|
||||
if err != nil {
|
||||
slog.Error("watch-config: failed to resolve config path", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
||||
go func() {
|
||||
(&configwatcher.Watcher{
|
||||
Path: absConfigPath,
|
||||
@@ -245,6 +247,22 @@ func main() {
|
||||
}()
|
||||
}
|
||||
|
||||
if *flagConfigDir != "" {
|
||||
absConfigDir, err := filepath.Abs(*flagConfigDir)
|
||||
if err != nil {
|
||||
slog.Error("watch-config: failed to resolve config-dir path", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
go func() {
|
||||
(&configwatcher.DirWatcher{
|
||||
Path: absConfigDir,
|
||||
Interval: configwatcher.DefaultInterval,
|
||||
OnChange: reload,
|
||||
}).Run(watcherCtx)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user