From 4e850c2834a38e029321d75a58e4c8b4bcf6f124 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Sun, 18 Jan 2026 21:52:34 -0800 Subject: [PATCH] config: refactor macro substitution in configuration (#470) This commit simplifies substitution of environment variables into the configuration. There was a lot of repetitive code substituting ${env.VAR_NAME} into different fields after the configuration was parsed into a config.Config. This refactor uses a string substitution of env vars into the YAML config before it is fully parsed. This eliminates a lot of logic while maintaining backwards compatibility. --- proxy/config/config.go | 238 +++++++++--------------------------- proxy/config/config_test.go | 68 ++++++++++- 2 files changed, 127 insertions(+), 179 deletions(-) diff --git a/proxy/config/config.go b/proxy/config/config.go index 019519a1..c4387f40 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -184,8 +184,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { if err != nil { return Config{}, err } + yamlStr := string(data) - // default configuration values + // Phase 1: Substitute all ${env.VAR} macros at string level + // This is safe because env values are simple strings without YAML formatting + yamlStr, err = substituteEnvMacros(yamlStr) + if err != nil { + return Config{}, err + } + + // Unmarshal into full Config with defaults config := Config{ HealthCheckTimeout: 120, StartPort: 5800, @@ -194,13 +202,11 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { LogToStdout: LogToStdoutProxy, MetricsMaxInMemory: 1000, } - err = yaml.Unmarshal(data, &config) - if err != nil { + if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil { return Config{}, err } if config.HealthCheckTimeout < 15 { - // set a minimum of 15 seconds config.HealthCheckTimeout = 15 } @@ -225,108 +231,46 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } - /* check macro constraint rules: - - - name must fit the regex ^[a-zA-Z0-9_-]+$ - - names must be less than 64 characters (no reason, just cause) - - name can not be any reserved macros: PORT, MODEL_ID - - macro values must be less than 1024 characters - */ + // Validate global macros for _, macro := range config.Macros { if err = validateMacro(macro.Name, macro.Value); err != nil { return Config{}, err } } - // Process environment variable macros in global macro values first - for i, macro := range config.Macros { - if strVal, ok := macro.Value.(string); ok { - newVal, err := substituteEnvMacros(strVal) - if err != nil { - return Config{}, fmt.Errorf("global macro '%s': %w", macro.Name, err) - } - config.Macros[i].Value = newVal - } - } - - // Get and sort all model IDs first, makes testing more consistent + // Get and sort all model IDs for consistent port assignment modelIds := make([]string, 0, len(config.Models)) for modelId := range config.Models { modelIds = append(modelIds, modelId) } - sort.Strings(modelIds) // This guarantees stable iteration order + sort.Strings(modelIds) nextPort := config.StartPort for _, modelId := range modelIds { modelConfig := config.Models[modelId] - // Strip comments from command fields before macro expansion + // Strip comments from command fields modelConfig.Cmd = StripComments(modelConfig.Cmd) modelConfig.CmdStop = StripComments(modelConfig.CmdStop) - // Substitute environment variable macros in model fields - modelConfig.Cmd, err = substituteEnvMacros(modelConfig.Cmd) - if err != nil { - return Config{}, fmt.Errorf("model %s cmd: %w", modelId, err) - } - modelConfig.CmdStop, err = substituteEnvMacros(modelConfig.CmdStop) - if err != nil { - return Config{}, fmt.Errorf("model %s cmdStop: %w", modelId, err) - } - modelConfig.Proxy, err = substituteEnvMacros(modelConfig.Proxy) - if err != nil { - return Config{}, fmt.Errorf("model %s proxy: %w", modelId, err) - } - modelConfig.CheckEndpoint, err = substituteEnvMacros(modelConfig.CheckEndpoint) - if err != nil { - return Config{}, fmt.Errorf("model %s checkEndpoint: %w", modelId, err) - } - modelConfig.Filters.StripParams, err = substituteEnvMacros(modelConfig.Filters.StripParams) - if err != nil { - return Config{}, fmt.Errorf("model %s filters.stripParams: %w", modelId, err) - } - - // Substitute env macros in model-level macro values - for i, macro := range modelConfig.Macros { - if strVal, ok := macro.Value.(string); ok { - newVal, err := substituteEnvMacros(strVal) - if err != nil { - return Config{}, fmt.Errorf("model %s macro '%s': %w", modelId, macro.Name, err) - } - modelConfig.Macros[i].Value = newVal - } - } - - // Substitute env macros in metadata - if len(modelConfig.Metadata) > 0 { - result, err := substituteEnvMacrosInValue(modelConfig.Metadata) - if err != nil { - return Config{}, fmt.Errorf("model %s metadata: %w", modelId, err) - } - modelConfig.Metadata = result.(map[string]any) - } - - // validate model macros + // Validate model macros for _, macro := range modelConfig.Macros { if err = validateMacro(macro.Name, macro.Value); err != nil { return Config{}, fmt.Errorf("model %s: %s", modelId, err.Error()) } } - // Merge global config and model macros. Model macros take precedence - mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)) + // Build merged macro list: MODEL_ID + global macros + model macros (model overrides global) + mergedMacros := make(MacroList, 0, len(config.Macros)+len(modelConfig.Macros)+1) mergedMacros = append(mergedMacros, MacroEntry{Name: "MODEL_ID", Value: modelId}) - - // Add global macros first mergedMacros = append(mergedMacros, config.Macros...) - // Add model macros (can override global) + // Add model macros (override globals with same name) for _, entry := range modelConfig.Macros { - // Remove any existing global macro with same name found := false for i, existing := range mergedMacros { if existing.Name == entry.Name { - mergedMacros[i] = entry // Override + mergedMacros[i] = entry found = true break } @@ -336,23 +280,20 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } - // First pass: Substitute user-defined macros in reverse order (LIFO - last defined first) - // This allows later macros to reference earlier ones + // Substitute remaining macros in model fields (LIFO order) for i := len(mergedMacros) - 1; i >= 0; i-- { entry := mergedMacros[i] macroSlug := fmt.Sprintf("${%s}", entry.Name) macroStr := fmt.Sprintf("%v", entry.Value) - // Substitute in command fields modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr) modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr) modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr) modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr) modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr) - // Substitute in metadata (recursive) + // Substitute in metadata (type-preserving) if len(modelConfig.Metadata) > 0 { - var err error result, err := substituteMacroInValue(modelConfig.Metadata, entry.Name, entry.Value) if err != nil { return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error()) @@ -361,18 +302,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } - // Final pass: check if PORT macro is needed after macro expansion - // ${PORT} is a resource on the local machine so a new port is only allocated - // if it is required in either cmd or proxy keys + // Handle PORT macro - only allocate if cmd uses it cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}") proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}") - if cmdHasPort || proxyHasPort { // either has it - if !cmdHasPort && proxyHasPort { // but both don't have it + if cmdHasPort || proxyHasPort { + if !cmdHasPort && proxyHasPort { return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId) } - // Add PORT macro and substitute it - portEntry := MacroEntry{Name: "PORT", Value: nextPort} macroSlug := "${PORT}" macroStr := fmt.Sprintf("%v", nextPort) @@ -380,10 +317,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr) modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr) - // Substitute PORT in metadata if len(modelConfig.Metadata) > 0 { - var err error - result, err := substituteMacroInValue(modelConfig.Metadata, portEntry.Name, portEntry.Value) + result, err := substituteMacroInValue(modelConfig.Metadata, "PORT", nextPort) if err != nil { return Config{}, fmt.Errorf("model %s metadata: %s", modelId, err.Error()) } @@ -393,7 +328,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { nextPort++ } - // make sure there are no unknown macros that have not been replaced + // Validate no unknown macros remain fieldMap := map[string]string{ "cmd": modelConfig.Cmd, "cmdStop": modelConfig.CmdStop, @@ -407,42 +342,27 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { for _, match := range matches { macroName := match[1] if macroName == "PID" && fieldName == "cmdStop" { - continue // this is ok, has to be replaced by process later + continue // replaced at runtime } - // Reserved macros are always valid (they should have been substituted already) if macroName == "PORT" || macroName == "MODEL_ID" { return Config{}, fmt.Errorf("macro '${%s}' should have been substituted in %s.%s", macroName, modelId, fieldName) } - // Any other macro is unknown return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName) } - - // Check for unsubstituted env macros - envMatches := envMacroRegex.FindAllStringSubmatch(fieldValue, -1) - for _, match := range envMatches { - varName := match[1] - return Config{}, fmt.Errorf("environment variable '%s' not set (found in %s.%s)", varName, modelId, fieldName) - } } - // Check for unknown macros in metadata if len(modelConfig.Metadata) > 0 { if err := validateNestedForUnknownMacros(modelConfig.Metadata, fmt.Sprintf("model %s metadata", modelId)); err != nil { return Config{}, err } } - // Validate the proxy URL. if _, err := url.Parse(modelConfig.Proxy); err != nil { - return Config{}, fmt.Errorf( - "model %s: invalid proxy URL: %w", modelId, err, - ) + return Config{}, fmt.Errorf("model %s: invalid proxy URL: %w", modelId, err) } - // if sendLoadingState is nil, set it to the global config value - // see #366 if modelConfig.SendLoadingState == nil { - v := config.SendLoadingState // copy it + v := config.SendLoadingState modelConfig.SendLoadingState = &v } @@ -450,18 +370,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } config = AddDefaultGroupToConfig(config) - // check that members are all unique in the groups - memberUsage := make(map[string]string) // maps member to group it appears in + + // Validate group members + memberUsage := make(map[string]string) for groupID, groupConfig := range config.Groups { prevSet := make(map[string]bool) for _, member := range groupConfig.Members { - // Check for duplicates within this group if _, found := prevSet[member]; found { return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID) } prevSet[member] = true - // Check if member is used in another group if existingGroup, exists := memberUsage[member]; exists { return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID) } @@ -469,7 +388,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } - // clean up hooks preload + // Clean up hooks preload if len(config.Hooks.OnStartup.Preload) > 0 { var toPreload []string for _, modelID := range config.Hooks.OnStartup.Preload { @@ -481,30 +400,23 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { toPreload = append(toPreload, real) } } - config.Hooks.OnStartup.Preload = toPreload } - // check api keys validity and substitute env macros + // Validate API keys (env macros already substituted at string level) for i, apikey := range config.RequiredAPIKeys { - apikey, err = substituteEnvMacros(apikey) - if err != nil { - return Config{}, fmt.Errorf("apiKeys[%d]: %w", i, err) - } - config.RequiredAPIKeys[i] = apikey - if apikey == "" { return Config{}, fmt.Errorf("empty api key found in apiKeys") } - if strings.Contains(apikey, " ") { return Config{}, fmt.Errorf("api key cannot contain spaces: `%s`", apikey) } + config.RequiredAPIKeys[i] = apikey } - // substitute macros and env macros in peer fields + // Process peers with global macro substitution for peerName, peerConfig := range config.Peers { - // Substitute global macros first (LIFO order like models) + // Substitute global macros (LIFO order) for i := len(config.Macros) - 1; i >= 0; i-- { entry := config.Macros[i] macroSlug := fmt.Sprintf("${%s}", entry.Name) @@ -513,7 +425,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { peerConfig.ApiKey = strings.ReplaceAll(peerConfig.ApiKey, macroSlug, macroStr) peerConfig.Filters.StripParams = strings.ReplaceAll(peerConfig.Filters.StripParams, macroSlug, macroStr) - // Substitute in setParams + // Substitute in setParams (type-preserving) if len(peerConfig.Filters.SetParams) > 0 { result, err := substituteMacroInValue(peerConfig.Filters.SetParams, entry.Name, entry.Value) if err != nil { @@ -523,25 +435,6 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } - // Substitute env macros - peerConfig.ApiKey, err = substituteEnvMacros(peerConfig.ApiKey) - if err != nil { - return Config{}, fmt.Errorf("peers.%s.apiKey: %w", peerName, err) - } - - peerConfig.Filters.StripParams, err = substituteEnvMacros(peerConfig.Filters.StripParams) - if err != nil { - return Config{}, fmt.Errorf("peers.%s.filters.stripParams: %w", peerName, err) - } - - if len(peerConfig.Filters.SetParams) > 0 { - result, err := substituteEnvMacrosInValue(peerConfig.Filters.SetParams) - if err != nil { - return Config{}, fmt.Errorf("peers.%s.filters.setParams: %w", peerName, err) - } - peerConfig.Filters.SetParams = result.(map[string]any) - } - // Validate no unknown macros remain if matches := macroPatternRegex.FindAllStringSubmatch(peerConfig.ApiKey, -1); len(matches) > 0 { return Config{}, fmt.Errorf("peers.%s.apiKey: unknown macro '${%s}'", peerName, matches[0][1]) @@ -554,7 +447,6 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { return Config{}, err } } - config.Peers[peerName] = peerConfig } @@ -776,7 +668,7 @@ func substituteMacroInValue(value any, macroName string, macroValue any) (any, e } // substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values -// Returns error if any env var is not set +// Returns error if any env var is not set or contains invalid characters func substituteEnvMacros(s string) (string, error) { result := s matches := envMacroRegex.FindAllStringSubmatch(s, -1) @@ -788,40 +680,32 @@ func substituteEnvMacros(s string) (string, error) { if !exists { return "", fmt.Errorf("environment variable '%s' is not set", varName) } + + // Sanitize the value for safe YAML substitution + value, err := sanitizeEnvValueForYAML(value, varName) + if err != nil { + return "", err + } + result = strings.ReplaceAll(result, fullMatch, value) } return result, nil } -// substituteEnvMacrosInValue recursively substitutes env macros in nested structures -func substituteEnvMacrosInValue(value any) (any, error) { - switch v := value.(type) { - case string: - return substituteEnvMacros(v) - - case map[string]any: - newMap := make(map[string]any) - for key, val := range v { - newVal, err := substituteEnvMacrosInValue(val) - if err != nil { - return nil, err - } - newMap[key] = newVal - } - return newMap, nil - - case []any: - newSlice := make([]any, len(v)) - for i, val := range v { - newVal, err := substituteEnvMacrosInValue(val) - if err != nil { - return nil, err - } - newSlice[i] = newVal - } - return newSlice, nil - - default: - return value, nil +// sanitizeEnvValueForYAML ensures an environment variable value is safe for YAML substitution. +// It rejects values with characters that break YAML structure and escapes quotes/backslashes +// for compatibility with double-quoted YAML strings. +func sanitizeEnvValueForYAML(value, varName string) (string, error) { + // Reject values that would break YAML structure regardless of quoting context + if strings.ContainsAny(value, "\n\r\x00") { + return "", fmt.Errorf("environment variable '%s' contains newlines or null bytes which are not allowed in YAML substitution", varName) } + + // Escape backslashes and double quotes for safe use in double-quoted YAML strings. + // In unquoted contexts, these escapes appear literally (harmless for most use cases). + // In double-quoted contexts, they are interpreted correctly. + value = strings.ReplaceAll(value, `\`, `\\`) + value = strings.ReplaceAll(value, `"`, `\"`) + + return value, nil } diff --git a/proxy/config/config_test.go b/proxy/config/config_test.go index c77b3a78..a19cbb56 100644 --- a/proxy/config/config_test.go +++ b/proxy/config/config_test.go @@ -834,7 +834,7 @@ func TestConfig_APIKeys_EnvMacros(t *testing.T) { content := `apiKeys: ["${env.NONEXISTENT_API_KEY}"]` _, err := LoadConfigFromReader(strings.NewReader(content)) assert.Error(t, err) - assert.Contains(t, err.Error(), "apiKeys[0]") + // With string-level env substitution, error only includes var name assert.Contains(t, err.Error(), "NONEXISTENT_API_KEY") }) @@ -1056,6 +1056,70 @@ models: assert.NoError(t, err) assert.Equal(t, "server --auth admin:secret", config.Models["test"].Cmd) }) + + t.Run("env value with newline is rejected", func(t *testing.T) { + t.Setenv("TEST_MULTILINE", "line1\nline2") + + content := ` +models: + test: + cmd: "server --config ${env.TEST_MULTILINE}" + proxy: "http://localhost:8080" +` + _, err := LoadConfigFromReader(strings.NewReader(content)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "TEST_MULTILINE") + assert.Contains(t, err.Error(), "newlines") + } + }) + + t.Run("env value with carriage return is rejected", func(t *testing.T) { + t.Setenv("TEST_CR", "line1\rline2") + + content := ` +models: + test: + cmd: "server --config ${env.TEST_CR}" + proxy: "http://localhost:8080" +` + _, err := LoadConfigFromReader(strings.NewReader(content)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "TEST_CR") + assert.Contains(t, err.Error(), "newlines") + } + }) + + t.Run("env value with quotes is escaped for YAML", func(t *testing.T) { + t.Setenv("TEST_QUOTED", `value with "quotes"`) + + content := ` +models: + test: + cmd: "server --arg \"${env.TEST_QUOTED}\"" + proxy: "http://localhost:8080" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + // Quotes are escaped before YAML parsing, then YAML unescapes them + // Final result preserves the original value with quotes + assert.Contains(t, config.Models["test"].Cmd, `"quotes"`) + }) + + t.Run("env value with backslash is escaped for YAML", func(t *testing.T) { + t.Setenv("TEST_BACKSLASH", `path\to\file`) + + content := ` +models: + test: + cmd: "server --path \"${env.TEST_BACKSLASH}\"" + proxy: "http://localhost:8080" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + // Backslashes are escaped before YAML parsing, then YAML unescapes them + // Final result preserves the original value with backslashes + assert.Contains(t, config.Models["test"].Cmd, `path\to\file`) + }) } func TestConfig_PeerApiKey_EnvMacros(t *testing.T) { @@ -1086,7 +1150,7 @@ peers: ` _, err := LoadConfigFromReader(strings.NewReader(content)) assert.Error(t, err) - assert.Contains(t, err.Error(), "peers.openrouter.apiKey") + // With string-level env substitution, error only includes var name assert.Contains(t, err.Error(), "NONEXISTENT_PEER_KEY") })