proxy: unify filtering for local models and peers

This unifies the filtering capabilities for models and peers

- stripParams: removes params in the request
- setParams: sets params in the request

fixes #453
This commit is contained in:
Benson Wong
2026-01-15 18:59:43 -08:00
committed by GitHub
parent 3edb180c08
commit eb5bfff0b0
10 changed files with 448 additions and 33 deletions
+31 -1
View File
@@ -185,7 +185,7 @@ models:
# filters: a dictionary of filter settings
# - optional, default: empty dictionary
# - only stripParams is currently supported
# - same capabilities as peer filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
@@ -195,6 +195,16 @@ models:
# - recommended to stick to sampling parameters
stripParams: "temperature, top_p, top_k"
# setParams: a dictionary of parameters to set/override in requests
# - optional, default: empty dictionary
# - useful for enforcing specific parameter values
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
setParams:
# Example: enforce specific sampling parameters
temperature: 0.7
top_p: 0.9
# metadata: a dictionary of arbitrary values that are included in /v1/models
# - optional, default: empty dictionary
# - while metadata can contains complex types it is recommended to keep it simple
@@ -373,3 +383,23 @@ peers:
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
# filters: a dictionary of filter settings for peer requests
# - optional, default: empty dictionary
# - same capabilities as model filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for removing parameters that the peer doesn't support
# - the `model` parameter can never be removed
stripParams: "temperature, top_p"
# setParams: a dictionary of parameters to set/override in requests to this peer
# - optional, default: empty dictionary
# - useful for injecting provider-specific settings like data retention policies
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
setParams:
# Example: enforce zero-data-retention for OpenRouter
provider:
data_collection: "deny"
allow_fallbacks: false
+81
View File
@@ -0,0 +1,81 @@
package config
import (
"slices"
"sort"
"strings"
)
// ProtectedParams is a list of parameters that cannot be set or stripped via filters
// These are protected to prevent breaking the proxy's ability to route requests correctly
var ProtectedParams = []string{"model"}
// Filters contains filter settings for modifying request parameters
// Used by both models and peers
type Filters struct {
// StripParams is a comma-separated list of parameters to remove from requests
// The "model" parameter can never be removed
StripParams string `yaml:"stripParams"`
// SetParams is a dictionary of parameters to set/override in requests
// Protected params (like "model") cannot be set
SetParams map[string]any `yaml:"setParams"`
}
// SanitizedStripParams returns a sorted list of parameters to strip,
// with duplicates, empty strings, and protected params removed
func (f Filters) SanitizedStripParams() []string {
if f.StripParams == "" {
return nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
seen := make(map[string]bool)
for _, param := range params {
trimmed := strings.TrimSpace(param)
// Skip protected params, empty strings, and duplicates
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
continue
}
seen[trimmed] = true
cleaned = append(cleaned, trimmed)
}
if len(cleaned) == 0 {
return nil
}
slices.Sort(cleaned)
return cleaned
}
// SanitizedSetParams returns a copy of SetParams with protected params removed
// and keys sorted for consistent iteration order
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
if len(f.SetParams) == 0 {
return nil, nil
}
result := make(map[string]any, len(f.SetParams))
keys := make([]string, 0, len(f.SetParams))
for key, value := range f.SetParams {
// Skip protected params
if slices.Contains(ProtectedParams, key) {
continue
}
result[key] = value
keys = append(keys, key)
}
// Sort keys for consistent ordering
sort.Strings(keys)
if len(result) == 0 {
return nil, nil
}
return result, keys
}
+168
View File
@@ -0,0 +1,168 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFilters_SanitizedStripParams(t *testing.T) {
tests := []struct {
name string
stripParams string
want []string
}{
{
name: "empty string",
stripParams: "",
want: nil,
},
{
name: "single param",
stripParams: "temperature",
want: []string{"temperature"},
},
{
name: "multiple params",
stripParams: "temperature, top_p, top_k",
want: []string{"temperature", "top_k", "top_p"}, // sorted
},
{
name: "model param filtered",
stripParams: "model, temperature, top_p",
want: []string{"temperature", "top_p"},
},
{
name: "only model param",
stripParams: "model",
want: nil,
},
{
name: "duplicates removed",
stripParams: "temperature, top_p, temperature",
want: []string{"temperature", "top_p"},
},
{
name: "extra whitespace",
stripParams: " temperature , top_p ",
want: []string{"temperature", "top_p"},
},
{
name: "empty values filtered",
stripParams: "temperature,,top_p,",
want: []string{"temperature", "top_p"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{StripParams: tt.stripParams}
got := f.SanitizedStripParams()
assert.Equal(t, tt.want, got)
})
}
}
func TestFilters_SanitizedSetParams(t *testing.T) {
tests := []struct {
name string
setParams map[string]any
wantParams map[string]any
wantKeys []string
}{
{
name: "empty setParams",
setParams: nil,
wantParams: nil,
wantKeys: nil,
},
{
name: "empty map",
setParams: map[string]any{},
wantParams: nil,
wantKeys: nil,
},
{
name: "normal params",
setParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantKeys: []string{"temperature", "top_p"},
},
{
name: "protected model param filtered",
setParams: map[string]any{
"model": "should-be-filtered",
"temperature": 0.7,
},
wantParams: map[string]any{
"temperature": 0.7,
},
wantKeys: []string{"temperature"},
},
{
name: "only protected param",
setParams: map[string]any{
"model": "should-be-filtered",
},
wantParams: nil,
wantKeys: nil,
},
{
name: "complex nested values",
setParams: map[string]any{
"provider": map[string]any{
"data_collection": "deny",
"allow_fallbacks": false,
},
"transforms": []string{"middle-out"},
},
wantParams: map[string]any{
"provider": map[string]any{
"data_collection": "deny",
"allow_fallbacks": false,
},
"transforms": []string{"middle-out"},
},
wantKeys: []string{"provider", "transforms"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{SetParams: tt.setParams}
gotParams, gotKeys := f.SanitizedSetParams()
assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
for i, key := range gotKeys {
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
}
if tt.wantParams == nil {
assert.Nil(t, gotParams, "expected nil params")
return
}
assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
for key, wantValue := range tt.wantParams {
gotValue, exists := gotParams[key]
assert.True(t, exists, "missing key: %s", key)
// Simple comparison for basic types
switch v := wantValue.(type) {
case string, int, float64, bool:
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
}
}
})
}
}
func TestProtectedParams(t *testing.T) {
// Verify that "model" is protected
assert.Contains(t, ProtectedParams, "model")
}
+7 -27
View File
@@ -3,8 +3,6 @@ package config
import (
"errors"
"runtime"
"slices"
"strings"
)
type ModelConfig struct {
@@ -74,16 +72,15 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}
// ModelFilters see issue #174
// ModelFilters embeds Filters and adds legacy support for strip_params field
// See issue #174
type ModelFilters struct {
StripParams string `yaml:"stripParams"`
Filters `yaml:",inline"`
}
func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
defaults := rawModelFilters{}
if err := unmarshal(&defaults); err != nil {
return err
@@ -104,25 +101,8 @@ func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
return nil
}
// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
// Returns ([]string, error) to match existing API
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}
params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
seen := make(map[string]bool)
for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" || seen[trimmed] {
continue
}
seen[trimmed] = true
cleaned = append(cleaned, trimmed)
}
// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
return f.Filters.SanitizedStripParams(), nil
}
+32
View File
@@ -72,3 +72,35 @@ models:
assert.True(t, *config.Models["model2"].SendLoadingState)
}
}
func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
stripParams: "top_k"
setParams:
temperature: 0.7
top_p: 0.9
stop:
- "<|end|>"
- "<|stop|>"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
modelConfig := config.Models["model1"]
// Check stripParams
stripParams, err := modelConfig.Filters.SanitizedStripParams()
assert.NoError(t, err)
assert.Equal(t, []string{"top_k"}, stripParams)
// Check setParams
setParams, keys := modelConfig.Filters.SanitizedSetParams()
assert.NotNil(t, setParams)
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
assert.Equal(t, 0.7, setParams["temperature"])
assert.Equal(t, 0.9, setParams["top_p"])
}
+5 -3
View File
@@ -11,14 +11,16 @@ type PeerConfig struct {
ProxyURL *url.URL `yaml:"-"`
ApiKey string `yaml:"apiKey"`
Models []string `yaml:"models"`
Filters Filters `yaml:"filters"`
}
func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawPeerConfig PeerConfig
defaults := rawPeerConfig{
Proxy: "",
ApiKey: "",
Models: []string{},
Proxy: "",
ApiKey: "",
Models: []string{},
Filters: Filters{},
}
if err := unmarshal(&defaults); err != nil {
+70
View File
@@ -137,3 +137,73 @@ func searchSubstring(s, substr string) bool {
}
return false
}
func TestPeerConfig_WithFilters(t *testing.T) {
yamlData := `
proxy: https://openrouter.ai/api
apiKey: sk-test
models:
- model_a
filters:
setParams:
temperature: 0.7
provider:
data_collection: deny
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if config.Filters.SetParams == nil {
t.Fatal("Filters.SetParams should not be nil")
}
if config.Filters.SetParams["temperature"] != 0.7 {
t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"])
}
provider, ok := config.Filters.SetParams["provider"].(map[string]any)
if !ok {
t.Fatal("provider should be a map")
}
if provider["data_collection"] != "deny" {
t.Errorf("expected data_collection deny, got %v", provider["data_collection"])
}
}
func TestPeerConfig_WithBothFilters(t *testing.T) {
yamlData := `
proxy: https://openrouter.ai/api
apiKey: sk-test
models:
- model_a
filters:
stripParams: "temperature, top_p"
setParams:
max_tokens: 1000
`
var config PeerConfig
err := yaml.Unmarshal([]byte(yamlData), &config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check stripParams
stripParams := config.Filters.SanitizedStripParams()
if len(stripParams) != 2 {
t.Errorf("expected 2 strip params, got %d", len(stripParams))
}
if stripParams[0] != "temperature" || stripParams[1] != "top_p" {
t.Errorf("unexpected strip params: %v", stripParams)
}
// Check setParams
if config.Filters.SetParams == nil {
t.Fatal("Filters.SetParams should not be nil")
}
if config.Filters.SetParams["max_tokens"] != 1000 {
t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"])
}
}
+14
View File
@@ -106,6 +106,20 @@ func (p *PeerProxy) HasPeerModel(modelID string) bool {
return found
}
// GetPeerFilters returns the filters for a peer model, or empty filters if not found
func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters {
pp, found := p.proxyMap[modelID]
if !found {
return config.Filters{}
}
// Get the peer config using the peerID
peer, found := p.peers[pp.peerID]
if !found {
return config.Filters{}
}
return peer.Filters
}
func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig {
return p.peers
}
+37 -1
View File
@@ -650,13 +650,49 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) {
}
}
// issue #453 set/override parameters in the JSON body
setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel)
nextHandler = processGroup.ProxyRequest
} else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) {
pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel)
modelID = requestedModel
nextHandler = pm.peerProxy.ProxyRequest
// issue #453 apply filters for peer requests
peerFilters := pm.peerProxy.GetPeerFilters(requestedModel)
// Apply stripParams - remove specified parameters from request
stripParams := peerFilters.SanitizedStripParams()
for _, param := range stripParams {
pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param)
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param))
return
}
}
// Apply setParams - set/override specified parameters in request
setParams, setParamKeys := peerFilters.SanitizedSetParams()
for _, key := range setParamKeys {
pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key)
bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key])
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key))
return
}
}
nextHandler = pm.peerProxy.ProxyRequest
}
if nextHandler == nil {
+3 -1
View File
@@ -966,7 +966,9 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
func TestProxyManager_FiltersStripParams(t *testing.T) {
modelConfig := getTestSimpleResponderConfig("model1")
modelConfig.Filters = config.ModelFilters{
StripParams: "temperature, model, stream",
Filters: config.Filters{
StripParams: "temperature, model, stream",
},
}
config := config.AddDefaultGroupToConfig(config.Config{