From 32bc7813261000228ea5274d89a8c89e7cce2725 Mon Sep 17 00:00:00 2001 From: Benson Wong <83972+mostlygeek@users.noreply.github.com> Date: Wed, 24 Jun 2026 20:48:51 -0700 Subject: [PATCH] internal/config,watcher: add -config-dir (#873) Over time the llama-swap configuration file can get really long and challenging to work with. The -config-dir flag is used for a directory of configuration YAML fragments. These fragments are merged together and into a full configuration and tested for validity. All previous configuration functionality remains unchanged. --- internal/config/commands.go | 57 +++ internal/config/config.go | 667 ---------------------------- internal/config/config_test.go | 13 +- internal/config/load.go | 436 ++++++++++++++++++ internal/config/macros.go | 198 +++++++++ internal/config/merge.go | 300 +++++++++++++ internal/config/merge_test.go | 304 +++++++++++++ internal/watcher/dirwatcher.go | 137 ++++++ internal/watcher/dirwatcher_test.go | 199 +++++++++ llama-swap.go | 56 ++- 10 files changed, 1677 insertions(+), 690 deletions(-) create mode 100644 internal/config/commands.go create mode 100644 internal/config/load.go create mode 100644 internal/config/macros.go create mode 100644 internal/config/merge.go create mode 100644 internal/config/merge_test.go create mode 100644 internal/watcher/dirwatcher.go create mode 100644 internal/watcher/dirwatcher_test.go diff --git a/internal/config/commands.go b/internal/config/commands.go new file mode 100644 index 00000000..066228b7 --- /dev/null +++ b/internal/config/commands.go @@ -0,0 +1,57 @@ +package config + +import ( + "fmt" + "runtime" + "strings" + + "github.com/billziss-gh/golib/shlex" +) + +func SanitizeCommand(cmdStr string) ([]string, error) { + var cleanedLines []string + for _, line := range strings.Split(cmdStr, "\n") { + trimmed := strings.TrimSpace(line) + // Skip comment lines + if strings.HasPrefix(trimmed, "#") { + continue + } + // Handle trailing backslashes by replacing with space + if strings.HasSuffix(trimmed, "\\") { + cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ") + } else { + cleanedLines = append(cleanedLines, line) + } + } + + // put it back together + cmdStr = strings.Join(cleanedLines, "\n") + + // Split the command into arguments + var args []string + if runtime.GOOS == "windows" { + args = shlex.Windows.Split(cmdStr) + } else { + args = shlex.Posix.Split(cmdStr) + } + + // Ensure the command is not empty + if len(args) == 0 { + return nil, fmt.Errorf("empty command") + } + + return args, nil +} + +func StripComments(cmdStr string) string { + var cleanedLines []string + for _, line := range strings.Split(cmdStr, "\n") { + trimmed := strings.TrimSpace(line) + // Skip comment lines + if strings.HasPrefix(trimmed, "#") { + continue + } + cleanedLines = append(cleanedLines, line) + } + return strings.Join(cleanedLines, "\n") +} diff --git a/internal/config/config.go b/internal/config/config.go index 3d34b060..090e3512 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,16 +2,9 @@ package config import ( "fmt" - "io" - "net/url" "os" - "regexp" - "runtime" "sort" - "strings" - "time" - "github.com/billziss-gh/golib/shlex" "gopkg.in/yaml.v3" ) @@ -85,12 +78,6 @@ type GroupConfig struct { Members []string `yaml:"members"` } -var ( - macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) - macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`) - envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`) -) - // set default values for GroupConfig func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { type rawGroupConfig GroupConfig @@ -224,430 +211,6 @@ func LoadConfig(path string) (Config, error) { return LoadConfigFromReader(file) } -func LoadConfigFromReader(r io.Reader) (Config, error) { - data, err := io.ReadAll(r) - if err != nil { - return Config{}, err - } - yamlStr := string(data) - - // Phase 1: Substitute all ${env.VAR} macros at string level - // This is safe because env values are simple strings without YAML formatting - yamlStr, err = substituteEnvMacros(yamlStr) - if err != nil { - return Config{}, err - } - - // Unmarshal into full Config with defaults - config := Config{ - HealthCheckTimeout: 120, - StartPort: 5800, - LogLevel: "info", - LogTimeFormat: "", - LogToStdout: LogToStdoutProxy, - MetricsMaxInMemory: 1000, - CaptureBuffer: 5, - GlobalTTL: 0, - } - if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil { - return Config{}, err - } - - if config.HealthCheckTimeout < 15 { - config.HealthCheckTimeout = 15 - } - - // Apply defaults for performance config when section is missing - if config.Performance.Every == 0 { - config.Performance.Every = 5 * time.Second - } - if err = config.Performance.Validate(); err != nil { - return Config{}, fmt.Errorf("performance: %w", err) - } - - if config.StartPort < 1 { - return Config{}, fmt.Errorf("startPort must be greater than 1") - } - - if config.GlobalTTL < 0 { - return Config{}, fmt.Errorf("globalTTL must be >= 0") - } - - // Apply default for upstream.ignorePaths when not specified. The default - // matches common static-asset suffixes so they do not trigger a swap. - if len(config.Upstream.IgnorePaths) == 0 { - config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths() - } - - switch config.LogToStdout { - case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone: - default: - return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none") - } - - // Populate the aliases map - config.aliases = make(map[string]string) - for modelName, modelConfig := range config.Models { - for _, alias := range modelConfig.Aliases { - if _, found := config.aliases[alias]; found { - return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName) - } - config.aliases[alias] = modelName - } - } - - // Validate global macros - for _, macro := range config.Macros { - if err = validateMacro(macro.Name, macro.Value); err != nil { - return Config{}, err - } - } - - // Get and sort all model IDs for consistent port assignment - modelIds := make([]string, 0, len(config.Models)) - for modelId := range config.Models { - modelIds = append(modelIds, modelId) - } - sort.Strings(modelIds) - - nextPort := config.StartPort - for _, modelId := range modelIds { - modelConfig := config.Models[modelId] - modelConfig.HealthCheckTimeout = config.HealthCheckTimeout - - // Strip comments from command fields - modelConfig.Cmd = StripComments(modelConfig.Cmd) - modelConfig.CmdStop = StripComments(modelConfig.CmdStop) - - // set model TTL to globalTTL it is the default value - if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL { - modelConfig.UnloadAfter = config.GlobalTTL - } - - if modelConfig.UnloadAfter < 0 { - return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter) - } - - // Validate model macros - for _, macro := range modelConfig.Macros { - if err = validateMacro(macro.Name, macro.Value); err != nil { - return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error()) - } - } - - // Build merged macro list: MODEL_ID + global macros + model macros (model overrides global) - mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1) - mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId}) - mergedMacros = append(mergedMacros, config.Macros...) - - // Add model macros (override globals with same name) - for _, entry := range modelConfig.Macros { - found := false - for i, existing := range mergedMacros { - if existing.Name == entry.Name { - mergedMacros[i] = entry - found = true - break - } - } - if !found { - mergedMacros = append(mergedMacros, entry) - } - } - - // Substitute remaining macros in model fields (LIFO order) - for i := len(mergedMacros) - 1; i >= 0; i-- { - entry := mergedMacros[i] - macroSlug := fmt.Sprintf("${%s}", entry.Name) - macroStr := fmt.Sprintf("%v", entry.Value) - - modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr) - modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr) - modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr) - modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr) - modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr) - modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr) - modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr) - - // Substitute macros in SetParamsByID keys and values - if len(modelConfig.Filters.SetParamsByID) > 0 { - newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID)) - for key, paramMap := range modelConfig.Filters.SetParamsByID { - newKey := strings.ReplaceAll(key, macroSlug, macroStr) - newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value) - if err != nil { - return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error()) - } - newParamMap, ok := newValAny.(map[string]any) - if !ok { - return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId) - } - newSetParamsByID[newKey] = newParamMap - } - modelConfig.Filters.SetParamsByID = newSetParamsByID - } - - // Substitute in metadata (type-preserving) - if len(modelConfig.Metadata) > 0 { - result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value) - if err != nil { - return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error()) - } - modelConfig.Metadata = result.(map[string]any) - } - } - - // Handle PORT macro - only allocate if cmd uses it - cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}") - proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}") - if cmdHasPort || proxyHasPort { - if !cmdHasPort && proxyHasPort { - return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId) - } - - macroSlug := "${PORT}" - macroStr := fmt.Sprintf("%v", nextPort) - - modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr) - modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr) - modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr) - modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr) - modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr) - - if len(modelConfig.Metadata) > 0 { - result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort) - if err != nil { - return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error()) - } - modelConfig.Metadata = result.(map[string]any) - } - - nextPort++ - } - - // Validate no unknown macros remain - fieldMap := map[string]string{ - "cmd": modelConfig.Cmd, - "cmdStop": modelConfig.CmdStop, - "proxy": modelConfig.Proxy, - "checkEndpoint": modelConfig.CheckEndpoint, - "filters.stripParams": modelConfig.Filters.StripParams, - "name": modelConfig.Name, - "description": modelConfig.Description, - } - - for fieldName, fieldValue := range fieldMap { - matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1) - for _, match := range matches { - macroName := match[1] - if macroName == "PID" && fieldName == "cmdStop" { - continue // replaced at runtime - } - if macroName == "PORT" || macroName == "MODEL_ID" { - return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName) - } - return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName) - } - } - - if len(modelConfig.Metadata) > 0 { - if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil { - return Config{}, err - } - } - - if err = modelConfig.Capabilities.Validate(); err != nil { - return Config{}, fmt.Errorf("model %s: %w", modelId, err) - } - - // Validate SetParamsByID keys and values - for key, paramMap := range modelConfig.Filters.SetParamsByID { - if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 { - return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId) - } - if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil { - return Config{}, err - } - } - - // Auto-register setParamsByID keys as aliases (skip the model's own ID) - for key := range modelConfig.Filters.SetParamsByID { - if key == modelId { - continue - } - if _, exists := config.Models[key]; exists { - return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key) - } - if existingModel, exists := config.aliases[key]; exists { - if existingModel != modelId { - return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel) - } - continue // already registered as explicit alias for this model - } - config.aliases[key] = modelId - modelConfig.Aliases = append(modelConfig.Aliases, key) - } - - if _, err := url.Parse(modelConfig.Proxy); err != nil { - return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err) - } - - if modelConfig.SendLoadingState == nil { - v := config.SendLoadingState - modelConfig.SendLoadingState = &v - } - - config.Models[modelId] = modelConfig - } - - // Normalize routing config. The legacy top-level `matrix`/`groups` keys and - // the new `routing.router` block are mutually exclusive: a config may use - // either style, never both. - hasTopLevel := config.Matrix != nil || len(config.Groups) > 0 - rtr := config.Routing.Router - hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0 - - if hasTopLevel && hasRouting { - return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them") - } - - if !hasTopLevel { - // Both groups and matrix may be defined under routing.router.settings; - // routing.router.use selects which one is active, so there is no conflict. - rs := config.Routing.Router.Settings - switch config.Routing.Router.Use { - case "matrix": - if rs.Matrix == nil { - return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set") - } - config.Matrix = rs.Matrix - case "group", "": - config.Groups = rs.Groups - default: - return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use) - } - } - - // groups XOR matrix - if config.Matrix != nil && len(config.Groups) > 0 { - return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'") - } - - if config.Matrix != nil { - expandedSets, err := ValidateMatrix(*config.Matrix, config.Models) - if err != nil { - return Config{}, fmt.Errorf("matrix: %w", err) - } - config.Matrix.ExpandedSets = expandedSets - } else { - config = AddDefaultGroupToConfig(config) - - // Validate group members - memberUsage := make(map[string]string) - for groupID, groupConfig := range config.Groups { - prevSet := make(map[string]bool) - for _, member := range groupConfig.Members { - if _, found := prevSet[member]; found { - return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID) - } - prevSet[member] = true - - if existingGroup, exists := memberUsage[member]; exists { - return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID) - } - memberUsage[member] = groupID - } - } - } - - // Build the canonical Config.Routing from the effective result. Both legacy - // and new-style configs converge here. The Matrix pointer is shared so - // ExpandedSets stays in one place. - if config.Matrix != nil { - config.Routing.Router.Use = "matrix" - } else { - config.Routing.Router.Use = "group" - } - config.Routing.Router.Settings.Matrix = config.Matrix - config.Routing.Router.Settings.Groups = config.Groups - - if config.Routing.Scheduler.Use == "" { - config.Routing.Scheduler.Use = "fifo" - } - if config.Routing.Scheduler.Use != "fifo" { - return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use) - } - for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority { - if _, found := config.RealModelName(modelID); !found { - return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID) - } - } - - // Clean up hooks preload - if len(config.Hooks.OnStartup.Preload) > 0 { - var toPreload []string - for _, modelID := range config.Hooks.OnStartup.Preload { - modelID = strings.TrimSpace(modelID) - if modelID == "" { - continue - } - if real, found := config.RealModelName(modelID); found { - toPreload = append(toPreload, real) - } - } - config.Hooks.OnStartup.Preload = toPreload - } - - // Validate API keys (env macros already substituted at string level) - for i, apikey := range config.RequiredAPIKeys { - if apikey == "" { - return Config{}, fmt.Errorf("empty api key found in apiKeys") - } - if strings.Contains(apikey, " ") { - return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey) - } - config.RequiredAPIKeys[i] = apikey - } - - // Process peers with global macro substitution - for peerName, peerConfig := range config.Peers { - // Substitute global macros (LIFO order) - for i := len(config.Macros) - 1; i >= 0; i-- { - entry := config.Macros[i] - macroSlug := fmt.Sprintf("${%s}", entry.Name) - macroStr := fmt.Sprintf("%v", entry.Value) - - peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr) - peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr) - - // Substitute in setParams (type-preserving) - if len(peerConfig.Filters.SetParams) > 0 { - result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value) - if err != nil { - return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err) - } - peerConfig.Filters.SetParams = result.(map[string]any) - } - } - - // Validate no unknown macros remain - if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 { - return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1]) - } - if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 { - return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1]) - } - if len(peerConfig.Filters.SetParams) > 0 { - if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil { - return Config{}, err - } - } - config.Peers[peerName] = peerConfig - } - - return config, nil -} - // rewrites the yaml to include a default group with any orphaned models func AddDefaultGroupToConfig(config Config) Config { @@ -692,233 +255,3 @@ func AddDefaultGroupToConfig(config Config) Config { return config } - -func SanitizeCommand(cmdStr string) ([]string, error) { - var cleanedLines []string - for _, line := range strings.Split(cmdStr, "\n") { - trimmed := strings.TrimSpace(line) - // Skip comment lines - if strings.HasPrefix(trimmed, "#") { - continue - } - // Handle trailing backslashes by replacing with space - if strings.HasSuffix(trimmed, "\\") { - cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ") - } else { - cleanedLines = append(cleanedLines, line) - } - } - - // put it back together - cmdStr = strings.Join(cleanedLines, "\n") - - // Split the command into arguments - var args []string - if runtime.GOOS == "windows" { - args = shlex.Windows.Split(cmdStr) - } else { - args = shlex.Posix.Split(cmdStr) - } - - // Ensure the command is not empty - if len(args) == 0 { - return nil, fmt.Errorf("empty command") - } - - return args, nil -} - -func StripComments(cmdStr string) string { - var cleanedLines []string - for _, line := range strings.Split(cmdStr, "\n") { - trimmed := strings.TrimSpace(line) - // Skip comment lines - if strings.HasPrefix(trimmed, "#") { - continue - } - cleanedLines = append(cleanedLines, line) - } - return strings.Join(cleanedLines, "\n") -} - -// validateMacro validates macro name and value constraints -func validateMacro(name string, value any) error { - if len(name) >= 64 { - return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name) - } - if !macroNameRegex.MatchString(name) { - return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name) - } - - // Validate that value is a scalar type - switch v := value.(type) { - case string: - // Check for self-reference - macroSlug := fmt.Sprintf("${%s}", name) - if strings.Contains(v, macroSlug) { - return fmt.Errorf("macro '%s' contains self-reference", name) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: - // These types are allowed - default: - return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value) - } - - switch name { - case "PORT", "MODEL_ID": - return fmt.Errorf("macro name '%s' is reserved", name) - } - - return nil -} - -// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures -func validateNestedForUnknownMacros(value any, context string) error { - switch v := value.(type) { - case string: - matches := macroPatternRegex.FindAllStringSubmatch(v, -1) - for _, match := range matches { - macroName := match[1] - return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName) - } - // Check for unsubstituted env macros - envMatches := envMacroRegex.FindAllStringSubmatch(v, -1) - for _, match := range envMatches { - varName := match[1] - return fmt.Errorf("%s: environment variable '%s' not set", context, varName) - } - return nil - - case map[string]any: - for _, val := range v { - if err := validateNestedForUnknownMacros(val, context); err != nil { - return err - } - } - return nil - - case []any: - for _, val := range v { - if err := validateNestedForUnknownMacros(val, context); err != nil { - return err - } - } - return nil - - default: - // Scalar types don't contain macros - return nil - } -} - -// substituteMacroInValue recursively substitutes a single macro in a value structure -// This is called once per macro, allowing LIFO substitution order -func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) { - macroSlug := fmt.Sprintf("${%s}", macroName) - macroStr := fmt.Sprintf("%v", macroValue) - - switch v := value.(type) { - case string: - // Check if this is a direct macro substitution - if v == macroSlug { - return macroValue, nil - } - // Handle string interpolation - if strings.Contains(v, macroSlug) { - return strings.ReplaceAll(v, macroSlug, macroStr), nil - } - return v, nil - - case map[string]any: - // Recursively process map values - newMap := make(map[string]any) - for key, val := range v { - newVal, err := substituteMacroInValue(val, macroName, macroValue) - if err != nil { - return nil, err - } - newMap[key] = newVal - } - return newMap, nil - - case []any: - // Recursively process slice elements - newSlice := make([]any, len(v)) - for i, val := range v { - newVal, err := substituteMacroInValue(val, macroName, macroValue) - if err != nil { - return nil, err - } - newSlice[i] = newVal - } - return newSlice, nil - - default: - // Return scalar types as-is - return value, nil - } -} - -// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values. -// Returns error if any referenced env var is not set or contains invalid characters. -// Env macros inside YAML comments are ignored by unmarshalling the YAML first -// (which strips comments) and only checking the comment-free version for macros. -func substituteEnvMacros(s string) (string, error) { - // Unmarshal and remarshal to strip YAML comments - var raw any - if err := yaml.Unmarshal([]byte(s), &raw); err != nil { - // If YAML is invalid, fall back to scanning the original string - // so the user gets the env var error rather than a confusing YAML parse error - return substituteEnvMacrosInString(s, s) - } - clean, err := yaml.Marshal(raw) - if err != nil { - return substituteEnvMacrosInString(s, s) - } - - return substituteEnvMacrosInString(s, string(clean)) -} - -// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes -// them in target. This separation allows scanning comment-free YAML while -// substituting in the original string. -func substituteEnvMacrosInString(target, scanStr string) (string, error) { - result := target - matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1) - for _, match := range matches { - fullMatch := match[0] // ${env.VAR_NAME} - varName := match[1] // VAR_NAME - - value, exists := os.LookupEnv(varName) - if !exists { - return "", fmt.Errorf("environment variable '%s' is not set", varName) - } - - // Sanitize the value for safe YAML substitution - value, err := sanitizeEnvValueForYAML(value, varName) - if err != nil { - return "", err - } - - result = strings.ReplaceAll(result, fullMatch, value) - } - return result, nil -} - -// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution. -// It rejects values with characters that break YAML structure and escapes quotes/backslashes -// for compatibility with double-quoted YAML strings. -func sanitizeEnvValueForYAML(value, varName string) (string, error) { - // Reject values that would break YAML structure regardless of quoting context - if strings.ContainsAny(value, "\n\r\x00") { - return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName) - } - - // Escape backslashes and double quotes for safe use in double-quoted YAML strings. - // In unquoted contexts, these escapes appear literally (harmless for most use cases). - // In double-quoted contexts, they are interpreted correctly. - value = strings.ReplaceAll(value, `\`, `\\`) - value = strings.ReplaceAll(value, `"`, `\"`) - - return value, nil -} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index da63db73..2125d242 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -777,22 +777,27 @@ func TestConfig_APIKeys_Invalid(t *testing.T) { { name: "blank spaces only", content: `apiKeys: [" "]`, - expectedErr: "api key cannot contain spaces: ` `", + expectedErr: "apiKeys[0]: api key cannot contain spaces", }, { name: "contains leading space", content: `apiKeys: [" key123"]`, - expectedErr: "api key cannot contain spaces: ` key123`", + expectedErr: "apiKeys[0]: api key cannot contain spaces", }, { name: "contains trailing space", content: `apiKeys: ["key123 "]`, - expectedErr: "api key cannot contain spaces: `key123 `", + expectedErr: "apiKeys[0]: api key cannot contain spaces", }, { name: "contains middle space", content: `apiKeys: ["key 123"]`, - expectedErr: "api key cannot contain spaces: `key 123`", + expectedErr: "apiKeys[0]: api key cannot contain spaces", + }, + { + name: "space in second key reports correct index", + content: `apiKeys: ["valid-key", "bad key"]`, + expectedErr: "apiKeys[1]: api key cannot contain spaces", }, { name: "empty in list with valid keys", diff --git a/internal/config/load.go b/internal/config/load.go new file mode 100644 index 00000000..6e9585d4 --- /dev/null +++ b/internal/config/load.go @@ -0,0 +1,436 @@ +package config + +import ( + "fmt" + "io" + "net/url" + "sort" + "strings" + "time" + + "gopkg.in/yaml.v3" +) + +func LoadConfigFromReader(r io.Reader) (Config, error) { + data, err := io.ReadAll(r) + if err != nil { + return Config{}, err + } + yamlStr := string(data) + + // Phase 1: Substitute all ${env.VAR} macros at string level + // This is safe because env values are simple strings without YAML formatting + yamlStr, err = substituteEnvMacros(yamlStr) + if err != nil { + return Config{}, err + } + + // Unmarshal into full Config with defaults + config := Config{ + HealthCheckTimeout: 120, + StartPort: 5800, + LogLevel: "info", + LogTimeFormat: "", + LogToStdout: LogToStdoutProxy, + MetricsMaxInMemory: 1000, + CaptureBuffer: 5, + GlobalTTL: 0, + } + if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil { + return Config{}, err + } + + if config.HealthCheckTimeout < 15 { + config.HealthCheckTimeout = 15 + } + + // Apply defaults for performance config when section is missing + if config.Performance.Every == 0 { + config.Performance.Every = 5 * time.Second + } + if err = config.Performance.Validate(); err != nil { + return Config{}, fmt.Errorf("performance: %w", err) + } + + if config.StartPort < 1 { + return Config{}, fmt.Errorf("startPort must be greater than 1") + } + + if config.GlobalTTL < 0 { + return Config{}, fmt.Errorf("globalTTL must be >= 0") + } + + // Apply default for upstream.ignorePaths when not specified. The default + // matches common static-asset suffixes so they do not trigger a swap. + if len(config.Upstream.IgnorePaths) == 0 { + config.Upstream.IgnorePaths = DefaultUpstreamIgnorePaths() + } + + switch config.LogToStdout { + case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone: + default: + return Config{}, fmt.Errorf("logToStdout must be one of: proxy, upstream, both, none") + } + + // Populate the aliases map + config.aliases = make(map[string]string) + for modelName, modelConfig := range config.Models { + for _, alias := range modelConfig.Aliases { + if _, found := config.aliases[alias]; found { + return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName) + } + config.aliases[alias] = modelName + } + } + + // Validate global macros + for _, macro := range config.Macros { + if err = validateMacro(macro.Name, macro.Value); err != nil { + return Config{}, err + } + } + + // Get and sort all model IDs for consistent port assignment + modelIds := make([]string, 0, len(config.Models)) + for modelId := range config.Models { + modelIds = append(modelIds, modelId) + } + sort.Strings(modelIds) + + nextPort := config.StartPort + for _, modelId := range modelIds { + modelConfig := config.Models[modelId] + modelConfig.HealthCheckTimeout = config.HealthCheckTimeout + + // Strip comments from command fields + modelConfig.Cmd = StripComments(modelConfig.Cmd) + modelConfig.CmdStop = StripComments(modelConfig.CmdStop) + + // set model TTL to globalTTL it is the default value + if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL { + modelConfig.UnloadAfter = config.GlobalTTL + } + + if modelConfig.UnloadAfter < 0 { + return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter) + } + + // Validate model macros + for _, macro := range modelConfig.Macros { + if err = validateMacro(macro.Name, macro.Value); err != nil { + return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error()) + } + } + + // Build merged macro list: MODEL_ID + global macros + model macros (model overrides global) + mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1) + mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId}) + mergedMacros = append(mergedMacros, config.Macros...) + + // Add model macros (override globals with same name) + for _, entry := range modelConfig.Macros { + found := false + for i, existing := range mergedMacros { + if existing.Name == entry.Name { + mergedMacros[i] = entry + found = true + break + } + } + if !found { + mergedMacros = append(mergedMacros, entry) + } + } + + // Substitute remaining macros in model fields (LIFO order) + for i := len(mergedMacros) - 1; i >= 0; i-- { + entry := mergedMacros[i] + macroSlug := fmt.Sprintf("${%s}", entry.Name) + macroStr := fmt.Sprintf("%v", entry.Value) + + modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr) + modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr) + modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr) + modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr) + modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr) + modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr) + modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr) + + // Substitute macros in SetParamsByID keys and values + if len(modelConfig.Filters.SetParamsByID) > 0 { + newSetParamsByID := make(map[string]map[string]any, len(modelConfig.Filters.SetParamsByID)) + for key, paramMap := range modelConfig.Filters.SetParamsByID { + newKey := strings.ReplaceAll(key, macroSlug, macroStr) + newValAny, err := substituteMacroInValue(any(paramMap), entry.Name, entry.Value) + if err != nil { + return Config{}, fmt.Errorf("model %s filters.setParamsByID: %s", modelId, err.Error()) + } + newParamMap, ok := newValAny.(map[string]any) + if !ok { + return Config{}, fmt.Errorf("model %s filters.setParamsByID: unexpected type after macro substitution", modelId) + } + newSetParamsByID[newKey] = newParamMap + } + modelConfig.Filters.SetParamsByID = newSetParamsByID + } + + // Substitute in metadata (type-preserving) + if len(modelConfig.Metadata) > 0 { + result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value) + if err != nil { + return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error()) + } + modelConfig.Metadata = result.(map[string]any) + } + } + + // Handle PORT macro - only allocate if cmd uses it + cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}") + proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}") + if cmdHasPort || proxyHasPort { + if !cmdHasPort && proxyHasPort { + return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId) + } + + macroSlug := "${PORT}" + macroStr := fmt.Sprintf("%v", nextPort) + + modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr) + modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr) + modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr) + modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr) + modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr) + + if len(modelConfig.Metadata) > 0 { + result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort) + if err != nil { + return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error()) + } + modelConfig.Metadata = result.(map[string]any) + } + + nextPort++ + } + + // Validate no unknown macros remain + fieldMap := map[string]string{ + "cmd": modelConfig.Cmd, + "cmdStop": modelConfig.CmdStop, + "proxy": modelConfig.Proxy, + "checkEndpoint": modelConfig.CheckEndpoint, + "filters.stripParams": modelConfig.Filters.StripParams, + "name": modelConfig.Name, + "description": modelConfig.Description, + } + + for fieldName, fieldValue := range fieldMap { + matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1) + for _, match := range matches { + macroName := match[1] + if macroName == "PID" && fieldName == "cmdStop" { + continue // replaced at runtime + } + if macroName == "PORT" || macroName == "MODEL_ID" { + return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName) + } + return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName) + } + } + + if len(modelConfig.Metadata) > 0 { + if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil { + return Config{}, err + } + } + + if err = modelConfig.Capabilities.Validate(); err != nil { + return Config{}, fmt.Errorf("model %s: %w", modelId, err) + } + + // Validate SetParamsByID keys and values + for key, paramMap := range modelConfig.Filters.SetParamsByID { + if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 { + return Config{}, fmt.Errorf("unknown macro '${%s}' found in model %s filters.setParamsByID key", matches[0][1], modelId) + } + if err := validateNestedForUnknownMacros(any(paramMap), fmt.Sprintf("model %s filters.setParamsByID[%s]", modelId, key)); err != nil { + return Config{}, err + } + } + + // Auto-register setParamsByID keys as aliases (skip the model's own ID) + for key := range modelConfig.Filters.SetParamsByID { + if key == modelId { + continue + } + if _, exists := config.Models[key]; exists { + return Config{}, fmt.Errorf("model %s filters.setParamsByID: key '%s' conflicts with an existing model ID", modelId, key) + } + if existingModel, exists := config.aliases[key]; exists { + if existingModel != modelId { + return Config{}, fmt.Errorf("duplicate alias '%s' in model %s filters.setParamsByID, already used by model %s", key, modelId, existingModel) + } + continue // already registered as explicit alias for this model + } + config.aliases[key] = modelId + modelConfig.Aliases = append(modelConfig.Aliases, key) + } + + if _, err := url.Parse(modelConfig.Proxy); err != nil { + return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err) + } + + if modelConfig.SendLoadingState == nil { + v := config.SendLoadingState + modelConfig.SendLoadingState = &v + } + + config.Models[modelId] = modelConfig + } + + // Normalize routing config. The legacy top-level `matrix`/`groups` keys and + // the new `routing.router` block are mutually exclusive: a config may use + // either style, never both. + hasTopLevel := config.Matrix != nil || len(config.Groups) > 0 + rtr := config.Routing.Router + hasRouting := rtr.Use != "" || rtr.Settings.Matrix != nil || len(rtr.Settings.Groups) > 0 + + if hasTopLevel && hasRouting { + return Config{}, fmt.Errorf("config uses both the legacy top-level 'matrix'/'groups' keys and the new 'routing.router' block; please migrate the top-level keys into 'routing.router' and remove them") + } + + if !hasTopLevel { + // Both groups and matrix may be defined under routing.router.settings; + // routing.router.use selects which one is active, so there is no conflict. + rs := config.Routing.Router.Settings + switch config.Routing.Router.Use { + case "matrix": + if rs.Matrix == nil { + return Config{}, fmt.Errorf("routing.router.use is 'matrix' but routing.router.settings.matrix is not set") + } + config.Matrix = rs.Matrix + case "group", "": + config.Groups = rs.Groups + default: + return Config{}, fmt.Errorf("routing.router.use: unknown router %q (valid: group, matrix)", config.Routing.Router.Use) + } + } + + // groups XOR matrix + if config.Matrix != nil && len(config.Groups) > 0 { + return Config{}, fmt.Errorf("config cannot use both 'groups' and 'matrix'") + } + + if config.Matrix != nil { + expandedSets, err := ValidateMatrix(*config.Matrix, config.Models) + if err != nil { + return Config{}, fmt.Errorf("matrix: %w", err) + } + config.Matrix.ExpandedSets = expandedSets + } else { + config = AddDefaultGroupToConfig(config) + + // Validate group members + memberUsage := make(map[string]string) + for groupID, groupConfig := range config.Groups { + prevSet := make(map[string]bool) + for _, member := range groupConfig.Members { + if _, found := prevSet[member]; found { + return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID) + } + prevSet[member] = true + + if existingGroup, exists := memberUsage[member]; exists { + return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID) + } + memberUsage[member] = groupID + } + } + } + + // Build the canonical Config.Routing from the effective result. Both legacy + // and new-style configs converge here. The Matrix pointer is shared so + // ExpandedSets stays in one place. + if config.Matrix != nil { + config.Routing.Router.Use = "matrix" + } else { + config.Routing.Router.Use = "group" + } + config.Routing.Router.Settings.Matrix = config.Matrix + config.Routing.Router.Settings.Groups = config.Groups + + if config.Routing.Scheduler.Use == "" { + config.Routing.Scheduler.Use = "fifo" + } + if config.Routing.Scheduler.Use != "fifo" { + return Config{}, fmt.Errorf("routing.scheduler.use: unknown scheduler %q (valid: fifo)", config.Routing.Scheduler.Use) + } + for modelID := range config.Routing.Scheduler.Settings.Fifo.Priority { + if _, found := config.RealModelName(modelID); !found { + return Config{}, fmt.Errorf("routing.scheduler.settings.fifo.priority references unknown model %q", modelID) + } + } + + // Clean up hooks preload + if len(config.Hooks.OnStartup.Preload) > 0 { + var toPreload []string + for _, modelID := range config.Hooks.OnStartup.Preload { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + continue + } + if real, found := config.RealModelName(modelID); found { + toPreload = append(toPreload, real) + } + } + config.Hooks.OnStartup.Preload = toPreload + } + + // Validate API keys (env macros already substituted at string level) + for i, apikey := range config.RequiredAPIKeys { + if apikey == "" { + return Config{}, fmt.Errorf("empty api key found in apiKeys") + } + if strings.Contains(apikey, " ") { + return Config{}, fmt.Errorf("apiKeys[%d]: api key cannot contain spaces", i) + } + config.RequiredAPIKeys[i] = apikey + } + + // Process peers with global macro substitution + for peerName, peerConfig := range config.Peers { + // Substitute global macros (LIFO order) + for i := len(config.Macros) - 1; i >= 0; i-- { + entry := config.Macros[i] + macroSlug := fmt.Sprintf("${%s}", entry.Name) + macroStr := fmt.Sprintf("%v", entry.Value) + + peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr) + peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr) + + // Substitute in setParams (type-preserving) + if len(peerConfig.Filters.SetParams) > 0 { + result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value) + if err != nil { + return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err) + } + peerConfig.Filters.SetParams = result.(map[string]any) + } + } + + // Validate no unknown macros remain + if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 { + return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1]) + } + if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.Filters.StripParams, -1); len(matches) > 0 { + return Config{}, fmt.Errorf("peers.%s.filters.stripParams: unknown macro '${%s}'", peerName, matches[0][1]) + } + if len(peerConfig.Filters.SetParams) > 0 { + if err := validateNestedForUnknownMacros(peerConfig.Filters.SetParams, fmt.Sprintf("peers.%s.filters.setParams", peerName)); err != nil { + return Config{}, err + } + } + config.Peers[peerName] = peerConfig + } + + return config, nil +} diff --git a/internal/config/macros.go b/internal/config/macros.go new file mode 100644 index 00000000..a2b9eaf6 --- /dev/null +++ b/internal/config/macros.go @@ -0,0 +1,198 @@ +package config + +import ( + "fmt" + "os" + "regexp" + "strings" + + "gopkg.in/yaml.v3" +) + +var ( + macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`) + envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`) +) + +// validateMacro validates macro name and value constraints +func validateMacro(name string, value any) error { + if len(name) >= 64 { + return fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", name) + } + if !macroNameRegex.MatchString(name) { + return fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", name) + } + + // Validate that value is a scalar type + switch v := value.(type) { + case string: + // Check for self-reference + macroSlug := fmt.Sprintf("${%s}", name) + if strings.Contains(v, macroSlug) { + return fmt.Errorf("macro '%s' contains self-reference", name) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: + // These types are allowed + default: + return fmt.Errorf("macro '%s' has invalid type %T, must be a scalar type (string, int, float, or bool)", name, value) + } + + switch name { + case "PORT", "MODEL_ID": + return fmt.Errorf("macro name '%s' is reserved", name) + } + + return nil +} + +// validateNestedForUnknownMacros recursively checks for any remaining macro references in nested structures +func validateNestedForUnknownMacros(value any, context string) error { + switch v := value.(type) { + case string: + matches := macroPatternRegex.FindAllStringSubmatch(v, -1) + for _, match := range matches { + macroName := match[1] + return fmt.Errorf("%s: unknown macro '${%s}'", context, macroName) + } + // Check for unsubstituted env macros + envMatches := envMacroRegex.FindAllStringSubmatch(v, -1) + for _, match := range envMatches { + varName := match[1] + return fmt.Errorf("%s: environment variable '%s' not set", context, varName) + } + return nil + + case map[string]any: + for _, val := range v { + if err := validateNestedForUnknownMacros(val, context); err != nil { + return err + } + } + return nil + + case []any: + for _, val := range v { + if err := validateNestedForUnknownMacros(val, context); err != nil { + return err + } + } + return nil + + default: + // Scalar types don't contain macros + return nil + } +} + +// substituteMacroInValue recursively substitutes a single macro in a value structure +// This is called once per macro, allowing LIFO substitution order +func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) { + macroSlug := fmt.Sprintf("${%s}", macroName) + macroStr := fmt.Sprintf("%v", macroValue) + + switch v := value.(type) { + case string: + // Check if this is a direct macro substitution + if v == macroSlug { + return macroValue, nil + } + // Handle string interpolation + if strings.Contains(v, macroSlug) { + return strings.ReplaceAll(v, macroSlug, macroStr), nil + } + return v, nil + + case map[string]any: + // Recursively process map values + newMap := make(map[string]any) + for key, val := range v { + newVal, err := substituteMacroInValue(val, macroName, macroValue) + if err != nil { + return nil, err + } + newMap[key] = newVal + } + return newMap, nil + + case []any: + // Recursively process slice elements + newSlice := make([]any, len(v)) + for i, val := range v { + newVal, err := substituteMacroInValue(val, macroName, macroValue) + if err != nil { + return nil, err + } + newSlice[i] = newVal + } + return newSlice, nil + + default: + // Return scalar types as-is + return value, nil + } +} + +// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values. +// Returns error if any referenced env var is not set or contains invalid characters. +// Env macros inside YAML comments are ignored by unmarshalling the YAML first +// (which strips comments) and only checking the comment-free version for macros. +func substituteEnvMacros(s string) (string, error) { + // Unmarshal and remarshal to strip YAML comments + var raw any + if err := yaml.Unmarshal([]byte(s), &raw); err != nil { + // If YAML is invalid, fall back to scanning the original string + // so the user gets the env var error rather than a confusing YAML parse error + return substituteEnvMacrosInString(s, s) + } + clean, err := yaml.Marshal(raw) + if err != nil { + return substituteEnvMacrosInString(s, s) + } + + return substituteEnvMacrosInString(s, string(clean)) +} + +// substituteEnvMacrosInString finds ${env.VAR} macros in scanStr and substitutes +// them in target. This separation allows scanning comment-free YAML while +// substituting in the original string. +func substituteEnvMacrosInString(target, scanStr string) (string, error) { + result := target + matches := envMacroRegex.FindAllStringSubmatch(scanStr, -1) + for _, match := range matches { + fullMatch := match[0] // ${env.VAR_NAME} + varName := match[1] // VAR_NAME + + value, exists := os.LookupEnv(varName) + if !exists { + return "", fmt.Errorf("environment variable '%s' is not set", varName) + } + + // Sanitize the value for safe YAML substitution + value, err := sanitizeEnvValueForYAML(value, varName) + if err != nil { + return "", err + } + + result = strings.ReplaceAll(result, fullMatch, value) + } + return result, nil +} + +// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution. +// It rejects values with characters that break YAML structure and escapes quotes/backslashes +// for compatibility with double-quoted YAML strings. +func sanitizeEnvValueForYAML(value, varName string) (string, error) { + // Reject values that would break YAML structure regardless of quoting context + if strings.ContainsAny(value, "\n\r\x00") { + return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName) + } + + // Escape backslashes and double quotes for safe use in double-quoted YAML strings. + // In unquoted contexts, these escapes appear literally (harmless for most use cases). + // In double-quoted contexts, they are interpreted correctly. + value = strings.ReplaceAll(value, `\`, `\\`) + value = strings.ReplaceAll(value, `"`, `\"`) + + return value, nil +} diff --git a/internal/config/merge.go b/internal/config/merge.go new file mode 100644 index 00000000..91bd6165 --- /dev/null +++ b/internal/config/merge.go @@ -0,0 +1,300 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "gopkg.in/yaml.v3" +) + +// identityMapPaths is the set of dotted paths whose direct children are +// identity-keyed maps. A child key present in two sources is a hard error; +// such keys name discrete entities (a model, a group, a peer, etc.) and a +// duplicate means the user has split one entity across files by mistake. +var identityMapPaths = map[string]bool{ + "models": true, + "groups": true, + "profiles": true, + "peers": true, + "matrix": true, + "routing.router.settings.groups": true, + "routing.router.settings.matrix": true, +} + +// LoadConfigSources loads and merges configuration from -config (optional) +// and -config-dir (optional). At least one must be provided. The -config file +// is loaded first; *.yml/*.yaml files directly under -config-dir are then +// merged in sorted filename order. The merged document is passed through the +// existing LoadConfigFromReader pipeline unchanged. +func LoadConfigSources(configPath, configDir string) (Config, error) { + if configPath == "" && configDir == "" { + return Config{}, fmt.Errorf("at least one of -config or -config-dir must be provided") + } + + var sourcePaths []string + + if configPath != "" { + sourcePaths = append(sourcePaths, configPath) + } + + if configDir != "" { + dirFiles, err := listYAMLFiles(configDir) + if err != nil { + return Config{}, fmt.Errorf("-config-dir %s: %w", configDir, err) + } + + if configPath != "" { + absConfig, err := filepath.Abs(configPath) + if err != nil { + return Config{}, fmt.Errorf("failed to resolve -config path: %w", err) + } + for _, f := range dirFiles { + absF, err := filepath.Abs(f) + if err != nil { + return Config{}, fmt.Errorf("failed to resolve config dir file %s: %w", f, err) + } + if absConfig == absF { + return Config{}, fmt.Errorf("-config path %s is also present in -config-dir %s; remove it from one", configPath, configDir) + } + } + } + + sourcePaths = append(sourcePaths, dirFiles...) + } + + if len(sourcePaths) == 0 { + return Config{}, fmt.Errorf("no configuration sources found") + } + + var merged *yaml.Node + for _, p := range sourcePaths { + node, err := parseSource(p) + if err != nil { + return Config{}, err + } + if node == nil { + continue // empty file + } + if merged == nil { + merged = node + continue + } + if err := mergeNodes(merged, node, "", p); err != nil { + return Config{}, err + } + } + + if merged == nil { + // All sources were empty; run the pipeline on empty input so defaults + // and validation still apply (e.g. startPort, performance defaults). + return LoadConfigFromReader(strings.NewReader("")) + } + + out, err := yaml.Marshal(merged) + if err != nil { + return Config{}, fmt.Errorf("failed to marshal merged config: %w", err) + } + return LoadConfigFromReader(strings.NewReader(string(out))) +} + +// listYAMLFiles returns the top-level *.yml and *.yaml files in dir, sorted by +// filename for deterministic merge order. Subdirectories are not traversed. +func listYAMLFiles(dir string) ([]string, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + var files []string + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") { + continue + } + files = append(files, filepath.Join(dir, name)) + } + sort.Strings(files) + return files, nil +} + +// parseSource reads and parses one YAML config file into a root mapping node. +// Returns a nil node (no error) when the file is empty or contains only +// comments. +// +// Env macros (${env.VAR}) are substituted at the string level before YAML +// parsing so that flow-style constructs like [${env.API_KEY}] parse +// correctly — the brace would otherwise be interpreted as a flow mapping. +func parseSource(path string) (*yaml.Node, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read config %s: %w", path, err) + } + yamlStr, err := substituteEnvMacros(string(data)) + if err != nil { + return nil, fmt.Errorf("config %s: %w", path, err) + } + var doc yaml.Node + if err := yaml.Unmarshal([]byte(yamlStr), &doc); err != nil { + return nil, fmt.Errorf("failed to parse config %s: %w", path, err) + } + // yaml.Unmarshal into a yaml.Node yields a DocumentNode whose Content[0] + // is the actual root. Unwrap it so callers see the real top-level node. + root := &doc + if root.Kind == yaml.DocumentNode && len(root.Content) > 0 { + root = root.Content[0] + } + if root.Kind == 0 || root.Content == nil { + return nil, nil + } + if root.Kind != yaml.MappingNode { + return nil, fmt.Errorf("config %s: top-level YAML must be a mapping", path) + } + return root, nil +} + +// mergeNodes merges src into dst (both MappingNodes) in place. Keys present in +// only one side are kept; shared keys are merged recursively under the rules +// in mergeValue. srcPath is included in error messages to identify the file +// that introduced the conflict. +func mergeNodes(dst, src *yaml.Node, path, srcPath string) error { + srcIdx := indexMapping(src) + + // First pass: merge shared keys in place. + for i := 0; i+1 < len(dst.Content); i += 2 { + keyNode := dst.Content[i] + dstVal := dst.Content[i+1] + key := keyNode.Value + + srcVal, ok := srcIdx[key] + if !ok { + continue // dst-only key, keep as-is + } + + childPath := joinPath(path, key) + + if identityMapPaths[childPath] { + // Identity-keyed map: each child key names a discrete entity + // (a model, group, peer, ...). A shared child key is a hard + // error; src-only children are appended in the second pass. + if err := mergeIdentityMap(dstVal, srcVal, childPath, key, srcPath); err != nil { + return err + } + continue + } + + if err := mergeValue(dstVal, srcVal, childPath, srcPath); err != nil { + return err + } + } + + // Second pass: append src-only keys. + dstIdx := indexMapping(dst) + for i := 0; i+1 < len(src.Content); i += 2 { + keyNode := src.Content[i] + srcVal := src.Content[i+1] + key := keyNode.Value + + if _, ok := dstIdx[key]; ok { + continue // already merged above + } + keyCopy := *keyNode + valCopy := *srcVal + dst.Content = append(dst.Content, &keyCopy, &valCopy) + } + + return nil +} + +// mergeIdentityMap merges two identity-keyed mapping nodes (e.g. `models`, +// `groups`, `peers`). Any child key present in both sides is a duplicate +// entity and produces an error naming the conflicting key and source file. +// src-only keys are appended to dst. +func mergeIdentityMap(dst, src *yaml.Node, path, mapName, srcPath string) error { + if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode { + return fmt.Errorf("conflict at %q: expected a mapping, introduced by %s", path, srcPath) + } + dstIdx := indexMapping(dst) + for i := 0; i+1 < len(src.Content); i += 2 { + keyNode := src.Content[i] + srcVal := src.Content[i+1] + key := keyNode.Value + if _, dup := dstIdx[key]; dup { + return fmt.Errorf("duplicate %s %q found in %s (already defined in another config source)", mapName, key, srcPath) + } + keyCopy := *keyNode + valCopy := *srcVal + dst.Content = append(dst.Content, &keyCopy, &valCopy) + } + return nil +} + +// mergeValue merges srcVal into dstVal (both pointing into the parent's +// Content slice). Mapping+Mapping recurses; Sequence+Sequence concatenates; +// Scalar+Scalar errors on value mismatch; null on either side yields to the +// non-null side. +func mergeValue(dstVal, srcVal *yaml.Node, path, srcPath string) error { + switch { + case dstVal.Kind == yaml.MappingNode && srcVal.Kind == yaml.MappingNode: + return mergeNodes(dstVal, srcVal, path, srcPath) + + case dstVal.Kind == yaml.SequenceNode && srcVal.Kind == yaml.SequenceNode: + dstVal.Content = append(dstVal.Content, srcVal.Content...) + return nil + + case dstVal.Kind == yaml.ScalarNode && srcVal.Kind == yaml.ScalarNode: + if isNullScalar(dstVal) { + *dstVal = *srcVal + return nil + } + if isNullScalar(srcVal) { + return nil + } + if dstVal.Value == srcVal.Value && dstVal.Tag == srcVal.Tag { + return nil + } + return fmt.Errorf("conflict at %q: %s sets a different value than a previous source", path, srcPath) + + case isNull(dstVal): + *dstVal = *srcVal + return nil + + case isNull(srcVal): + return nil + + default: + return fmt.Errorf("conflict at %q: incompatible YAML node kinds (kind %d vs %d) introduced by %s", path, dstVal.Kind, srcVal.Kind, srcPath) + } +} + +// isNull reports whether n represents a YAML null (empty or !!null). +func isNull(n *yaml.Node) bool { + if n == nil || n.Kind == 0 { + return true + } + return isNullScalar(n) +} + +func isNullScalar(n *yaml.Node) bool { + return n.Kind == yaml.ScalarNode && (n.Tag == "!!null" || n.Tag == "") && n.Value == "" +} + +// indexMapping builds a key -> value-node index for a mapping node. +func indexMapping(n *yaml.Node) map[string]*yaml.Node { + idx := make(map[string]*yaml.Node, len(n.Content)/2) + for i := 0; i+1 < len(n.Content); i += 2 { + idx[n.Content[i].Value] = n.Content[i+1] + } + return idx +} + +func joinPath(parent, key string) string { + if parent == "" { + return key + } + return parent + "." + key +} diff --git a/internal/config/merge_test.go b/internal/config/merge_test.go new file mode 100644 index 00000000..e74ba10e --- /dev/null +++ b/internal/config/merge_test.go @@ -0,0 +1,304 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// writeYAML writes content to a file named name inside dir. Returns the full +// path of the written file. +func writeYAML(t *testing.T, dir, name, content string) string { + t.Helper() + p := filepath.Join(dir, name) + require.NoError(t, os.MkdirAll(filepath.Dir(p), 0o755)) + require.NoError(t, os.WriteFile(p, []byte(content), 0o644)) + return p +} + +// modelCfg builds a single-model YAML snippet indented for nesting under a +// `models:` key. The proxy uses a fixed port so tests don't depend on +// ${PORT} allocation. +func modelCfg(id, cmd string) string { + return " " + id + ":\n cmd: " + cmd + "\n proxy: \"http://localhost:9999\"\n" +} + +func TestLoadConfigSources_NeitherProvided(t *testing.T) { + _, err := LoadConfigSources("", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "at least one of -config or -config-dir") +} + +func TestLoadConfigSources_ConfigOnly(t *testing.T) { + dir := t.TempDir() + cfgPath := writeYAML(t, dir, "config.yaml", ` +models: +`+modelCfg("model1", "echo hi")+` +groups: + group1: + members: ["model1"] +`) + cfg, err := LoadConfigSources(cfgPath, "") + require.NoError(t, err) + _, id, ok := cfg.FindConfig("model1") + require.True(t, ok) + assert.Equal(t, "model1", id) +} + +func TestLoadConfigSources_DirOnly(t *testing.T) { + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("alpha", "echo a")) + writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("beta", "echo b")) + + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + for _, want := range []string{"alpha", "beta"} { + _, _, ok := cfg.FindConfig(want) + assert.True(t, ok, "model %s should be present", want) + } +} + +func TestLoadConfigSources_ConfigPlusDirAdditive(t *testing.T) { + // -config lives outside -config-dir; both contribute models additively. + dir := t.TempDir() + cfgPath := writeYAML(t, dir, "config.yaml", "models:\n"+modelCfg("base", "echo base")) + cfgDir := t.TempDir() + writeYAML(t, cfgDir, "extra.yaml", "models:\n"+modelCfg("ext", "echo ext")) + + cfg, err := LoadConfigSources(cfgPath, cfgDir) + require.NoError(t, err) + for _, want := range []string{"base", "ext"} { + _, _, ok := cfg.FindConfig(want) + assert.True(t, ok, "model %s should be present after merge", want) + } +} + +// TestLoadConfigSources_ConfigInDirOverlap verifies that a -config file that +// is also a member of -config-dir is rejected. +func TestLoadConfigSources_ConfigInDirOverlap(t *testing.T) { + dir := t.TempDir() + cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("base", "echo base")) + + _, err := LoadConfigSources(cfgPath, dir) + require.Error(t, err) + assert.Contains(t, err.Error(), "is also present in -config-dir") +} + +func TestLoadConfigSources_DuplicateModelID(t *testing.T) { + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("dup", "echo a")) + writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("dup", "echo b")) + + _, err := LoadConfigSources("", dir) + require.Error(t, err) + assert.Contains(t, err.Error(), `duplicate models "dup"`) +} + +func TestLoadConfigSources_DuplicateGroupID(t *testing.T) { + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", ` +models: +`+modelCfg("m1", "echo m1")+"groups:\n g1:\n members: [m1]\n") + writeYAML(t, dir, "b.yaml", ` +models: +`+modelCfg("m2", "echo m2")+"groups:\n g1:\n members: [m2]\n") + + _, err := LoadConfigSources("", dir) + require.Error(t, err) + assert.Contains(t, err.Error(), `duplicate groups "g1"`) +} + +func TestLoadConfigSources_DuplicatePeer(t *testing.T) { + dir := t.TempDir() + peerA := "peers:\n remote:\n proxy: http://x:1\n models: [m1]\n" + peerB := "peers:\n remote:\n proxy: http://x:2\n models: [m2]\n" + writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\n"+peerA) + writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\n"+peerB) + + _, err := LoadConfigSources("", dir) + require.Error(t, err) + assert.Contains(t, err.Error(), `duplicate peers "remote"`) +} + +func TestLoadConfigSources_ScalarConflict(t *testing.T) { + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n") + writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 200\n") + + _, err := LoadConfigSources("", dir) + require.Error(t, err) + assert.Contains(t, err.Error(), `conflict at "globalTTL"`) +} + +func TestLoadConfigSources_ScalarSameValueNoConflict(t *testing.T) { + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\nglobalTTL: 100\n") + writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\nglobalTTL: 100\n") + + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + assert.Equal(t, 100, cfg.GlobalTTL) +} + +func TestLoadConfigSources_MacrosConcatenate(t *testing.T) { + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", "macros:\n LOW: 1\nmodels:\n"+modelCfg("m1", "echo ${LOW}")) + writeYAML(t, dir, "b.yaml", "macros:\n HIGH: 2\nmodels:\n"+modelCfg("m2", "echo ${HIGH}")) + + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + // Both macros are available globally after merge. + low, ok := cfg.Macros.Get("LOW") + require.True(t, ok) + assert.Equal(t, 1, low) + high, ok := cfg.Macros.Get("HIGH") + require.True(t, ok) + assert.Equal(t, 2, high) +} + +func TestLoadConfigSources_APIKeysConcatenate(t *testing.T) { + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")+"\napiKeys: [key-a]\n") + writeYAML(t, dir, "b.yaml", "models:\n"+modelCfg("m2", "echo m2")+"\napiKeys: [key-b]\n") + + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"key-a", "key-b"}, cfg.RequiredAPIKeys) +} + +func TestLoadConfigSources_RoutingGroupsMerge(t *testing.T) { + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", ` +models: +`+modelCfg("m1", "echo m1")+` +routing: + router: + settings: + groups: + groupA: + members: [m1] +`) + writeYAML(t, dir, "b.yaml", ` +models: +`+modelCfg("m2", "echo m2")+` +routing: + router: + settings: + groups: + groupB: + members: [m2] +`) + + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + groups := cfg.Routing.Router.Settings.Groups + assert.Contains(t, groups, "groupA") + assert.Contains(t, groups, "groupB") + // default group added by pipeline for orphaned/leftover routing groups... + // here both groups reference distinct models +} + +func TestLoadConfigSources_EnvMacrosSubstituted(t *testing.T) { + dir := t.TempDir() + // Use ${PORT} in cmd so the pipeline allocates a port and substitutes it; + // verifies env/macro substitution runs on the merged document. + writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: serve --port ${PORT}\n proxy: \"http://localhost:${PORT}\"\n") + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + m := cfg.Models["m1"] + assert.NotContains(t, m.Cmd, "${PORT}", "PORT macro should have been substituted") + assert.NotContains(t, m.Proxy, "${PORT}", "PORT macro should have been substituted in proxy") +} + +func TestLoadConfigSources_EnvMacroInFlowStyleList(t *testing.T) { + // Regression: flow-style lists with ${env.*} must parse. Previously + // parseSource unmarshalled before env substitution, so the brace in + // [${env.API_KEY}] was misread as a flow mapping and parsing failed. + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", "models:\n m1:\n cmd: echo hi\n proxy: \"http://localhost:9999\"\n") + writeYAML(t, dir, "keys.yaml", "apiKeys: [${env.TEST_API_KEY}]\nmodels:\n m2:\n cmd: echo hi\n proxy: \"http://localhost:9998\"\n") + + t.Setenv("TEST_API_KEY", "secret123") + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + assert.Contains(t, cfg.RequiredAPIKeys, "secret123") +} + +func TestLoadConfigSources_SortedOrderDeterministic(t *testing.T) { + // Two files defining distinct models, scanned in z..a order by filename. + // Determine merged result is the same regardless of how the FS returns them. + dir := t.TempDir() + writeYAML(t, dir, "z.yaml", "models:\n"+modelCfg("zmodel", "echo z")) + writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("amodel", "echo a")) + + const runs = 3 + for i := 0; i < runs; i++ { + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + // startPort-based allocation: first allocated model gets 5800. + // Sorted order means amodel gets 5800, zmodel gets 5801. + _, _, ok := cfg.FindConfig("amodel") + assert.True(t, ok) + _, _, ok = cfg.FindConfig("zmodel") + assert.True(t, ok) + } +} + +func TestLoadConfigSources_EmptyDirWithConfig(t *testing.T) { + dir := t.TempDir() + cfgDir := t.TempDir() + cfgPath := writeYAML(t, dir, "main.yaml", "models:\n"+modelCfg("m1", "echo m1")) + + cfg, err := LoadConfigSources(cfgPath, cfgDir) + require.NoError(t, err) + assert.Contains(t, cfg.Models, "m1") +} + +func TestLoadConfigSources_EmptyDirOnly(t *testing.T) { + // An empty -config-dir with no -config is an error: there is nothing to + // load and silently producing an empty config would mask the misconfig. + cfgDir := t.TempDir() + _, err := LoadConfigSources("", cfgDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "no configuration sources found") +} + +func TestLoadConfigSources_AssertNoUnknownMacrosAfterMerge(t *testing.T) { + // Macros defined in one file should not satisfy unknown-macro validation in + // another — they do, because merge concats global macros before validation + // runs. This test documents that a macro from file A is usable in file B. + dir := t.TempDir() + writeYAML(t, dir, "macros.yaml", "macros:\n SHARED: hello\nmodels:\n"+modelCfg("dummy", "echo dummy")) + writeYAML(t, dir, "use.yaml", "models:\n"+modelCfg("user", "echo ${SHARED}")) + + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + m := cfg.Models["user"] + assert.Contains(t, m.Cmd, "hello") + assert.NotContains(t, m.Cmd, "${SHARED}") +} + +func TestLoadConfigSources_KindMismatchErrors(t *testing.T) { + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", "startPort: 5800\nmodels:\n"+modelCfg("m1", "echo m1")) + writeYAML(t, dir, "b.yaml", "startPort: [5800, 5801]\nmodels:\n"+modelCfg("m2", "echo m2")) + + _, err := LoadConfigSources("", dir) + require.Error(t, err) + assert.Contains(t, err.Error(), "incompatible YAML node kinds") +} + +func TestLoadConfigSources_NullYieldsToValue(t *testing.T) { + // File A: routing.router block absent (null on root for routing); + // file B: defines routing.router.settings.groups. Merge should keep B's. + dir := t.TempDir() + writeYAML(t, dir, "a.yaml", "models:\n"+modelCfg("m1", "echo m1")) + writeYAML(t, dir, "b.yaml", "routing:\n router:\n settings:\n groups:\n g1:\n members: [m1]\nmodels:\n"+modelCfg("m2", "echo m2")) + + cfg, err := LoadConfigSources("", dir) + require.NoError(t, err) + assert.Contains(t, cfg.Routing.Router.Settings.Groups, "g1") +} diff --git a/internal/watcher/dirwatcher.go b/internal/watcher/dirwatcher.go new file mode 100644 index 00000000..65e22c2e --- /dev/null +++ b/internal/watcher/dirwatcher.go @@ -0,0 +1,137 @@ +package configwatcher + +import ( + "context" + "os" + "path/filepath" + "sort" + "strings" + "time" +) + +// DirWatcher polls a directory for changes to its set of *.yml / *.yaml files. +// It fires OnChange when a file is added, removed, or has its mod time/size +// change. Like Watcher it is poll-based so it works in Docker bind-mounts and +// k8s ConfigMap projections where inotify is unreliable. +// +// The baseline poll establishes initial state and does not fire OnChange. +type DirWatcher struct { + Path string + Interval time.Duration + OnChange func() +} + +// dirSnapshot is an ordered map of file name -> file state. The ordering is +// derived from sorted filenames so two snapshots compare deterministically +// regardless of readdir order. exists reflects whether the directory was +// readable at scan time; a missing directory yields exists=false. +type dirSnapshot struct { + exists bool + names []string + states map[string]snapshot +} + +func newDirSnapshot() dirSnapshot { + return dirSnapshot{states: make(map[string]snapshot)} +} + +// equal reports whether two snapshots describe the same file set and per-file +// state. A missing directory (exists=false) is treated as equal to any other +// missing directory regardless of cached names. +func (s dirSnapshot) equal(other dirSnapshot) bool { + if !s.exists && !other.exists { + return true + } + if s.exists != other.exists { + return false + } + if len(s.names) != len(other.names) { + return false + } + for i, n := range s.names { + if other.names[i] != n { + return false + } + } + for _, n := range s.names { + a, b := s.states[n], other.states[n] + if a.exists != b.exists || a.size != b.size || !a.modTime.Equal(b.modTime) { + return false + } + } + return true +} + +// Run blocks until ctx is canceled. It polls Path on Interval and invokes +// OnChange whenever the directory's YAML file set changes. +// +// Policy mirrors the single-file Watcher: disappearance (directory missing or +// empty) is treated as a transient rename-style write and stays quiet; the +// transition back to present-with-content fires OnChange. +func (w *DirWatcher) Run(ctx context.Context) { + interval := w.Interval + if interval <= 0 { + interval = DefaultInterval + } + + prev := scanDir(w.Path) + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + cur := scanDir(w.Path) + // Suppress transitions involving an empty or missing directory — + // these are treated as transient rename-style writes, mirroring + // the single-file Watcher. Only present-with-content → + // present-with-content (changed) or no-content → + // present-with-content fires OnChange. + prevHasContent := prev.exists && len(prev.names) > 0 + curHasContent := cur.exists && len(cur.names) > 0 + if curHasContent && (!prevHasContent || !prev.equal(cur)) && w.OnChange != nil { + w.OnChange() + } + prev = cur + } + } +} + +// scanDir returns a snapshot of the *.yml/*.yaml files in dir. If the +// directory cannot be read (missing, permission denied) the snapshot reports +// exists=false; the next successful scan will detect the recovery and fire +// OnChange. +func scanDir(dir string) dirSnapshot { + snap := newDirSnapshot() + entries, err := os.ReadDir(dir) + if err != nil { + return snap // exists=false + } + snap.exists = true + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if !strings.HasSuffix(name, ".yml") && !strings.HasSuffix(name, ".yaml") { + continue + } + fi, err := os.Stat(filepath.Join(dir, name)) + if err != nil { + // File disappeared between ReadDir and Stat; skip it — the + // next poll will observe the removal cleanly. + continue + } + snap.names = append(snap.names, name) + snap.states[name] = snapshot{ + exists: true, + modTime: fi.ModTime(), + size: fi.Size(), + } + } + sort.Strings(snap.names) + return snap +} diff --git a/internal/watcher/dirwatcher_test.go b/internal/watcher/dirwatcher_test.go new file mode 100644 index 00000000..3c619a0e --- /dev/null +++ b/internal/watcher/dirwatcher_test.go @@ -0,0 +1,199 @@ +package configwatcher + +import ( + "context" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// startDirWatcher launches w.Run in a goroutine and returns a function that +// cancels the context and waits for Run to return. +func startDirWatcher(t *testing.T, w *DirWatcher) func() { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + w.Run(ctx) + close(done) + }() + return func() { + cancel() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("DirWatcher did not stop within 2s of cancel") + } + } +} + +func writeYAMLInDir(t *testing.T, dir, name, content string) { + t.Helper() + require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0o644)) +} + +func TestDirWatcher_NoFireOnBaseline(t *testing.T) { + dir := t.TempDir() + writeYAMLInDir(t, dir, "a.yaml", "a") + + var n int64 + stop := startDirWatcher(t, &DirWatcher{ + Path: dir, + Interval: testInterval, + OnChange: func() { atomic.AddInt64(&n, 1) }, + }) + defer stop() + + time.Sleep(testInterval * 5) + require.Equal(t, int64(0), atomic.LoadInt64(&n), "baseline poll must not fire") +} + +func TestDirWatcher_DetectsFileAdd(t *testing.T) { + dir := t.TempDir() + writeYAMLInDir(t, dir, "a.yaml", "a") + + var n int64 + stop := startDirWatcher(t, &DirWatcher{ + Path: dir, + Interval: testInterval, + OnChange: func() { atomic.AddInt64(&n, 1) }, + }) + defer stop() + time.Sleep(testInterval * 2) + + writeYAMLInDir(t, dir, "b.yaml", "b") + require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is added") +} + +func TestDirWatcher_DetectsFileRemoval(t *testing.T) { + dir := t.TempDir() + writeYAMLInDir(t, dir, "a.yaml", "a") + writeYAMLInDir(t, dir, "b.yaml", "b") + + var n int64 + stop := startDirWatcher(t, &DirWatcher{ + Path: dir, + Interval: testInterval, + OnChange: func() { atomic.AddInt64(&n, 1) }, + }) + defer stop() + time.Sleep(testInterval * 2) + + require.NoError(t, os.Remove(filepath.Join(dir, "b.yaml"))) + require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when a file is removed") +} + +func TestDirWatcher_DetectsModTimeChange(t *testing.T) { + dir := t.TempDir() + writeYAMLInDir(t, dir, "a.yaml", "a") + + base := time.Now().Add(-1 * time.Hour).Truncate(time.Second) + require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base, base)) + + var n int64 + stop := startDirWatcher(t, &DirWatcher{ + Path: dir, + Interval: testInterval, + OnChange: func() { atomic.AddInt64(&n, 1) }, + }) + defer stop() + time.Sleep(testInterval * 2) + + require.NoError(t, os.Chtimes(filepath.Join(dir, "a.yaml"), base.Add(10*time.Second), base.Add(10*time.Second))) + require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire after mtime change") +} + +func TestDirWatcher_IgnoresNonYAMLFiles(t *testing.T) { + dir := t.TempDir() + writeYAMLInDir(t, dir, "a.yaml", "a") + + var n int64 + stop := startDirWatcher(t, &DirWatcher{ + Path: dir, + Interval: testInterval, + OnChange: func() { atomic.AddInt64(&n, 1) }, + }) + defer stop() + time.Sleep(testInterval * 2) + + // Adding a .txt file must not fire. + require.NoError(t, os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0o644)) + time.Sleep(testInterval * 4) + require.Equal(t, int64(0), atomic.LoadInt64(&n), "non-YAML files must be ignored") + + // Adding a .yml file must fire. + writeYAMLInDir(t, dir, "b.yml", "b") + require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire for *.yml files") +} + +func TestDirWatcher_MissingDirRecovers(t *testing.T) { + dir := t.TempDir() + writeYAMLInDir(t, dir, "a.yaml", "a") + + var n int64 + stop := startDirWatcher(t, &DirWatcher{ + Path: dir, + Interval: testInterval, + OnChange: func() { atomic.AddInt64(&n, 1) }, + }) + defer stop() + time.Sleep(testInterval * 2) + + // Remove the directory. No fire expected on disappearance alone. + require.NoError(t, os.RemoveAll(dir)) + time.Sleep(testInterval * 3) + require.Equal(t, int64(0), atomic.LoadInt64(&n), "directory removal alone must not fire") + + // Recreate the directory and a YAML file; the recovery should fire. + require.NoError(t, os.MkdirAll(dir, 0o755)) + writeYAMLInDir(t, dir, "recovered.yaml", "r") + require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when dir returns with content") +} + +func TestDirWatcher_EmptyDirSuppressedThenRecovers(t *testing.T) { + // Present-with-content → empty (all YAML removed, dir still exists) + // must stay quiet — treated as transient per the documented policy. + // The transition back to content fires. + dir := t.TempDir() + writeYAMLInDir(t, dir, "a.yaml", "a") + + var n int64 + stop := startDirWatcher(t, &DirWatcher{ + Path: dir, + Interval: testInterval, + OnChange: func() { atomic.AddInt64(&n, 1) }, + }) + defer stop() + time.Sleep(testInterval * 2) + + // Remove the only YAML file. Dir still exists but is empty of YAML. + require.NoError(t, os.Remove(filepath.Join(dir, "a.yaml"))) + time.Sleep(testInterval * 4) + require.Equal(t, int64(0), atomic.LoadInt64(&n), "emptying the directory must not fire") + + // Add a YAML file back; transition to present-with-content fires. + writeYAMLInDir(t, dir, "c.yaml", "c") + require.True(t, waitForCount(t, &n, 1, time.Second), "callback should fire when content returns") +} + +func TestDirWatcher_ContextCancelStopsRun(t *testing.T) { + dir := t.TempDir() + writeYAMLInDir(t, dir, "a.yaml", "a") + + w := &DirWatcher{Path: dir, Interval: testInterval} + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { w.Run(ctx); close(done) }() + + time.Sleep(testInterval * 2) + cancel() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Run did not return within 2s of cancel") + } +} diff --git a/llama-swap.go b/llama-swap.go index 69a4697c..21ec97e8 100644 --- a/llama-swap.go +++ b/llama-swap.go @@ -55,7 +55,8 @@ var logTimeFormats = map[string]string{ } func main() { - flagConfig := flag.String("config", "", "path to config file (required)") + flagConfig := flag.String("config", "", "path to config file") + flagConfigDir := flag.String("config-dir", "", "directory of *.yml/*.yaml config files (additive to -config)") flagListen := flag.String("listen", "", "listen address (default :8080 or :8443 for TLS)") flagCertFile := flag.String("tls-cert-file", "", "TLS certificate file") flagKeyFile := flag.String("tls-key-file", "", "TLS key file") @@ -68,8 +69,8 @@ func main() { os.Exit(0) } - if *flagConfig == "" { - slog.Error("-config is required") + if *flagConfig == "" && *flagConfigDir == "" { + slog.Error("at least one of -config or -config-dir must be provided") os.Exit(1) } @@ -88,10 +89,9 @@ func main() { } } - configPath := *flagConfig - cfg, err := config.LoadConfig(configPath) + cfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir) if err != nil { - slog.Error("failed to load config", "path", configPath, "error", err) + slog.Error("failed to load config", "config", *flagConfig, "config-dir", *flagConfigDir, "error", err) os.Exit(1) } @@ -187,7 +187,7 @@ func main() { proxyLog.Info("reloading configuration") - newCfg, err := config.LoadConfig(configPath) + newCfg, err := config.LoadConfigSources(*flagConfig, *flagConfigDir) if err != nil { proxyLog.Warnf("failed to reload config: %v", err) return @@ -230,19 +230,37 @@ func main() { defer watcherCancel() if *flagWatchConfig { - absConfigPath, err := filepath.Abs(configPath) - if err != nil { - slog.Error("watch-config: failed to resolve config path", "error", err) - os.Exit(1) - } proxyLog.Info("watching configuration for changes (poll-based, 2s interval)") - go func() { - (&configwatcher.Watcher{ - Path: absConfigPath, - Interval: configwatcher.DefaultInterval, - OnChange: reload, - }).Run(watcherCtx) - }() + + if *flagConfig != "" { + absConfigPath, err := filepath.Abs(*flagConfig) + if err != nil { + slog.Error("watch-config: failed to resolve config path", "error", err) + os.Exit(1) + } + go func() { + (&configwatcher.Watcher{ + Path: absConfigPath, + Interval: configwatcher.DefaultInterval, + OnChange: reload, + }).Run(watcherCtx) + }() + } + + if *flagConfigDir != "" { + absConfigDir, err := filepath.Abs(*flagConfigDir) + if err != nil { + slog.Error("watch-config: failed to resolve config-dir path", "error", err) + os.Exit(1) + } + go func() { + (&configwatcher.DirWatcher{ + Path: absConfigDir, + Interval: configwatcher.DefaultInterval, + OnChange: reload, + }).Run(watcherCtx) + }() + } } sigChan := make(chan os.Signal, 1)