internal/config,watcher: add -config-dir (#873)
Over time the llama-swap configuration file can get really long and challenging to work with. The -config-dir flag is used for a directory of configuration YAML fragments. These fragments are merged together and into a full configuration and tested for validity. All previous configuration functionality remains unchanged.
This commit is contained in:
@@ -0,0 +1,57 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/billziss-gh/golib/shlex"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SanitizeCommand(cmdStr string) ([]string, error) {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Handle trailing backslashes by replacing with space
|
||||||
|
if strings.HasSuffix(trimmed, "\\") {
|
||||||
|
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
||||||
|
} else {
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// put it back together
|
||||||
|
cmdStr = strings.Join(cleanedLines, "\n")
|
||||||
|
|
||||||
|
// Split the command into arguments
|
||||||
|
var args []string
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
args = shlex.Windows.Split(cmdStr)
|
||||||
|
} else {
|
||||||
|
args = shlex.Posix.Split(cmdStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the command is not empty
|
||||||
|
if len(args) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty command")
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripComments(cmdStr string) string {
|
||||||
|
var cleanedLines []string
|
||||||
|
for _, line := range strings.Split(cmdStr, "\n") {
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
// Skip comment lines
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cleanedLines = append(cleanedLines, line)
|
||||||
|
}
|
||||||
|
return strings.Join(cleanedLines, "\n")
|
||||||
|
}
|
||||||
@@ -2,16 +2,9 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
|
||||||
"runtime"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/billziss-gh/golib/shlex"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,12 +78,6 @@ type GroupConfig struct {
|
|||||||
Members []string `yaml:"members"`
|
Members []string `yaml:"members"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
|
||||||
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
|
||||||
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
|
||||||
)
|
|
||||||
|
|
||||||
// set default values for GroupConfig
|
// set default values for GroupConfig
|
||||||
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
type rawGroupConfig GroupConfig
|
type rawGroupConfig GroupConfig
|
||||||
@@ -224,430 +211,6 @@ func LoadConfig(path string) (Config, error) {
|
|||||||
return LoadConfigFromReader(file)
|
return LoadConfigFromReader(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
yamlStr := string(data)
|
|
||||||
|
|
||||||
// Phase 1: Substitute all ${env.VAR} macros at string level
|
|
||||||
// This is safe because env values are simple strings without YAML formatting
|
|
||||||
yamlStr, err = substituteEnvMacros(yamlStr)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unmarshal into full Config with defaults
|
|
||||||
config := Config{
|
|
||||||
HealthCheckTimeout: 120,
|
|
||||||
StartPort: 5800,
|
|
||||||
LogLevel: "info",
|
|
||||||
LogTimeFormat: "",
|
|
||||||
LogToStdout: LogToStdoutProxy,
|
|
||||||
MetricsMaxInMemory: 1000,
|
|
||||||
CaptureBuffer: 5,
|
|
||||||
GlobalTTL: 0,
|
|
||||||
}
|
|
||||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.HealthCheckTimeout < 15 {
|
|
||||||
config.HealthCheckTimeout = 15
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply defaults for performance config when section is missing
|
|
||||||
if config.Performance.Every == 0 {
|
|
||||||
config.Performance.Every = 5 * time.Second
|
|
||||||
}
|
|
||||||
if err = config.Performance.Validate(); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("performance: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.StartPort < 1 {
|
|
||||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.GlobalTTL < 0 {
|
|
||||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply default for upstream.ignorePaths when not specified. The default
|
|
||||||
// matches common static-asset suffixes so they do not trigger a swap.
|
|
||||||
if len(config.Upstream.IgnorePaths) == 0 {
|
|
||||||
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
|
||||||
}
|
|
||||||
|
|
||||||
switch config.LogToStdout {
|
|
||||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
|
||||||
default:
|
|
||||||
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate the aliases map
|
|
||||||
config.aliases = make(map[string]string)
|
|
||||||
for modelName, modelConfig := range config.Models {
|
|
||||||
for _, alias := range modelConfig.Aliases {
|
|
||||||
if _, found := config.aliases[alias]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
|
||||||
}
|
|
||||||
config.aliases[alias] = modelName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate global macros
|
|
||||||
for _, macro := range config.Macros {
|
|
||||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get and sort all model IDs for consistent port assignment
|
|
||||||
modelIds := make([]string, 0, len(config.Models))
|
|
||||||
for modelId := range config.Models {
|
|
||||||
modelIds = append(modelIds, modelId)
|
|
||||||
}
|
|
||||||
sort.Strings(modelIds)
|
|
||||||
|
|
||||||
nextPort := config.StartPort
|
|
||||||
for _, modelId := range modelIds {
|
|
||||||
modelConfig := config.Models[modelId]
|
|
||||||
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
|
||||||
|
|
||||||
// Strip comments from command fields
|
|
||||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
|
||||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
|
||||||
|
|
||||||
// set model TTL to globalTTL it is the default value
|
|
||||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
|
||||||
modelConfig.UnloadAfter = config.GlobalTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelConfig.UnloadAfter < 0 {
|
|
||||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate model macros
|
|
||||||
for _, macro := range modelConfig.Macros {
|
|
||||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
|
||||||
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
|
||||||
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
|
||||||
mergedMacros = append(mergedMacros, config.Macros...)
|
|
||||||
|
|
||||||
// Add model macros (override globals with same name)
|
|
||||||
for _, entry := range modelConfig.Macros {
|
|
||||||
found := false
|
|
||||||
for i, existing := range mergedMacros {
|
|
||||||
if existing.Name == entry.Name {
|
|
||||||
mergedMacros[i] = entry
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
mergedMacros = append(mergedMacros, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Substitute remaining macros in model fields (LIFO order)
|
|
||||||
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
|
||||||
entry := mergedMacros[i]
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
|
||||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
|
||||||
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
|
||||||
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
|
||||||
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
|
||||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
|
||||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
|
||||||
|
|
||||||
// Substitute macros in SetParamsByID keys and values
|
|
||||||
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
|
||||||
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
|
||||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
|
||||||
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
|
||||||
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
newParamMap, ok := newValAny.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
|
||||||
}
|
|
||||||
newSetParamsByID[newKey] = newParamMap
|
|
||||||
}
|
|
||||||
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Substitute in metadata (type-preserving)
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
modelConfig.Metadata = result.(map[string]any)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle PORT macro - only allocate if cmd uses it
|
|
||||||
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
|
||||||
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
|
||||||
if cmdHasPort || proxyHasPort {
|
|
||||||
if !cmdHasPort && proxyHasPort {
|
|
||||||
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
|
||||||
}
|
|
||||||
|
|
||||||
macroSlug := "${PORT}"
|
|
||||||
macroStr := fmt.Sprintf("%v", nextPort)
|
|
||||||
|
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
|
||||||
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
|
||||||
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
|
||||||
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
|
||||||
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
|
||||||
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
|
||||||
}
|
|
||||||
modelConfig.Metadata = result.(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
nextPort++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate no unknown macros remain
|
|
||||||
fieldMap := map[string]string{
|
|
||||||
"cmd": modelConfig.Cmd,
|
|
||||||
"cmdStop": modelConfig.CmdStop,
|
|
||||||
"proxy": modelConfig.Proxy,
|
|
||||||
"checkEndpoint": modelConfig.CheckEndpoint,
|
|
||||||
"filters.stripParams": modelConfig.Filters.StripParams,
|
|
||||||
"name": modelConfig.Name,
|
|
||||||
"description": modelConfig.Description,
|
|
||||||
}
|
|
||||||
|
|
||||||
for fieldName, fieldValue := range fieldMap {
|
|
||||||
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
macroName := match[1]
|
|
||||||
if macroName == "PID" && fieldName == "cmdStop" {
|
|
||||||
continue // replaced at runtime
|
|
||||||
}
|
|
||||||
if macroName == "PORT" || macroName == "MODEL_ID" {
|
|
||||||
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
|
||||||
}
|
|
||||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(modelConfig.Metadata) > 0 {
|
|
||||||
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = modelConfig.Capabilities.Validate(); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate SetParamsByID keys and values
|
|
||||||
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
|
||||||
}
|
|
||||||
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
|
||||||
for key := range modelConfig.Filters.SetParamsByID {
|
|
||||||
if key == modelId {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, exists := config.Models[key]; exists {
|
|
||||||
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
|
||||||
}
|
|
||||||
if existingModel, exists := config.aliases[key]; exists {
|
|
||||||
if existingModel != modelId {
|
|
||||||
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
|
||||||
}
|
|
||||||
continue // already registered as explicit alias for this model
|
|
||||||
}
|
|
||||||
config.aliases[key] = modelId
|
|
||||||
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelConfig.SendLoadingState == nil {
|
|
||||||
v := config.SendLoadingState
|
|
||||||
modelConfig.SendLoadingState = &v
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Models[modelId] = modelConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
|
||||||
// the new `routing.router` block are mutually exclusive: a config may use
|
|
||||||
// either style, never both.
|
|
||||||
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
|
||||||
rtr := config.Routing.Router
|
|
||||||
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
|
||||||
|
|
||||||
if hasTopLevel && hasRouting {
|
|
||||||
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasTopLevel {
|
|
||||||
// Both groups and matrix may be defined under routing.router.settings;
|
|
||||||
// routing.router.use selects which one is active, so there is no conflict.
|
|
||||||
rs := config.Routing.Router.Settings
|
|
||||||
switch config.Routing.Router.Use {
|
|
||||||
case "matrix":
|
|
||||||
if rs.Matrix == nil {
|
|
||||||
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
|
||||||
}
|
|
||||||
config.Matrix = rs.Matrix
|
|
||||||
case "group", "":
|
|
||||||
config.Groups = rs.Groups
|
|
||||||
default:
|
|
||||||
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// groups XOR matrix
|
|
||||||
if config.Matrix != nil && len(config.Groups) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Matrix != nil {
|
|
||||||
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("matrix: %w", err)
|
|
||||||
}
|
|
||||||
config.Matrix.ExpandedSets = expandedSets
|
|
||||||
} else {
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
|
||||||
|
|
||||||
// Validate group members
|
|
||||||
memberUsage := make(map[string]string)
|
|
||||||
for groupID, groupConfig := range config.Groups {
|
|
||||||
prevSet := make(map[string]bool)
|
|
||||||
for _, member := range groupConfig.Members {
|
|
||||||
if _, found := prevSet[member]; found {
|
|
||||||
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
|
||||||
}
|
|
||||||
prevSet[member] = true
|
|
||||||
|
|
||||||
if existingGroup, exists := memberUsage[member]; exists {
|
|
||||||
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
|
||||||
}
|
|
||||||
memberUsage[member] = groupID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the canonical Config.Routing from the effective result. Both legacy
|
|
||||||
// and new-style configs converge here. The Matrix pointer is shared so
|
|
||||||
// ExpandedSets stays in one place.
|
|
||||||
if config.Matrix != nil {
|
|
||||||
config.Routing.Router.Use = "matrix"
|
|
||||||
} else {
|
|
||||||
config.Routing.Router.Use = "group"
|
|
||||||
}
|
|
||||||
config.Routing.Router.Settings.Matrix = config.Matrix
|
|
||||||
config.Routing.Router.Settings.Groups = config.Groups
|
|
||||||
|
|
||||||
if config.Routing.Scheduler.Use == "" {
|
|
||||||
config.Routing.Scheduler.Use = "fifo"
|
|
||||||
}
|
|
||||||
if config.Routing.Scheduler.Use != "fifo" {
|
|
||||||
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
|
|
||||||
}
|
|
||||||
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
|
||||||
if _, found := config.RealModelName(modelID); !found {
|
|
||||||
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up hooks preload
|
|
||||||
if len(config.Hooks.OnStartup.Preload) > 0 {
|
|
||||||
var toPreload []string
|
|
||||||
for _, modelID := range config.Hooks.OnStartup.Preload {
|
|
||||||
modelID = strings.TrimSpace(modelID)
|
|
||||||
if modelID == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if real, found := config.RealModelName(modelID); found {
|
|
||||||
toPreload = append(toPreload, real)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config.Hooks.OnStartup.Preload = toPreload
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate API keys (env macros already substituted at string level)
|
|
||||||
for i, apikey := range config.RequiredAPIKeys {
|
|
||||||
if apikey == "" {
|
|
||||||
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
|
||||||
}
|
|
||||||
if strings.Contains(apikey, " ") {
|
|
||||||
return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey)
|
|
||||||
}
|
|
||||||
config.RequiredAPIKeys[i] = apikey
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process peers with global macro substitution
|
|
||||||
for peerName, peerConfig := range config.Peers {
|
|
||||||
// Substitute global macros (LIFO order)
|
|
||||||
for i := len(config.Macros) - 1; i >= 0; i-- {
|
|
||||||
entry := config.Macros[i]
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
|
||||||
macroStr := fmt.Sprintf("%v", entry.Value)
|
|
||||||
|
|
||||||
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
|
||||||
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
|
||||||
|
|
||||||
// Substitute in setParams (type-preserving)
|
|
||||||
if len(peerConfig.Filters.SetParams) > 0 {
|
|
||||||
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
|
||||||
}
|
|
||||||
peerConfig.Filters.SetParams = result.(map[string]any)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate no unknown macros remain
|
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
|
||||||
}
|
|
||||||
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
|
||||||
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
|
||||||
}
|
|
||||||
if len(peerConfig.Filters.SetParams) > 0 {
|
|
||||||
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
|
||||||
return Config{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config.Peers[peerName] = peerConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rewrites the yaml to include a default group with any orphaned models
|
// rewrites the yaml to include a default group with any orphaned models
|
||||||
func AddDefaultGroupToConfig(config Config) Config {
|
func AddDefaultGroupToConfig(config Config) Config {
|
||||||
|
|
||||||
@@ -692,233 +255,3 @@ func AddDefaultGroupToConfig(config Config) Config {
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func SanitizeCommand(cmdStr string) ([]string, error) {
|
|
||||||
var cleanedLines []string
|
|
||||||
for _, line := range strings.Split(cmdStr, "\n") {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
// Skip comment lines
|
|
||||||
if strings.HasPrefix(trimmed, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Handle trailing backslashes by replacing with space
|
|
||||||
if strings.HasSuffix(trimmed, "\\") {
|
|
||||||
cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ")
|
|
||||||
} else {
|
|
||||||
cleanedLines = append(cleanedLines, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// put it back together
|
|
||||||
cmdStr = strings.Join(cleanedLines, "\n")
|
|
||||||
|
|
||||||
// Split the command into arguments
|
|
||||||
var args []string
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
args = shlex.Windows.Split(cmdStr)
|
|
||||||
} else {
|
|
||||||
args = shlex.Posix.Split(cmdStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the command is not empty
|
|
||||||
if len(args) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty command")
|
|
||||||
}
|
|
||||||
|
|
||||||
return args, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func StripComments(cmdStr string) string {
|
|
||||||
var cleanedLines []string
|
|
||||||
for _, line := range strings.Split(cmdStr, "\n") {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
// Skip comment lines
|
|
||||||
if strings.HasPrefix(trimmed, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cleanedLines = append(cleanedLines, line)
|
|
||||||
}
|
|
||||||
return strings.Join(cleanedLines, "\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateMacro validates macro name and value constraints
|
|
||||||
func validateMacro(name string, value any) error {
|
|
||||||
if len(name) >= 64 {
|
|
||||||
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
|
||||||
}
|
|
||||||
if !macroNameRegex.MatchString(name) {
|
|
||||||
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate that value is a scalar type
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
// Check for self-reference
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", name)
|
|
||||||
if strings.Contains(v, macroSlug) {
|
|
||||||
return fmt.Errorf("macro '%s' contains self-reference", name)
|
|
||||||
}
|
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
|
||||||
// These types are allowed
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch name {
|
|
||||||
case "PORT", "MODEL_ID":
|
|
||||||
return fmt.Errorf("macro name '%s' is reserved", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
|
||||||
func validateNestedForUnknownMacros(value any, context string) error {
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
macroName := match[1]
|
|
||||||
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
|
||||||
}
|
|
||||||
// Check for unsubstituted env macros
|
|
||||||
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
|
||||||
for _, match := range envMatches {
|
|
||||||
varName := match[1]
|
|
||||||
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case map[string]any:
|
|
||||||
for _, val := range v {
|
|
||||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case []any:
|
|
||||||
for _, val := range v {
|
|
||||||
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Scalar types don't contain macros
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
|
||||||
// This is called once per macro, allowing LIFO substitution order
|
|
||||||
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
|
||||||
macroSlug := fmt.Sprintf("${%s}", macroName)
|
|
||||||
macroStr := fmt.Sprintf("%v", macroValue)
|
|
||||||
|
|
||||||
switch v := value.(type) {
|
|
||||||
case string:
|
|
||||||
// Check if this is a direct macro substitution
|
|
||||||
if v == macroSlug {
|
|
||||||
return macroValue, nil
|
|
||||||
}
|
|
||||||
// Handle string interpolation
|
|
||||||
if strings.Contains(v, macroSlug) {
|
|
||||||
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
|
||||||
}
|
|
||||||
return v, nil
|
|
||||||
|
|
||||||
case map[string]any:
|
|
||||||
// Recursively process map values
|
|
||||||
newMap := make(map[string]any)
|
|
||||||
for key, val := range v {
|
|
||||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newMap[key] = newVal
|
|
||||||
}
|
|
||||||
return newMap, nil
|
|
||||||
|
|
||||||
case []any:
|
|
||||||
// Recursively process slice elements
|
|
||||||
newSlice := make([]any, len(v))
|
|
||||||
for i, val := range v {
|
|
||||||
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
newSlice[i] = newVal
|
|
||||||
}
|
|
||||||
return newSlice, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Return scalar types as-is
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
|
||||||
// Returns error if any referenced env var is not set or contains invalid characters.
|
|
||||||
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
|
||||||
// (which strips comments) and only checking the comment-free version for macros.
|
|
||||||
func substituteEnvMacros(s string) (string, error) {
|
|
||||||
// Unmarshal and remarshal to strip YAML comments
|
|
||||||
var raw any
|
|
||||||
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
|
||||||
// If YAML is invalid, fall back to scanning the original string
|
|
||||||
// so the user gets the env var error rather than a confusing YAML parse error
|
|
||||||
return substituteEnvMacrosInString(s, s)
|
|
||||||
}
|
|
||||||
clean, err := yaml.Marshal(raw)
|
|
||||||
if err != nil {
|
|
||||||
return substituteEnvMacrosInString(s, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
return substituteEnvMacrosInString(s, string(clean))
|
|
||||||
}
|
|
||||||
|
|
||||||
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
|
||||||
// them in target. This separation allows scanning comment-free YAML while
|
|
||||||
// substituting in the original string.
|
|
||||||
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
|
||||||
result := target
|
|
||||||
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
|
||||||
for _, match := range matches {
|
|
||||||
fullMatch := match[0] // ${env.VAR_NAME}
|
|
||||||
varName := match[1] // VAR_NAME
|
|
||||||
|
|
||||||
value, exists := os.LookupEnv(varName)
|
|
||||||
if !exists {
|
|
||||||
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sanitize the value for safe YAML substitution
|
|
||||||
value, err := sanitizeEnvValueForYAML(value, varName)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
result = strings.ReplaceAll(result, fullMatch, value)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
|
||||||
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
|
||||||
// for compatibility with double-quoted YAML strings.
|
|
||||||
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
|
||||||
// Reject values that would break YAML structure regardless of quoting context
|
|
||||||
if strings.ContainsAny(value, "\n\r\x00") {
|
|
||||||
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
|
||||||
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
|
||||||
// In double-quoted contexts, they are interpreted correctly.
|
|
||||||
value = strings.ReplaceAll(value, `\`, `\\`)
|
|
||||||
value = strings.ReplaceAll(value, `"`, `\"`)
|
|
||||||
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -777,22 +777,27 @@ func TestConfig_APIKeys_Invalid(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "blank spaces only",
|
name: "blank spaces only",
|
||||||
content: `apiKeys: [" "]`,
|
content: `apiKeys: [" "]`,
|
||||||
expectedErr: "api key cannot contain spaces: ` `",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "contains leading space",
|
name: "contains leading space",
|
||||||
content: `apiKeys: [" key123"]`,
|
content: `apiKeys: [" key123"]`,
|
||||||
expectedErr: "api key cannot contain spaces: ` key123`",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "contains trailing space",
|
name: "contains trailing space",
|
||||||
content: `apiKeys: ["key123 "]`,
|
content: `apiKeys: ["key123 "]`,
|
||||||
expectedErr: "api key cannot contain spaces: `key123 `",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "contains middle space",
|
name: "contains middle space",
|
||||||
content: `apiKeys: ["key 123"]`,
|
content: `apiKeys: ["key 123"]`,
|
||||||
expectedErr: "api key cannot contain spaces: `key 123`",
|
expectedErr: "apiKeys[0]: api key cannot contain spaces",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "space in second key reports correct index",
|
||||||
|
content: `apiKeys: ["valid-key", "bad key"]`,
|
||||||
|
expectedErr: "apiKeys[1]: api key cannot contain spaces",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty in list with valid keys",
|
name: "empty in list with valid keys",
|
||||||
|
|||||||
@@ -0,0 +1,436 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
yamlStr := string(data)
|
||||||
|
|
||||||
|
// Phase 1: Substitute all ${env.VAR} macros at string level
|
||||||
|
// This is safe because env values are simple strings without YAML formatting
|
||||||
|
yamlStr, err = substituteEnvMacros(yamlStr)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal into full Config with defaults
|
||||||
|
config := Config{
|
||||||
|
HealthCheckTimeout: 120,
|
||||||
|
StartPort: 5800,
|
||||||
|
LogLevel: "info",
|
||||||
|
LogTimeFormat: "",
|
||||||
|
LogToStdout: LogToStdoutProxy,
|
||||||
|
MetricsMaxInMemory: 1000,
|
||||||
|
CaptureBuffer: 5,
|
||||||
|
GlobalTTL: 0,
|
||||||
|
}
|
||||||
|
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HealthCheckTimeout < 15 {
|
||||||
|
config.HealthCheckTimeout = 15
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults for performance config when section is missing
|
||||||
|
if config.Performance.Every == 0 {
|
||||||
|
config.Performance.Every = 5 * time.Second
|
||||||
|
}
|
||||||
|
if err = config.Performance.Validate(); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("performance: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.StartPort < 1 {
|
||||||
|
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.GlobalTTL < 0 {
|
||||||
|
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply default for upstream.ignorePaths when not specified. The default
|
||||||
|
// matches common static-asset suffixes so they do not trigger a swap.
|
||||||
|
if len(config.Upstream.IgnorePaths) == 0 {
|
||||||
|
config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.LogToStdout {
|
||||||
|
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the aliases map
|
||||||
|
config.aliases = make(map[string]string)
|
||||||
|
for modelName, modelConfig := range config.Models {
|
||||||
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
if _, found := config.aliases[alias]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName)
|
||||||
|
}
|
||||||
|
config.aliases[alias] = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate global macros
|
||||||
|
for _, macro := range config.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get and sort all model IDs for consistent port assignment
|
||||||
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
|
for modelId := range config.Models {
|
||||||
|
modelIds = append(modelIds, modelId)
|
||||||
|
}
|
||||||
|
sort.Strings(modelIds)
|
||||||
|
|
||||||
|
nextPort := config.StartPort
|
||||||
|
for _, modelId := range modelIds {
|
||||||
|
modelConfig := config.Models[modelId]
|
||||||
|
modelConfig.HealthCheckTimeout = config.HealthCheckTimeout
|
||||||
|
|
||||||
|
// Strip comments from command fields
|
||||||
|
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||||
|
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||||
|
|
||||||
|
// set model TTL to globalTTL it is the default value
|
||||||
|
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||||
|
modelConfig.UnloadAfter = config.GlobalTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelConfig.UnloadAfter < 0 {
|
||||||
|
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate model macros
|
||||||
|
for _, macro := range modelConfig.Macros {
|
||||||
|
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build merged macro list: MODEL_ID + global macros + model macros (model overrides global)
|
||||||
|
mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1)
|
||||||
|
mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId})
|
||||||
|
mergedMacros = append(mergedMacros, config.Macros...)
|
||||||
|
|
||||||
|
// Add model macros (override globals with same name)
|
||||||
|
for _, entry := range modelConfig.Macros {
|
||||||
|
found := false
|
||||||
|
for i, existing := range mergedMacros {
|
||||||
|
if existing.Name == entry.Name {
|
||||||
|
mergedMacros[i] = entry
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
mergedMacros = append(mergedMacros, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Substitute remaining macros in model fields (LIFO order)
|
||||||
|
for i := len(mergedMacros) - 1; i >= 0; i-- {
|
||||||
|
entry := mergedMacros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
|
||||||
|
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||||
|
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute macros in SetParamsByID keys and values
|
||||||
|
if len(modelConfig.Filters.SetParamsByID) > 0 {
|
||||||
|
newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID))
|
||||||
|
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||||
|
newKey := strings.ReplaceAll(key, macroSlug, macroStr)
|
||||||
|
newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
newParamMap, ok := newValAny.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId)
|
||||||
|
}
|
||||||
|
newSetParamsByID[newKey] = newParamMap
|
||||||
|
}
|
||||||
|
modelConfig.Filters.SetParamsByID = newSetParamsByID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Substitute in metadata (type-preserving)
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle PORT macro - only allocate if cmd uses it
|
||||||
|
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
|
||||||
|
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
|
||||||
|
if cmdHasPort || proxyHasPort {
|
||||||
|
if !cmdHasPort && proxyHasPort {
|
||||||
|
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
macroSlug := "${PORT}"
|
||||||
|
macroStr := fmt.Sprintf("%v", nextPort)
|
||||||
|
|
||||||
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
|
||||||
|
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
|
||||||
|
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
|
||||||
|
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
|
||||||
|
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
|
||||||
|
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error())
|
||||||
|
}
|
||||||
|
modelConfig.Metadata = result.(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextPort++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate no unknown macros remain
|
||||||
|
fieldMap := map[string]string{
|
||||||
|
"cmd": modelConfig.Cmd,
|
||||||
|
"cmdStop": modelConfig.CmdStop,
|
||||||
|
"proxy": modelConfig.Proxy,
|
||||||
|
"checkEndpoint": modelConfig.CheckEndpoint,
|
||||||
|
"filters.stripParams": modelConfig.Filters.StripParams,
|
||||||
|
"name": modelConfig.Name,
|
||||||
|
"description": modelConfig.Description,
|
||||||
|
}
|
||||||
|
|
||||||
|
for fieldName, fieldValue := range fieldMap {
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
if macroName == "PID" && fieldName == "cmdStop" {
|
||||||
|
continue // replaced at runtime
|
||||||
|
}
|
||||||
|
if macroName == "PORT" || macroName == "MODEL_ID" {
|
||||||
|
return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(modelConfig.Metadata) > 0 {
|
||||||
|
if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = modelConfig.Capabilities.Validate(); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate SetParamsByID keys and values
|
||||||
|
for key, paramMap := range modelConfig.Filters.SetParamsByID {
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId)
|
||||||
|
}
|
||||||
|
if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-register setParamsByID keys as aliases (skip the model's own ID)
|
||||||
|
for key := range modelConfig.Filters.SetParamsByID {
|
||||||
|
if key == modelId {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := config.Models[key]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key)
|
||||||
|
}
|
||||||
|
if existingModel, exists := config.aliases[key]; exists {
|
||||||
|
if existingModel != modelId {
|
||||||
|
return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel)
|
||||||
|
}
|
||||||
|
continue // already registered as explicit alias for this model
|
||||||
|
}
|
||||||
|
config.aliases[key] = modelId
|
||||||
|
modelConfig.Aliases = append(modelConfig.Aliases, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelConfig.SendLoadingState == nil {
|
||||||
|
v := config.SendLoadingState
|
||||||
|
modelConfig.SendLoadingState = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Models[modelId] = modelConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize routing config. The legacy top-level `matrix`/`groups` keys and
|
||||||
|
// the new `routing.router` block are mutually exclusive: a config may use
|
||||||
|
// either style, never both.
|
||||||
|
hasTopLevel := config.Matrix != nil || len(config.Groups) > 0
|
||||||
|
rtr := config.Routing.Router
|
||||||
|
hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0
|
||||||
|
|
||||||
|
if hasTopLevel && hasRouting {
|
||||||
|
return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasTopLevel {
|
||||||
|
// Both groups and matrix may be defined under routing.router.settings;
|
||||||
|
// routing.router.use selects which one is active, so there is no conflict.
|
||||||
|
rs := config.Routing.Router.Settings
|
||||||
|
switch config.Routing.Router.Use {
|
||||||
|
case "matrix":
|
||||||
|
if rs.Matrix == nil {
|
||||||
|
return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set")
|
||||||
|
}
|
||||||
|
config.Matrix = rs.Matrix
|
||||||
|
case "group", "":
|
||||||
|
config.Groups = rs.Groups
|
||||||
|
default:
|
||||||
|
return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// groups XOR matrix
|
||||||
|
if config.Matrix != nil && len(config.Groups) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Matrix != nil {
|
||||||
|
expandedSets, err := ValidateMatrix(*config.Matrix, config.Models)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("matrix: %w", err)
|
||||||
|
}
|
||||||
|
config.Matrix.ExpandedSets = expandedSets
|
||||||
|
} else {
|
||||||
|
config = AddDefaultGroupToConfig(config)
|
||||||
|
|
||||||
|
// Validate group members
|
||||||
|
memberUsage := make(map[string]string)
|
||||||
|
for groupID, groupConfig := range config.Groups {
|
||||||
|
prevSet := make(map[string]bool)
|
||||||
|
for _, member := range groupConfig.Members {
|
||||||
|
if _, found := prevSet[member]; found {
|
||||||
|
return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID)
|
||||||
|
}
|
||||||
|
prevSet[member] = true
|
||||||
|
|
||||||
|
if existingGroup, exists := memberUsage[member]; exists {
|
||||||
|
return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID)
|
||||||
|
}
|
||||||
|
memberUsage[member] = groupID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the canonical Config.Routing from the effective result. Both legacy
|
||||||
|
// and new-style configs converge here. The Matrix pointer is shared so
|
||||||
|
// ExpandedSets stays in one place.
|
||||||
|
if config.Matrix != nil {
|
||||||
|
config.Routing.Router.Use = "matrix"
|
||||||
|
} else {
|
||||||
|
config.Routing.Router.Use = "group"
|
||||||
|
}
|
||||||
|
config.Routing.Router.Settings.Matrix = config.Matrix
|
||||||
|
config.Routing.Router.Settings.Groups = config.Groups
|
||||||
|
|
||||||
|
if config.Routing.Scheduler.Use == "" {
|
||||||
|
config.Routing.Scheduler.Use = "fifo"
|
||||||
|
}
|
||||||
|
if config.Routing.Scheduler.Use != "fifo" {
|
||||||
|
return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use)
|
||||||
|
}
|
||||||
|
for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority {
|
||||||
|
if _, found := config.RealModelName(modelID); !found {
|
||||||
|
return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up hooks preload
|
||||||
|
if len(config.Hooks.OnStartup.Preload) > 0 {
|
||||||
|
var toPreload []string
|
||||||
|
for _, modelID := range config.Hooks.OnStartup.Preload {
|
||||||
|
modelID = strings.TrimSpace(modelID)
|
||||||
|
if modelID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if real, found := config.RealModelName(modelID); found {
|
||||||
|
toPreload = append(toPreload, real)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Hooks.OnStartup.Preload = toPreload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate API keys (env macros already substituted at string level)
|
||||||
|
for i, apikey := range config.RequiredAPIKeys {
|
||||||
|
if apikey == "" {
|
||||||
|
return Config{}, fmt.Errorf("empty api key found in apiKeys")
|
||||||
|
}
|
||||||
|
if strings.Contains(apikey, " ") {
|
||||||
|
return Config{}, fmt.Errorf("apiKeys[%d]: api key cannot contain spaces", i)
|
||||||
|
}
|
||||||
|
config.RequiredAPIKeys[i] = apikey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process peers with global macro substitution
|
||||||
|
for peerName, peerConfig := range config.Peers {
|
||||||
|
// Substitute global macros (LIFO order)
|
||||||
|
for i := len(config.Macros) - 1; i >= 0; i-- {
|
||||||
|
entry := config.Macros[i]
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", entry.Name)
|
||||||
|
macroStr := fmt.Sprintf("%v", entry.Value)
|
||||||
|
|
||||||
|
peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr)
|
||||||
|
peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr)
|
||||||
|
|
||||||
|
// Substitute in setParams (type-preserving)
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err)
|
||||||
|
}
|
||||||
|
peerConfig.Filters.SetParams = result.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate no unknown macros remain
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 {
|
||||||
|
return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1])
|
||||||
|
}
|
||||||
|
if len(peerConfig.Filters.SetParams) > 0 {
|
||||||
|
if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.Peers[peerName] = peerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||||
|
macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`)
|
||||||
|
envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// validateMacro validates macro name and value constraints
|
||||||
|
func validateMacro(name string, value any) error {
|
||||||
|
if len(name) >= 64 {
|
||||||
|
return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name)
|
||||||
|
}
|
||||||
|
if !macroNameRegex.MatchString(name) {
|
||||||
|
return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that value is a scalar type
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check for self-reference
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", name)
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return fmt.Errorf("macro '%s' contains self-reference", name)
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
|
||||||
|
// These types are allowed
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "PORT", "MODEL_ID":
|
||||||
|
return fmt.Errorf("macro name '%s' is reserved", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures
|
||||||
|
func validateNestedForUnknownMacros(value any, context string) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
matches := macroPatternRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
macroName := match[1]
|
||||||
|
return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName)
|
||||||
|
}
|
||||||
|
// Check for unsubstituted env macros
|
||||||
|
envMatches := envMacroRegex.FindAllStringSubmatch(v, -1)
|
||||||
|
for _, match := range envMatches {
|
||||||
|
varName := match[1]
|
||||||
|
return fmt.Errorf("%s: environment variable '%s' not set", context, varName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
for _, val := range v {
|
||||||
|
if err := validateNestedForUnknownMacros(val, context); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Scalar types don't contain macros
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteMacroInValue recursively substitutes a single macro in a value structure
|
||||||
|
// This is called once per macro, allowing LIFO substitution order
|
||||||
|
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
|
||||||
|
macroSlug := fmt.Sprintf("${%s}", macroName)
|
||||||
|
macroStr := fmt.Sprintf("%v", macroValue)
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
// Check if this is a direct macro substitution
|
||||||
|
if v == macroSlug {
|
||||||
|
return macroValue, nil
|
||||||
|
}
|
||||||
|
// Handle string interpolation
|
||||||
|
if strings.Contains(v, macroSlug) {
|
||||||
|
return strings.ReplaceAll(v, macroSlug, macroStr), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case map[string]any:
|
||||||
|
// Recursively process map values
|
||||||
|
newMap := make(map[string]any)
|
||||||
|
for key, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newMap[key] = newVal
|
||||||
|
}
|
||||||
|
return newMap, nil
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
// Recursively process slice elements
|
||||||
|
newSlice := make([]any, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
newVal, err := substituteMacroInValue(val, macroName, macroValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSlice[i] = newVal
|
||||||
|
}
|
||||||
|
return newSlice, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Return scalar types as-is
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values.
|
||||||
|
// Returns error if any referenced env var is not set or contains invalid characters.
|
||||||
|
// Env macros inside YAML comments are ignored by unmarshalling the YAML first
|
||||||
|
// (which strips comments) and only checking the comment-free version for macros.
|
||||||
|
func substituteEnvMacros(s string) (string, error) {
|
||||||
|
// Unmarshal and remarshal to strip YAML comments
|
||||||
|
var raw any
|
||||||
|
if err := yaml.Unmarshal([]byte(s), &raw); err != nil {
|
||||||
|
// If YAML is invalid, fall back to scanning the original string
|
||||||
|
// so the user gets the env var error rather than a confusing YAML parse error
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
clean, err := yaml.Marshal(raw)
|
||||||
|
if err != nil {
|
||||||
|
return substituteEnvMacrosInString(s, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return substituteEnvMacrosInString(s, string(clean))
|
||||||
|
}
|
||||||
|
|
||||||
|
// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes
|
||||||
|
// them in target. This separation allows scanning comment-free YAML while
|
||||||
|
// substituting in the original string.
|
||||||
|
func substituteEnvMacrosInString(target, scanStr string) (string, error) {
|
||||||
|
result := target
|
||||||
|
matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1)
|
||||||
|
for _, match := range matches {
|
||||||
|
fullMatch := match[0] // ${env.VAR_NAME}
|
||||||
|
varName := match[1] // VAR_NAME
|
||||||
|
|
||||||
|
value, exists := os.LookupEnv(varName)
|
||||||
|
if !exists {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' is not set", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize the value for safe YAML substitution
|
||||||
|
value, err := sanitizeEnvValueForYAML(value, varName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = strings.ReplaceAll(result, fullMatch, value)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution.
|
||||||
|
// It rejects values with characters that break YAML structure and escapes quotes/backslashes
|
||||||
|
// for compatibility with double-quoted YAML strings.
|
||||||
|
func sanitizeEnvValueForYAML(value, varName string) (string, error) {
|
||||||
|
// Reject values that would break YAML structure regardless of quoting context
|
||||||
|
if strings.ContainsAny(value, "\n\r\x00") {
|
||||||
|
return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape backslashes and double quotes for safe use in double-quoted YAML strings.
|
||||||
|
// In unquoted contexts, these escapes appear literally (harmless for most use cases).
|
||||||
|
// In double-quoted contexts, they are interpreted correctly.
|
||||||
|
value = strings.ReplaceAll(value, `\`, `\\`)
|
||||||
|
value = strings.ReplaceAll(value, `"`, `\"`)
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,300 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// identityMapPaths is the set of dotted paths whose direct children are
|
||||||
|
// identity-keyed maps. A child key present in two sources is a hard error;
|
||||||
|
// such keys name discrete entities (a model, a group, a peer, etc.) and a
|
||||||
|
// duplicate means the user has split one entity across files by mistake.
|
||||||
|
var identityMapPaths = map[string]bool{
|
||||||
|
"models": true,
|
||||||
|
"groups": true,
|
||||||
|
"profiles": true,
|
||||||
|
"peers": true,
|
||||||
|
"matrix": true,
|
||||||
|
"routing.router.settings.groups": true,
|
||||||
|
"routing.router.settings.matrix": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfigSources loads and merges configuration from -config (optional)
|
||||||
|
// and -config-dir (optional). At least one must be provided. The -config file
|
||||||
|
// is loaded first; *.yml/*.yaml files directly under -config-dir are then
|
||||||
|
// merged in sorted filename order. The merged document is passed through the
|
||||||
|
// existing LoadConfigFromReader pipeline unchanged.
|
||||||
|
func LoadConfigSources(configPath, configDir string) (Config, error) {
|
||||||
|
if configPath == "" && configDir == "" {
|
||||||
|
return Config{}, fmt.Errorf("at least one of -config or -config-dir must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
var sourcePaths []string
|
||||||
|
|
||||||
|
if configPath != "" {
|
||||||
|
sourcePaths = append(sourcePaths, configPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if configDir != "" {
|
||||||
|
dirFiles, err := listYAMLFiles(configDir)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("-config-dir %s: %w", configDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if configPath != "" {
|
||||||
|
absConfig, err := filepath.Abs(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to resolve -config path: %w", err)
|
||||||
|
}
|
||||||
|
for _, f := range dirFiles {
|
||||||
|
absF, err := filepath.Abs(f)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to resolve config dir file %s: %w", f, err)
|
||||||
|
}
|
||||||
|
if absConfig == absF {
|
||||||
|
return Config{}, fmt.Errorf("-config path %s is also present in -config-dir %s; remove it from one", configPath, configDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sourcePaths = append(sourcePaths, dirFiles...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sourcePaths) == 0 {
|
||||||
|
return Config{}, fmt.Errorf("no configuration sources found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var merged *yaml.Node
|
||||||
|
for _, p := range sourcePaths {
|
||||||
|
node, err := parseSource(p)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
if node == nil {
|
||||||
|
continue // empty file
|
||||||
|
}
|
||||||
|
if merged == nil {
|
||||||
|
merged = node
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := mergeNodes(merged, node, "", p); err != nil {
|
||||||
|
return Config{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if merged == nil {
|
||||||
|
// All sources were empty; run the pipeline on empty input so defaults
|
||||||
|
// and validation still apply (e.g. startPort, performance defaults).
|
||||||
|
return LoadConfigFromReader(strings.NewReader(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := yaml.Marshal(merged)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to marshal merged config: %w", err)
|
||||||
|
}
|
||||||
|
return LoadConfigFromReader(strings.NewReader(string(out)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// listYAMLFiles returns the top-level *.yml and *.yaml files in dir, sorted by
|
||||||
|
// filename for deterministic merge order. Subdirectories are not traversed.
|
||||||
|
func listYAMLFiles(dir string) ([]string, error) {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var files []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := e.Name()
|
||||||
|
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
files = append(files, filepath.Join(dir, name))
|
||||||
|
}
|
||||||
|
sort.Strings(files)
|
||||||
|
return files, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSource reads and parses one YAML config file into a root mapping node.
|
||||||
|
// Returns a nil node (no error) when the file is empty or contains only
|
||||||
|
// comments.
|
||||||
|
//
|
||||||
|
// Env macros (${env.VAR}) are substituted at the string level before YAML
|
||||||
|
// parsing so that flow-style constructs like [${env.API_KEY}] parse
|
||||||
|
// correctly — the brace would otherwise be interpreted as a flow mapping.
|
||||||
|
func parseSource(path string) (*yaml.Node, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
yamlStr, err := substituteEnvMacros(string(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
var doc yaml.Node
|
||||||
|
if err := yaml.Unmarshal([]byte(yamlStr), &doc); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
// yaml.Unmarshal into a yaml.Node yields a DocumentNode whose Content[0]
|
||||||
|
// is the actual root. Unwrap it so callers see the real top-level node.
|
||||||
|
root := &doc
|
||||||
|
if root.Kind == yaml.DocumentNode && len(root.Content) > 0 {
|
||||||
|
root = root.Content[0]
|
||||||
|
}
|
||||||
|
if root.Kind == 0 || root.Content == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if root.Kind != yaml.MappingNode {
|
||||||
|
return nil, fmt.Errorf("config %s: top-level YAML must be a mapping", path)
|
||||||
|
}
|
||||||
|
return root, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeNodes merges src into dst (both MappingNodes) in place. Keys present in
|
||||||
|
// only one side are kept; shared keys are merged recursively under the rules
|
||||||
|
// in mergeValue. srcPath is included in error messages to identify the file
|
||||||
|
// that introduced the conflict.
|
||||||
|
func mergeNodes(dst, src *yaml.Node, path, srcPath string) error {
|
||||||
|
srcIdx := indexMapping(src)
|
||||||
|
|
||||||
|
// First pass: merge shared keys in place.
|
||||||
|
for i := 0; i+1 < len(dst.Content); i += 2 {
|
||||||
|
keyNode := dst.Content[i]
|
||||||
|
dstVal := dst.Content[i+1]
|
||||||
|
key := keyNode.Value
|
||||||
|
|
||||||
|
srcVal, ok := srcIdx[key]
|
||||||
|
if !ok {
|
||||||
|
continue // dst-only key, keep as-is
|
||||||
|
}
|
||||||
|
|
||||||
|
childPath := joinPath(path, key)
|
||||||
|
|
||||||
|
if identityMapPaths[childPath] {
|
||||||
|
// Identity-keyed map: each child key names a discrete entity
|
||||||
|
// (a model, group, peer, ...). A shared child key is a hard
|
||||||
|
// error; src-only children are appended in the second pass.
|
||||||
|
if err := mergeIdentityMap(dstVal, srcVal, childPath, key, srcPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mergeValue(dstVal, srcVal, childPath, srcPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: append src-only keys.
|
||||||
|
dstIdx := indexMapping(dst)
|
||||||
|
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||||
|
keyNode := src.Content[i]
|
||||||
|
srcVal := src.Content[i+1]
|
||||||
|
key := keyNode.Value
|
||||||
|
|
||||||
|
if _, ok := dstIdx[key]; ok {
|
||||||
|
continue // already merged above
|
||||||
|
}
|
||||||
|
keyCopy := *keyNode
|
||||||
|
valCopy := *srcVal
|
||||||
|
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeIdentityMap merges two identity-keyed mapping nodes (e.g. `models`,
|
||||||
|
// `groups`, `peers`). Any child key present in both sides is a duplicate
|
||||||
|
// entity and produces an error naming the conflicting key and source file.
|
||||||
|
// src-only keys are appended to dst.
|
||||||
|
func mergeIdentityMap(dst, src *yaml.Node, path, mapName, srcPath string) error {
|
||||||
|
if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode {
|
||||||
|
return fmt.Errorf("conflict at %q: expected a mapping, introduced by %s", path, srcPath)
|
||||||
|
}
|
||||||
|
dstIdx := indexMapping(dst)
|
||||||
|
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||||
|
keyNode := src.Content[i]
|
||||||
|
srcVal := src.Content[i+1]
|
||||||
|
key := keyNode.Value
|
||||||
|
if _, dup := dstIdx[key]; dup {
|
||||||
|
return fmt.Errorf("duplicate %s %q found in %s (already defined in another config source)", mapName, key, srcPath)
|
||||||
|
}
|
||||||
|
keyCopy := *keyNode
|
||||||
|
valCopy := *srcVal
|
||||||
|
dst.Content = append(dst.Content, &keyCopy, &valCopy)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeValue merges srcVal into dstVal (both pointing into the parent's
|
||||||
|
// Content slice). Mapping+Mapping recurses; Sequence+Sequence concatenates;
|
||||||
|
// Scalar+Scalar errors on value mismatch; null on either side yields to the
|
||||||
|
// non-null side.
|
||||||
|
func mergeValue(dstVal, srcVal *yaml.Node, path, srcPath string) error {
|
||||||
|
switch {
|
||||||
|
case dstVal.Kind == yaml.MappingNode && srcVal.Kind == yaml.MappingNode:
|
||||||
|
return mergeNodes(dstVal, srcVal, path, srcPath)
|
||||||
|
|
||||||
|
case dstVal.Kind == yaml.SequenceNode && srcVal.Kind == yaml.SequenceNode:
|
||||||
|
dstVal.Content = append(dstVal.Content, srcVal.Content...)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case dstVal.Kind == yaml.ScalarNode && srcVal.Kind == yaml.ScalarNode:
|
||||||
|
if isNullScalar(dstVal) {
|
||||||
|
*dstVal = *srcVal
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if isNullScalar(srcVal) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if dstVal.Value == srcVal.Value && dstVal.Tag == srcVal.Tag {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("conflict at %q: %s sets a different value than a previous source", path, srcPath)
|
||||||
|
|
||||||
|
case isNull(dstVal):
|
||||||
|
*dstVal = *srcVal
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case isNull(srcVal):
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("conflict at %q: incompatible YAML node kinds (kind %d vs %d) introduced by %s", path, dstVal.Kind, srcVal.Kind, srcPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNull reports whether n represents a YAML null (empty or !!null).
|
||||||
|
func isNull(n *yaml.Node) bool {
|
||||||
|
if n == nil || n.Kind == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return isNullScalar(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNullScalar(n *yaml.Node) bool {
|
||||||
|
return n.Kind == yaml.ScalarNode && (n.Tag == "!!null" || n.Tag == "") && n.Value == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// indexMapping builds a key -> value-node index for a mapping node.
|
||||||
|
func indexMapping(n *yaml.Node) map[string]*yaml.Node {
|
||||||
|
idx := make(map[string]*yaml.Node, len(n.Content)/2)
|
||||||
|
for i := 0; i+1 < len(n.Content); i += 2 {
|
||||||
|
idx[n.Content[i].Value] = n.Content[i+1]
|
||||||
|
}
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinPath(parent, key string) string {
|
||||||
|
if parent == "" {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
return parent + "." + key
|
||||||
|
}
|
||||||
@@ -0,0 +1,304 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeYAML writes content to a file named name inside dir. Returns the full
|
||||||
|
// path of the written file.
|
||||||
|
func writeYAML(t *testing.T, dir, name, content string) string {
|
||||||
|
t.Helper()
|
||||||
|
p := filepath.Join(dir, name)
|
||||||
|
require.NoError(t, os.MkdirAll(filepath.Dir(p), 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(p, []byte(content), 0o644))
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelCfg builds a single-model YAML snippet indented for nesting under a
|
||||||
|
// `models:` key. The proxy uses a fixed port so tests don't depend on
|
||||||
|
// ${PORT} allocation.
|
||||||
|
func modelCfg(id, cmd string) string {
|
||||||
|
return " " + id + ":\n cmd: " + cmd + "\n proxy: \"http://localhost:9999\"\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_NeitherProvided(t *testing.T) {
|
||||||
|
_, err := LoadConfigSources("", "")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "at least one of -config or -config-dir")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ConfigOnly(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "config.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("model1", "echo hi")+`
|
||||||
|
groups:
|
||||||
|
group1:
|
||||||
|
members: ["model1"]
|
||||||
|
`)
|
||||||
|
cfg, err := LoadConfigSources(cfgPath, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, id, ok := cfg.FindConfig("model1")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "model1", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DirOnly(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("alpha", "echo a"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("beta", "echo b"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, want := range []string{"alpha", "beta"} {
|
||||||
|
_, _, ok := cfg.FindConfig(want)
|
||||||
|
assert.True(t, ok, "model %s should be present", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ConfigPlusDirAdditive(t *testing.T) {
|
||||||
|
// -config lives outside -config-dir; both contribute models additively.
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "config.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||||
|
cfgDir := t.TempDir()
|
||||||
|
writeYAML(t, cfgDir, "extra.yaml", "models:\n"+modelCfg("ext", "echo ext"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, want := range []string{"base", "ext"} {
|
||||||
|
_, _, ok := cfg.FindConfig(want)
|
||||||
|
assert.True(t, ok, "model %s should be present after merge", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadConfigSources_ConfigInDirOverlap verifies that a -config file that
|
||||||
|
// is also a member of -config-dir is rejected.
|
||||||
|
func TestLoadConfigSources_ConfigInDirOverlap(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("base", "echo base"))
|
||||||
|
|
||||||
|
_, err := LoadConfigSources(cfgPath, dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "is also present in -config-dir")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DuplicateModelID(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("dup", "echo a"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("dup", "echo b"))
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `duplicate models "dup"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DuplicateGroupID(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m1", "echo m1")+"groups:\n g1:\n members: [m1]\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m2", "echo m2")+"groups:\n g1:\n members: [m2]\n")
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `duplicate groups "g1"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_DuplicatePeer(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
peerA := "peers:\n remote:\n proxy: http://x:1\n models: [m1]\n"
|
||||||
|
peerB := "peers:\n remote:\n proxy: http://x:2\n models: [m2]\n"
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\n"+peerA)
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\n"+peerB)
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `duplicate peers "remote"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ScalarConflict(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 200\n")
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), `conflict at "globalTTL"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_ScalarSameValueNoConflict(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 100\n")
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 100, cfg.GlobalTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_MacrosConcatenate(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "macros:\n LOW: 1\nmodels:\n"+modelCfg("m1", "echo ${LOW}"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "macros:\n HIGH: 2\nmodels:\n"+modelCfg("m2", "echo ${HIGH}"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Both macros are available globally after merge.
|
||||||
|
low, ok := cfg.Macros.Get("LOW")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, 1, low)
|
||||||
|
high, ok := cfg.Macros.Get("HIGH")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, 2, high)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_APIKeysConcatenate(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\napiKeys: [key-a]\n")
|
||||||
|
writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\napiKeys: [key-b]\n")
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.ElementsMatch(t, []string{"key-a", "key-b"}, cfg.RequiredAPIKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_RoutingGroupsMerge(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m1", "echo m1")+`
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
settings:
|
||||||
|
groups:
|
||||||
|
groupA:
|
||||||
|
members: [m1]
|
||||||
|
`)
|
||||||
|
writeYAML(t, dir, "b.yaml", `
|
||||||
|
models:
|
||||||
|
`+modelCfg("m2", "echo m2")+`
|
||||||
|
routing:
|
||||||
|
router:
|
||||||
|
settings:
|
||||||
|
groups:
|
||||||
|
groupB:
|
||||||
|
members: [m2]
|
||||||
|
`)
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
groups := cfg.Routing.Router.Settings.Groups
|
||||||
|
assert.Contains(t, groups, "groupA")
|
||||||
|
assert.Contains(t, groups, "groupB")
|
||||||
|
// default group added by pipeline for orphaned/leftover routing groups...
|
||||||
|
// here both groups reference distinct models
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EnvMacrosSubstituted(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
// Use ${PORT} in cmd so the pipeline allocates a port and substitutes it;
|
||||||
|
// verifies env/macro substitution runs on the merged document.
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: serve --port ${PORT}\n proxy: \"http://localhost:${PORT}\"\n")
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m := cfg.Models["m1"]
|
||||||
|
assert.NotContains(t, m.Cmd, "${PORT}", "PORT macro should have been substituted")
|
||||||
|
assert.NotContains(t, m.Proxy, "${PORT}", "PORT macro should have been substituted in proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EnvMacroInFlowStyleList(t *testing.T) {
|
||||||
|
// Regression: flow-style lists with ${env.*} must parse. Previously
|
||||||
|
// parseSource unmarshalled before env substitution, so the brace in
|
||||||
|
// [${env.API_KEY}] was misread as a flow mapping and parsing failed.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: echo hi\n proxy: \"http://localhost:9999\"\n")
|
||||||
|
writeYAML(t, dir, "keys.yaml", "apiKeys: [${env.TEST_API_KEY}]\nmodels:\n m2:\n cmd: echo hi\n proxy: \"http://localhost:9998\"\n")
|
||||||
|
|
||||||
|
t.Setenv("TEST_API_KEY", "secret123")
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, cfg.RequiredAPIKeys, "secret123")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_SortedOrderDeterministic(t *testing.T) {
|
||||||
|
// Two files defining distinct models, scanned in z..a order by filename.
|
||||||
|
// Determine merged result is the same regardless of how the FS returns them.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "z.yaml", "models:\n"+modelCfg("zmodel", "echo z"))
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("amodel", "echo a"))
|
||||||
|
|
||||||
|
const runs = 3
|
||||||
|
for i := 0; i < runs; i++ {
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// startPort-based allocation: first allocated model gets 5800.
|
||||||
|
// Sorted order means amodel gets 5800, zmodel gets 5801.
|
||||||
|
_, _, ok := cfg.FindConfig("amodel")
|
||||||
|
assert.True(t, ok)
|
||||||
|
_, _, ok = cfg.FindConfig("zmodel")
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EmptyDirWithConfig(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
cfgDir := t.TempDir()
|
||||||
|
cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources(cfgPath, cfgDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, cfg.Models, "m1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_EmptyDirOnly(t *testing.T) {
|
||||||
|
// An empty -config-dir with no -config is an error: there is nothing to
|
||||||
|
// load and silently producing an empty config would mask the misconfig.
|
||||||
|
cfgDir := t.TempDir()
|
||||||
|
_, err := LoadConfigSources("", cfgDir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no configuration sources found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_AssertNoUnknownMacrosAfterMerge(t *testing.T) {
|
||||||
|
// Macros defined in one file should not satisfy unknown-macro validation in
|
||||||
|
// another — they do, because merge concats global macros before validation
|
||||||
|
// runs. This test documents that a macro from file A is usable in file B.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "macros.yaml", "macros:\n SHARED: hello\nmodels:\n"+modelCfg("dummy", "echo dummy"))
|
||||||
|
writeYAML(t, dir, "use.yaml", "models:\n"+modelCfg("user", "echo ${SHARED}"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
m := cfg.Models["user"]
|
||||||
|
assert.Contains(t, m.Cmd, "hello")
|
||||||
|
assert.NotContains(t, m.Cmd, "${SHARED}")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_KindMismatchErrors(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "startPort: 5800\nmodels:\n"+modelCfg("m1", "echo m1"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "startPort: [5800, 5801]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||||
|
|
||||||
|
_, err := LoadConfigSources("", dir)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "incompatible YAML node kinds")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigSources_NullYieldsToValue(t *testing.T) {
|
||||||
|
// File A: routing.router block absent (null on root for routing);
|
||||||
|
// file B: defines routing.router.settings.groups. Merge should keep B's.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1"))
|
||||||
|
writeYAML(t, dir, "b.yaml", "routing:\n router:\n settings:\n groups:\n g1:\n members: [m1]\nmodels:\n"+modelCfg("m2", "echo m2"))
|
||||||
|
|
||||||
|
cfg, err := LoadConfigSources("", dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1")
|
||||||
|
}
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DirWatcher polls a directory for changes to its set of *.yml / *.yaml files.
|
||||||
|
// It fires OnChange when a file is added, removed, or has its mod time/size
|
||||||
|
// change. Like Watcher it is poll-based so it works in Docker bind-mounts and
|
||||||
|
// k8s ConfigMap projections where inotify is unreliable.
|
||||||
|
//
|
||||||
|
// The baseline poll establishes initial state and does not fire OnChange.
|
||||||
|
type DirWatcher struct {
|
||||||
|
Path string
|
||||||
|
Interval time.Duration
|
||||||
|
OnChange func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// dirSnapshot is an ordered map of file name -> file state. The ordering is
|
||||||
|
// derived from sorted filenames so two snapshots compare deterministically
|
||||||
|
// regardless of readdir order. exists reflects whether the directory was
|
||||||
|
// readable at scan time; a missing directory yields exists=false.
|
||||||
|
type dirSnapshot struct {
|
||||||
|
exists bool
|
||||||
|
names []string
|
||||||
|
states map[string]snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDirSnapshot() dirSnapshot {
|
||||||
|
return dirSnapshot{states: make(map[string]snapshot)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equal reports whether two snapshots describe the same file set and per-file
|
||||||
|
// state. A missing directory (exists=false) is treated as equal to any other
|
||||||
|
// missing directory regardless of cached names.
|
||||||
|
func (s dirSnapshot) equal(other dirSnapshot) bool {
|
||||||
|
if !s.exists && !other.exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if s.exists != other.exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(s.names) != len(other.names) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, n := range s.names {
|
||||||
|
if other.names[i] != n {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, n := range s.names {
|
||||||
|
a, b := s.states[n], other.states[n]
|
||||||
|
if a.exists != b.exists || a.size != b.size || !a.modTime.Equal(b.modTime) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run blocks until ctx is canceled. It polls Path on Interval and invokes
|
||||||
|
// OnChange whenever the directory's YAML file set changes.
|
||||||
|
//
|
||||||
|
// Policy mirrors the single-file Watcher: disappearance (directory missing or
|
||||||
|
// empty) is treated as a transient rename-style write and stays quiet; the
|
||||||
|
// transition back to present-with-content fires OnChange.
|
||||||
|
func (w *DirWatcher) Run(ctx context.Context) {
|
||||||
|
interval := w.Interval
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = DefaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
prev := scanDir(w.Path)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
cur := scanDir(w.Path)
|
||||||
|
// Suppress transitions involving an empty or missing directory —
|
||||||
|
// these are treated as transient rename-style writes, mirroring
|
||||||
|
// the single-file Watcher. Only present-with-content →
|
||||||
|
// present-with-content (changed) or no-content →
|
||||||
|
// present-with-content fires OnChange.
|
||||||
|
prevHasContent := prev.exists && len(prev.names) > 0
|
||||||
|
curHasContent := cur.exists && len(cur.names) > 0
|
||||||
|
if curHasContent && (!prevHasContent || !prev.equal(cur)) && w.OnChange != nil {
|
||||||
|
w.OnChange()
|
||||||
|
}
|
||||||
|
prev = cur
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanDir returns a snapshot of the *.yml/*.yaml files in dir. If the
|
||||||
|
// directory cannot be read (missing, permission denied) the snapshot reports
|
||||||
|
// exists=false; the next successful scan will detect the recovery and fire
|
||||||
|
// OnChange.
|
||||||
|
func scanDir(dir string) dirSnapshot {
|
||||||
|
snap := newDirSnapshot()
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return snap // exists=false
|
||||||
|
}
|
||||||
|
snap.exists = true
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := e.Name()
|
||||||
|
if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fi, err := os.Stat(filepath.Join(dir, name))
|
||||||
|
if err != nil {
|
||||||
|
// File disappeared between ReadDir and Stat; skip it — the
|
||||||
|
// next poll will observe the removal cleanly.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snap.names = append(snap.names, name)
|
||||||
|
snap.states[name] = snapshot{
|
||||||
|
exists: true,
|
||||||
|
modTime: fi.ModTime(),
|
||||||
|
size: fi.Size(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(snap.names)
|
||||||
|
return snap
|
||||||
|
}
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
package configwatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startDirWatcher launches w.Run in a goroutine and returns a function that
|
||||||
|
// cancels the context and waits for Run to return.
|
||||||
|
func startDirWatcher(t *testing.T, w *DirWatcher) func() {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
w.Run(ctx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("DirWatcher did not stop within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeYAMLInDir(t *testing.T, dir, name, content string) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_NoFireOnBaseline(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 5)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_DetectsFileAdd(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is added")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_DetectsFileRemoval(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
writeYAMLInDir(t, dir, "b.yaml", "b")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.Remove(filepath.Join(dir, "b.yaml")))
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_DetectsModTimeChange(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
base := time.Now().Add(-1 * time.Hour).Truncate(time.Second)
|
||||||
|
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base, base))
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base.Add(10*time.Second), base.Add(10*time.Second)))
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_IgnoresNonYAMLFiles(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Adding a .txt file must not fire.
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0o644))
|
||||||
|
time.Sleep(testInterval * 4)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "non-YAML files must be ignored")
|
||||||
|
|
||||||
|
// Adding a .yml file must fire.
|
||||||
|
writeYAMLInDir(t, dir, "b.yml", "b")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire for *.yml files")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_MissingDirRecovers(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Remove the directory. No fire expected on disappearance alone.
|
||||||
|
require.NoError(t, os.RemoveAll(dir))
|
||||||
|
time.Sleep(testInterval * 3)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "directory removal alone must not fire")
|
||||||
|
|
||||||
|
// Recreate the directory and a YAML file; the recovery should fire.
|
||||||
|
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||||
|
writeYAMLInDir(t, dir, "recovered.yaml", "r")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when dir returns with content")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_EmptyDirSuppressedThenRecovers(t *testing.T) {
|
||||||
|
// Present-with-content → empty (all YAML removed, dir still exists)
|
||||||
|
// must stay quiet — treated as transient per the documented policy.
|
||||||
|
// The transition back to content fires.
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
var n int64
|
||||||
|
stop := startDirWatcher(t, &DirWatcher{
|
||||||
|
Path: dir,
|
||||||
|
Interval: testInterval,
|
||||||
|
OnChange: func() { atomic.AddInt64(&n, 1) },
|
||||||
|
})
|
||||||
|
defer stop()
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
|
||||||
|
// Remove the only YAML file. Dir still exists but is empty of YAML.
|
||||||
|
require.NoError(t, os.Remove(filepath.Join(dir, "a.yaml")))
|
||||||
|
time.Sleep(testInterval * 4)
|
||||||
|
require.Equal(t, int64(0), atomic.LoadInt64(&n), "emptying the directory must not fire")
|
||||||
|
|
||||||
|
// Add a YAML file back; transition to present-with-content fires.
|
||||||
|
writeYAMLInDir(t, dir, "c.yaml", "c")
|
||||||
|
require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when content returns")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirWatcher_ContextCancelStopsRun(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
writeYAMLInDir(t, dir, "a.yaml", "a")
|
||||||
|
|
||||||
|
w := &DirWatcher{Path: dir, Interval: testInterval}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() { w.Run(ctx); close(done) }()
|
||||||
|
|
||||||
|
time.Sleep(testInterval * 2)
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Run did not return within 2s of cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
+37
-19
@@ -55,7 +55,8 @@ var logTimeFormats = map[string]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flagConfig := flag.String("config", "", "path to config file (required)")
|
flagConfig := flag.String("config", "", "path to config file")
|
||||||
|
flagConfigDir := flag.String("config-dir", "", "directory of *.yml/*.yaml config files (additive to -config)")
|
||||||
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
|
flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)")
|
||||||
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file")
|
||||||
flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
|
flagKeyFile := flag.String("tls-key-file", "", "TLS key file")
|
||||||
@@ -68,8 +69,8 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
if *flagConfig == "" {
|
if *flagConfig == "" && *flagConfigDir == "" {
|
||||||
slog.Error("-config is required")
|
slog.Error("at least one of -config or -config-dir must be provided")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,10 +89,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
configPath := *flagConfig
|
cfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||||
cfg, err := config.LoadConfig(configPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to load config", "path", configPath, "error", err)
|
slog.Error("failed to load config", "config", *flagConfig, "config-dir", *flagConfigDir, "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,7 +187,7 @@ func main() {
|
|||||||
|
|
||||||
proxyLog.Info("reloading configuration")
|
proxyLog.Info("reloading configuration")
|
||||||
|
|
||||||
newCfg, err := config.LoadConfig(configPath)
|
newCfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
proxyLog.Warnf("failed to reload config: %v", err)
|
proxyLog.Warnf("failed to reload config: %v", err)
|
||||||
return
|
return
|
||||||
@@ -230,19 +230,37 @@ func main() {
|
|||||||
defer watcherCancel()
|
defer watcherCancel()
|
||||||
|
|
||||||
if *flagWatchConfig {
|
if *flagWatchConfig {
|
||||||
absConfigPath, err := filepath.Abs(configPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("watch-config: failed to resolve config path", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
proxyLog.Info("watching configuration for changes (poll-based, 2s interval)")
|
||||||
go func() {
|
|
||||||
(&configwatcher.Watcher{
|
if *flagConfig != "" {
|
||||||
Path: absConfigPath,
|
absConfigPath, err := filepath.Abs(*flagConfig)
|
||||||
Interval: configwatcher.DefaultInterval,
|
if err != nil {
|
||||||
OnChange: reload,
|
slog.Error("watch-config: failed to resolve config path", "error", err)
|
||||||
}).Run(watcherCtx)
|
os.Exit(1)
|
||||||
}()
|
}
|
||||||
|
go func() {
|
||||||
|
(&configwatcher.Watcher{
|
||||||
|
Path: absConfigPath,
|
||||||
|
Interval: configwatcher.DefaultInterval,
|
||||||
|
OnChange: reload,
|
||||||
|
}).Run(watcherCtx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if *flagConfigDir != "" {
|
||||||
|
absConfigDir, err := filepath.Abs(*flagConfigDir)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("watch-config: failed to resolve config-dir path", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
(&configwatcher.DirWatcher{
|
||||||
|
Path: absConfigDir,
|
||||||
|
Interval: configwatcher.DefaultInterval,
|
||||||
|
OnChange: reload,
|
||||||
|
}).Run(watcherCtx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user