proxy,proxy/config: add global TTL feature (#554)
Add a new configuration parameter globalTTL that all models will inherit. The default value is 0 which matches the currently functionality to never automatically unload a model. The model.ttl's default has changed to -1, which means use the global TTL value. Any model.ttl >=0 is now value with 0 meaning never unload. This allows a model to override a globalTTL > 0 and be configured to never unload. Fixes #459 Closes #512
This commit is contained in:
+10
-4
@@ -48,6 +48,12 @@
|
||||
"default": 120,
|
||||
"description": "Number of seconds to wait for a model to be ready to serve requests."
|
||||
},
|
||||
"globalTTL": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Default TTL for all models in seconds, 0 means no TTL and models will never be automatically unloaded"
|
||||
},
|
||||
"logLevel": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -177,9 +183,9 @@
|
||||
},
|
||||
"ttl": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable."
|
||||
"minimum": -1,
|
||||
"default": -1,
|
||||
"description": "Automatically unload the model after ttl seconds. -1 uses the global TTL value, 0 disables unloading. Must be >0 to enable."
|
||||
},
|
||||
"useModelName": {
|
||||
"type": "string",
|
||||
@@ -368,4 +374,4 @@
|
||||
"description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+9
-2
@@ -75,6 +75,11 @@ sendLoadingState: true
|
||||
# all fields except for Id so chat UIs can use the alias equivalent to the original.
|
||||
includeAliasesInList: false
|
||||
|
||||
# globalTTL: the default TTL in seconds before unloading a model
|
||||
# - optional, default: 0 (never automatically unload)
|
||||
# - must be >= 0
|
||||
globalTTL: 0
|
||||
|
||||
# macros: a dictionary of string substitutions
|
||||
# - optional, default: empty dictionary
|
||||
# - macros are reusable snippets
|
||||
@@ -180,8 +185,10 @@ models:
|
||||
checkEndpoint: /custom-endpoint
|
||||
|
||||
# ttl: automatically unload the model after ttl seconds
|
||||
# - optional, default: 0
|
||||
# - ttl values must be a value greater than 0
|
||||
# - optional, default: -1 (use global default)
|
||||
# - ttl values must be a value greater than or equal to 0
|
||||
# - a ttl of -1 will use the global TTL value as the default
|
||||
# - a ttl of 0 will mean never unload
|
||||
# - a value of 0 disables automatic unloading of the model
|
||||
ttl: 60
|
||||
|
||||
|
||||
@@ -124,6 +124,7 @@ type Config struct {
|
||||
LogToStdout string `yaml:"logToStdout"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
CaptureBuffer int `yaml:"captureBuffer"`
|
||||
GlobalTTL int `yaml:"globalTTL"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
@@ -203,6 +204,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
LogToStdout: LogToStdoutProxy,
|
||||
MetricsMaxInMemory: 1000,
|
||||
CaptureBuffer: 5,
|
||||
GlobalTTL: 0,
|
||||
}
|
||||
if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil {
|
||||
return Config{}, err
|
||||
@@ -216,6 +218,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
return Config{}, fmt.Errorf("startPort must be greater than 1")
|
||||
}
|
||||
|
||||
if config.GlobalTTL < 0 {
|
||||
return Config{}, fmt.Errorf("globalTTL must be >= 0")
|
||||
}
|
||||
|
||||
switch config.LogToStdout {
|
||||
case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone:
|
||||
default:
|
||||
@@ -255,6 +261,15 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
modelConfig.Cmd = StripComments(modelConfig.Cmd)
|
||||
modelConfig.CmdStop = StripComments(modelConfig.CmdStop)
|
||||
|
||||
// set model TTL to globalTTL it is the default value
|
||||
if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL {
|
||||
modelConfig.UnloadAfter = config.GlobalTTL
|
||||
}
|
||||
|
||||
if modelConfig.UnloadAfter < 0 {
|
||||
return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter)
|
||||
}
|
||||
|
||||
// Validate model macros
|
||||
for _, macro := range modelConfig.Macros {
|
||||
if err = validateMacro(macro.Name, macro.Value); err != nil {
|
||||
|
||||
@@ -848,6 +848,71 @@ func TestConfig_APIKeys_EnvMacros(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_GlobalTTL(t *testing.T) {
|
||||
t.Run("globalTTL sets default for models", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: 300
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 300, config.GlobalTTL)
|
||||
assert.Equal(t, 300, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("model ttl=0 overrides globalTTL", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: 300
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
ttl: 0
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("model explicit ttl overrides globalTTL", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: 300
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
ttl: 600
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 600, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("globalTTL defaults to 0", func(t *testing.T) {
|
||||
content := `
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, config.GlobalTTL)
|
||||
assert.Equal(t, 0, config.Models["model1"].UnloadAfter)
|
||||
})
|
||||
|
||||
t.Run("negative globalTTL rejected", func(t *testing.T) {
|
||||
content := `
|
||||
globalTTL: -1
|
||||
models:
|
||||
model1:
|
||||
cmd: server --port ${PORT}
|
||||
`
|
||||
_, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "globalTTL must be >= 0")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_EnvMacros(t *testing.T) {
|
||||
t.Run("basic env substitution in cmd", func(t *testing.T) {
|
||||
t.Setenv("TEST_MODEL_PATH", "/opt/models")
|
||||
|
||||
@@ -5,6 +5,10 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
MODEL_CONFIG_DEFAULT_TTL = -1
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
@@ -47,7 +51,7 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/health",
|
||||
UnloadAfter: 0,
|
||||
UnloadAfter: MODEL_CONFIG_DEFAULT_TTL, // use GlobalTTL
|
||||
Unlisted: false,
|
||||
UseModelName: "",
|
||||
ConcurrencyLimit: 0,
|
||||
|
||||
+10
-10
@@ -117,12 +117,12 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||
}
|
||||
|
||||
expectedMessage := "I_sense_imminent_danger"
|
||||
config := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, config.UnloadAfter)
|
||||
conf := getTestSimpleResponderConfig(expectedMessage)
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 3 // seconds
|
||||
assert.Equal(t, 3, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
// this should take 4 seconds
|
||||
@@ -159,12 +159,12 @@ func TestProcess_LowTTLValue(t *testing.T) {
|
||||
t.Skip("skipping test, edit process_test.go to run it ")
|
||||
}
|
||||
|
||||
config := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, 0, config.UnloadAfter)
|
||||
config.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, config.UnloadAfter)
|
||||
conf := getTestSimpleResponderConfig("fast_ttl")
|
||||
assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter)
|
||||
conf.UnloadAfter = 1 // second
|
||||
assert.Equal(t, 1, conf.UnloadAfter)
|
||||
|
||||
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
|
||||
process := NewProcess("ttl", 2, conf, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
|
||||
@@ -730,7 +730,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Verify extended fields are present
|
||||
assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated")
|
||||
assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated")
|
||||
assert.Equal(t, 0, response.Running[0].TTL, "ttl should default to 0")
|
||||
assert.Equal(t, -1, response.Running[0].TTL, "ttl should default to -1 (use globalTTL)")
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user