diff --git a/config-schema.json b/config-schema.json index 58613ced..87cde486 100644 --- a/config-schema.json +++ b/config-schema.json @@ -48,6 +48,12 @@ "default": 120, "description": "Number of seconds to wait for a model to be ready to serve requests." }, + "globalTTL": { + "type": "integer", + "minimum": 0, + "default": 0, + "description": "Default TTL for all models in seconds, 0 means no TTL and models will never be automatically unloaded" + }, "logLevel": { "type": "string", "enum": [ @@ -177,9 +183,9 @@ }, "ttl": { "type": "integer", - "minimum": 0, - "default": 0, - "description": "Automatically unload the model after ttl seconds. 0 disables unloading. Must be >0 to enable." + "minimum": -1, + "default": -1, + "description": "Automatically unload the model after ttl seconds. -1 uses the global TTL value, 0 disables unloading. Must be >0 to enable." }, "useModelName": { "type": "string", @@ -368,4 +374,4 @@ "description": "A dictionary of remote peers and models they provide. Peers can be another llama-swap or any server that provides the /v1/ generative API endpoints supported by llama-swap." } } -} +} \ No newline at end of file diff --git a/config.example.yaml b/config.example.yaml index cc076fa6..35f74c12 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -75,6 +75,11 @@ sendLoadingState: true # all fields except for Id so chat UIs can use the alias equivalent to the original. includeAliasesInList: false +# globalTTL: the default TTL in seconds before unloading a model +# - optional, default: 0 (never automatically unload) +# - must be >= 0 +globalTTL: 0 + # macros: a dictionary of string substitutions # - optional, default: empty dictionary # - macros are reusable snippets @@ -180,8 +185,10 @@ models: checkEndpoint: /custom-endpoint # ttl: automatically unload the model after ttl seconds - # - optional, default: 0 - # - ttl values must be a value greater than 0 + # - optional, default: -1 (use global default) + # - ttl values must be a value greater than or equal to 0 + # - a ttl of -1 will use the global TTL value as the default + # - a ttl of 0 will mean never unload # - a value of 0 disables automatic unloading of the model ttl: 60 diff --git a/proxy/config/config.go b/proxy/config/config.go index 4d1e6818..c474c089 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -124,6 +124,7 @@ type Config struct { LogToStdout string `yaml:"logToStdout"` MetricsMaxInMemory int `yaml:"metricsMaxInMemory"` CaptureBuffer int `yaml:"captureBuffer"` + GlobalTTL int `yaml:"globalTTL"` Models map[string]ModelConfig `yaml:"models"` /* key is model ID */ Profiles map[string][]string `yaml:"profiles"` Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */ @@ -203,6 +204,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { LogToStdout: LogToStdoutProxy, MetricsMaxInMemory: 1000, CaptureBuffer: 5, + GlobalTTL: 0, } if err = yaml.Unmarshal([]byte(yamlStr), &config); err != nil { return Config{}, err @@ -216,6 +218,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { return Config{}, fmt.Errorf("startPort must be greater than 1") } + if config.GlobalTTL < 0 { + return Config{}, fmt.Errorf("globalTTL must be >= 0") + } + switch config.LogToStdout { case LogToStdoutProxy, LogToStdoutUpstream, LogToStdoutBoth, LogToStdoutNone: default: @@ -255,6 +261,15 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { modelConfig.Cmd = StripComments(modelConfig.Cmd) modelConfig.CmdStop = StripComments(modelConfig.CmdStop) + // set model TTL to globalTTL it is the default value + if modelConfig.UnloadAfter == MODEL_CONFIG_DEFAULT_TTL { + modelConfig.UnloadAfter = config.GlobalTTL + } + + if modelConfig.UnloadAfter < 0 { + return Config{}, fmt.Errorf("model %s: invalid TTL value %d", modelId, modelConfig.UnloadAfter) + } + // Validate model macros for _, macro := range modelConfig.Macros { if err = validateMacro(macro.Name, macro.Value); err != nil { diff --git a/proxy/config/config_test.go b/proxy/config/config_test.go index 49bbdc9f..2ea8e460 100644 --- a/proxy/config/config_test.go +++ b/proxy/config/config_test.go @@ -848,6 +848,71 @@ func TestConfig_APIKeys_EnvMacros(t *testing.T) { }) } +func TestConfig_GlobalTTL(t *testing.T) { + t.Run("globalTTL sets default for models", func(t *testing.T) { + content := ` +globalTTL: 300 +models: + model1: + cmd: server --port ${PORT} +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, 300, config.GlobalTTL) + assert.Equal(t, 300, config.Models["model1"].UnloadAfter) + }) + + t.Run("model ttl=0 overrides globalTTL", func(t *testing.T) { + content := ` +globalTTL: 300 +models: + model1: + cmd: server --port ${PORT} + ttl: 0 +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, 0, config.Models["model1"].UnloadAfter) + }) + + t.Run("model explicit ttl overrides globalTTL", func(t *testing.T) { + content := ` +globalTTL: 300 +models: + model1: + cmd: server --port ${PORT} + ttl: 600 +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, 600, config.Models["model1"].UnloadAfter) + }) + + t.Run("globalTTL defaults to 0", func(t *testing.T) { + content := ` +models: + model1: + cmd: server --port ${PORT} +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, 0, config.GlobalTTL) + assert.Equal(t, 0, config.Models["model1"].UnloadAfter) + }) + + t.Run("negative globalTTL rejected", func(t *testing.T) { + content := ` +globalTTL: -1 +models: + model1: + cmd: server --port ${PORT} +` + _, err := LoadConfigFromReader(strings.NewReader(content)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "globalTTL must be >= 0") + }) +} + func TestConfig_EnvMacros(t *testing.T) { t.Run("basic env substitution in cmd", func(t *testing.T) { t.Setenv("TEST_MODEL_PATH", "/opt/models") diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go index 9dc37aea..685687ba 100644 --- a/proxy/config/model_config.go +++ b/proxy/config/model_config.go @@ -5,6 +5,10 @@ import ( "runtime" ) +const ( + MODEL_CONFIG_DEFAULT_TTL = -1 +) + type ModelConfig struct { Cmd string `yaml:"cmd"` CmdStop string `yaml:"cmdStop"` @@ -47,7 +51,7 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { Aliases: []string{}, Env: []string{}, CheckEndpoint: "/health", - UnloadAfter: 0, + UnloadAfter: MODEL_CONFIG_DEFAULT_TTL, // use GlobalTTL Unlisted: false, UseModelName: "", ConcurrencyLimit: 0, diff --git a/proxy/process_test.go b/proxy/process_test.go index 3881c3dd..dd9e9d8a 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -117,12 +117,12 @@ func TestProcess_UnloadAfterTTL(t *testing.T) { } expectedMessage := "I_sense_imminent_danger" - config := getTestSimpleResponderConfig(expectedMessage) - assert.Equal(t, 0, config.UnloadAfter) - config.UnloadAfter = 3 // seconds - assert.Equal(t, 3, config.UnloadAfter) + conf := getTestSimpleResponderConfig(expectedMessage) + assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter) + conf.UnloadAfter = 3 // seconds + assert.Equal(t, 3, conf.UnloadAfter) - process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) + process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger) defer process.Stop() // this should take 4 seconds @@ -159,12 +159,12 @@ func TestProcess_LowTTLValue(t *testing.T) { t.Skip("skipping test, edit process_test.go to run it ") } - config := getTestSimpleResponderConfig("fast_ttl") - assert.Equal(t, 0, config.UnloadAfter) - config.UnloadAfter = 1 // second - assert.Equal(t, 1, config.UnloadAfter) + conf := getTestSimpleResponderConfig("fast_ttl") + assert.Equal(t, config.MODEL_CONFIG_DEFAULT_TTL, conf.UnloadAfter) + conf.UnloadAfter = 1 // second + assert.Equal(t, 1, conf.UnloadAfter) - process := NewProcess("ttl", 2, config, debugLogger, debugLogger) + process := NewProcess("ttl", 2, conf, debugLogger, debugLogger) defer process.Stop() for i := 0; i < 100; i++ { diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index a147e5ea..e3a42bfe 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -730,7 +730,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { // Verify extended fields are present assert.NotEmpty(t, response.Running[0].Cmd, "cmd should be populated") assert.NotEmpty(t, response.Running[0].Proxy, "proxy should be populated") - assert.Equal(t, 0, response.Running[0].TTL, "ttl should default to 0") + assert.Equal(t, -1, response.Running[0].TTL, "ttl should default to -1 (use globalTTL)") }) }