feat: add FOREMAN_KEEP_ALIVE config for worker model residency
Allow configuring how long the worker model stays resident on the Ollama
target after a request via FOREMAN_KEEP_ALIVE env var. Accepts Ollama
duration strings ("-1" forever, "0" unload, "15m", "1h", etc). Defaults
to "-1" (pin forever). The embedder warm-up is unaffected and always
uses keep_alive=-1.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -17,6 +17,13 @@ FOREMAN_TOKEN=change-me-to-a-secret
|
|||||||
# Always-resident embedding model (pinned in slot 1)
|
# Always-resident embedding model (pinned in slot 1)
|
||||||
FOREMAN_EMBED_MODEL=nomic-embed-text
|
FOREMAN_EMBED_MODEL=nomic-embed-text
|
||||||
|
|
||||||
|
# How long the worker model stays resident on the target after a request.
|
||||||
|
# Accepts Ollama duration strings: "-1" (forever/pin), "0" (unload immediately),
|
||||||
|
# "15m", "1h", "3600" (seconds), etc.
|
||||||
|
# Does NOT affect the embedder, which is always pinned with keep_alive=-1.
|
||||||
|
# Default: -1 (pin forever — best for a dedicated box)
|
||||||
|
FOREMAN_KEEP_ALIVE=-1
|
||||||
|
|
||||||
# === Persistence ===
|
# === Persistence ===
|
||||||
|
|
||||||
# SQLite database path (default: foreman.db)
|
# SQLite database path (default: foreman.db)
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func newTestForeman(t *testing.T, ollamaClient ollama.Client, webhookSecret stri
|
|||||||
|
|
||||||
notifier := worker.NewNotifier()
|
notifier := worker.NewNotifier()
|
||||||
dispatcher := webhook.NewDispatcher(webhookSecret, logger)
|
dispatcher := webhook.NewDispatcher(webhookSecret, logger)
|
||||||
w := worker.New(st, ollamaClient, inv, notifier, dispatcher, logger)
|
w := worker.New(st, ollamaClient, inv, notifier, dispatcher, logger, "-1")
|
||||||
|
|
||||||
cfg := config.Config{
|
cfg := config.Config{
|
||||||
OllamaURL: "http://localhost:11434",
|
OllamaURL: "http://localhost:11434",
|
||||||
@@ -260,7 +260,7 @@ func TestSubmit_AuthToken(t *testing.T) {
|
|||||||
|
|
||||||
notifier := worker.NewNotifier()
|
notifier := worker.NewNotifier()
|
||||||
dispatcher := webhook.NewDispatcher("", logger)
|
dispatcher := webhook.NewDispatcher("", logger)
|
||||||
w := worker.New(st, ollamaStub, inv, notifier, dispatcher, logger)
|
w := worker.New(st, ollamaStub, inv, notifier, dispatcher, logger, "-1")
|
||||||
|
|
||||||
cfg := config.Config{
|
cfg := config.Config{
|
||||||
OllamaURL: "http://localhost:11434",
|
OllamaURL: "http://localhost:11434",
|
||||||
|
|||||||
+2
-1
@@ -79,6 +79,7 @@ func runServe(logger *slog.Logger) error {
|
|||||||
"auth_enabled", cfg.Token != "",
|
"auth_enabled", cfg.Token != "",
|
||||||
"max_attempts", cfg.MaxAttempts,
|
"max_attempts", cfg.MaxAttempts,
|
||||||
"job_ttl", cfg.JobTTL,
|
"job_ttl", cfg.JobTTL,
|
||||||
|
"keep_alive", cfg.KeepAlive,
|
||||||
)
|
)
|
||||||
|
|
||||||
st, err := store.Open(cfg.DBPath)
|
st, err := store.Open(cfg.DBPath)
|
||||||
@@ -107,7 +108,7 @@ func runServe(logger *slog.Logger) error {
|
|||||||
|
|
||||||
// Create the notifier and worker.
|
// Create the notifier and worker.
|
||||||
notifier := worker.NewNotifier()
|
notifier := worker.NewNotifier()
|
||||||
w := worker.New(st, client, inventory, notifier, dispatcher, logger)
|
w := worker.New(st, client, inventory, notifier, dispatcher, logger, cfg.KeepAlive)
|
||||||
|
|
||||||
// Start the worker loop in a goroutine.
|
// Start the worker loop in a goroutine.
|
||||||
go w.Run(ctx)
|
go w.Run(ctx)
|
||||||
|
|||||||
@@ -48,6 +48,13 @@ type Config struct {
|
|||||||
// JobTTL is how long terminal jobs are retained before the pruner deletes them
|
// JobTTL is how long terminal jobs are retained before the pruner deletes them
|
||||||
// (default 24h).
|
// (default 24h).
|
||||||
JobTTL time.Duration
|
JobTTL time.Duration
|
||||||
|
|
||||||
|
// KeepAlive is the keep_alive value sent in outbound /api/chat requests to the
|
||||||
|
// Ollama target. It controls how long the worker model stays resident after a
|
||||||
|
// request. Accepts Ollama duration strings like "15m", "1h", "-1" (forever), or
|
||||||
|
// "0" (unload immediately). Default is "-1" (pin forever). This does NOT affect
|
||||||
|
// the embedder, which is always pinned with keep_alive=-1.
|
||||||
|
KeepAlive string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load reads configuration from environment variables and returns a validated Config.
|
// Load reads configuration from environment variables and returns a validated Config.
|
||||||
@@ -64,6 +71,7 @@ func Load() (Config, error) {
|
|||||||
EmbedModel: os.Getenv("FOREMAN_EMBED_MODEL"),
|
EmbedModel: os.Getenv("FOREMAN_EMBED_MODEL"),
|
||||||
DBPath: envOr("FOREMAN_DB_PATH", "foreman.db"),
|
DBPath: envOr("FOREMAN_DB_PATH", "foreman.db"),
|
||||||
WebhookSecret: os.Getenv("FOREMAN_WEBHOOK_SECRET"),
|
WebhookSecret: os.Getenv("FOREMAN_WEBHOOK_SECRET"),
|
||||||
|
KeepAlive: envOr("FOREMAN_KEEP_ALIVE", "-1"),
|
||||||
}
|
}
|
||||||
|
|
||||||
pollStr := envOr("FOREMAN_POLL_INTERVAL", "30s")
|
pollStr := envOr("FOREMAN_POLL_INTERVAL", "30s")
|
||||||
|
|||||||
@@ -31,6 +31,9 @@ func TestLoad_Defaults(t *testing.T) {
|
|||||||
if cfg.PollInterval != 30*time.Second {
|
if cfg.PollInterval != 30*time.Second {
|
||||||
t.Errorf("PollInterval = %v, want %v", cfg.PollInterval, 30*time.Second)
|
t.Errorf("PollInterval = %v, want %v", cfg.PollInterval, 30*time.Second)
|
||||||
}
|
}
|
||||||
|
if cfg.KeepAlive != "-1" {
|
||||||
|
t.Errorf("KeepAlive = %q, want %q", cfg.KeepAlive, "-1")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoad_AllEnvVars(t *testing.T) {
|
func TestLoad_AllEnvVars(t *testing.T) {
|
||||||
@@ -42,6 +45,7 @@ func TestLoad_AllEnvVars(t *testing.T) {
|
|||||||
t.Setenv("FOREMAN_DB_PATH", "/data/foreman.db")
|
t.Setenv("FOREMAN_DB_PATH", "/data/foreman.db")
|
||||||
t.Setenv("FOREMAN_POLL_INTERVAL", "1m")
|
t.Setenv("FOREMAN_POLL_INTERVAL", "1m")
|
||||||
t.Setenv("FOREMAN_WEBHOOK_SECRET", "hmac-key")
|
t.Setenv("FOREMAN_WEBHOOK_SECRET", "hmac-key")
|
||||||
|
t.Setenv("FOREMAN_KEEP_ALIVE", "15m")
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -72,6 +76,9 @@ func TestLoad_AllEnvVars(t *testing.T) {
|
|||||||
if cfg.WebhookSecret != "hmac-key" {
|
if cfg.WebhookSecret != "hmac-key" {
|
||||||
t.Errorf("WebhookSecret = %q", cfg.WebhookSecret)
|
t.Errorf("WebhookSecret = %q", cfg.WebhookSecret)
|
||||||
}
|
}
|
||||||
|
if cfg.KeepAlive != "15m" {
|
||||||
|
t.Errorf("KeepAlive = %q, want %q", cfg.KeepAlive, "15m")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoad_MissingOllamaURL(t *testing.T) {
|
func TestLoad_MissingOllamaURL(t *testing.T) {
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func newJobTestServer(t *testing.T, client ollama.Client, webhookSecret string)
|
|||||||
|
|
||||||
notifier := worker.NewNotifier()
|
notifier := worker.NewNotifier()
|
||||||
dispatcher := webhook.NewDispatcher(webhookSecret, logger)
|
dispatcher := webhook.NewDispatcher(webhookSecret, logger)
|
||||||
w := worker.New(st, client, inv, notifier, dispatcher, logger)
|
w := worker.New(st, client, inv, notifier, dispatcher, logger, "-1")
|
||||||
|
|
||||||
cfg := config.Config{
|
cfg := config.Config{
|
||||||
OllamaURL: "http://localhost:11434",
|
OllamaURL: "http://localhost:11434",
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func newTestServer(t *testing.T, cfg config.Config, client ollama.Client) (*Serv
|
|||||||
inv := ollama.NewModelInventory(client, logger)
|
inv := ollama.NewModelInventory(client, logger)
|
||||||
notifier := worker.NewNotifier()
|
notifier := worker.NewNotifier()
|
||||||
dispatcher := webhook.NewDispatcher("", logger)
|
dispatcher := webhook.NewDispatcher("", logger)
|
||||||
w := worker.New(st, client, inv, notifier, dispatcher, logger)
|
w := worker.New(st, client, inv, notifier, dispatcher, logger, "-1")
|
||||||
srv := New(cfg, st, client, inv, notifier, w, dispatcher, logger)
|
srv := New(cfg, st, client, inv, notifier, w, dispatcher, logger)
|
||||||
return srv, st
|
return srv, st
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -104,6 +105,11 @@ type Worker struct {
|
|||||||
dispatcher *webhook.Dispatcher
|
dispatcher *webhook.Dispatcher
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
|
|
||||||
|
// keepAlive is the JSON-encoded keep_alive value sent in outbound chat requests
|
||||||
|
// to control how long the worker model stays resident on the target. Derived from
|
||||||
|
// FOREMAN_KEEP_ALIVE config; does not affect the embedder.
|
||||||
|
keepAlive json.RawMessage
|
||||||
|
|
||||||
// wake is signaled when a new job is enqueued.
|
// wake is signaled when a new job is enqueued.
|
||||||
wake chan struct{}
|
wake chan struct{}
|
||||||
}
|
}
|
||||||
@@ -120,6 +126,7 @@ func New(
|
|||||||
notifier *Notifier,
|
notifier *Notifier,
|
||||||
dispatcher *webhook.Dispatcher,
|
dispatcher *webhook.Dispatcher,
|
||||||
logger *slog.Logger,
|
logger *slog.Logger,
|
||||||
|
keepAlive string,
|
||||||
) *Worker {
|
) *Worker {
|
||||||
return &Worker{
|
return &Worker{
|
||||||
store: st,
|
store: st,
|
||||||
@@ -128,10 +135,35 @@ func New(
|
|||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
dispatcher: dispatcher,
|
dispatcher: dispatcher,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
keepAlive: encodeKeepAlive(keepAlive),
|
||||||
wake: make(chan struct{}, 1),
|
wake: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// encodeKeepAlive converts a FOREMAN_KEEP_ALIVE config string to a json.RawMessage
|
||||||
|
// suitable for the Ollama ChatRequest KeepAlive field.
|
||||||
|
//
|
||||||
|
// Why: Ollama's keep_alive field accepts either a JSON number (seconds, or -1 for
|
||||||
|
// forever) or a JSON string duration ("15m", "1h"). Pure-numeric values and "-1"
|
||||||
|
// are encoded as JSON numbers; everything else is encoded as a JSON string.
|
||||||
|
// What: returns a json.RawMessage containing the appropriate JSON representation.
|
||||||
|
// Test: assert "-1" -> `-1`, "0" -> `0`, "15m" -> `"15m"`, "3600" -> `3600`.
|
||||||
|
func encodeKeepAlive(val string) json.RawMessage {
|
||||||
|
if val == "" {
|
||||||
|
val = "-1"
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the value parses as an integer, emit it as a JSON number.
|
||||||
|
// This covers "-1", "0", "3600", etc.
|
||||||
|
if _, err := strconv.Atoi(val); err == nil {
|
||||||
|
return json.RawMessage(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, emit it as a JSON string (e.g. "15m", "1h").
|
||||||
|
b, _ := json.Marshal(val)
|
||||||
|
return json.RawMessage(b)
|
||||||
|
}
|
||||||
|
|
||||||
// Wake signals the worker that a new job may be available. Non-blocking.
|
// Wake signals the worker that a new job may be available. Non-blocking.
|
||||||
//
|
//
|
||||||
// Why: the HTTP handlers signal the worker to check for new work immediately
|
// Why: the HTTP handlers signal the worker to check for new work immediately
|
||||||
@@ -241,6 +273,10 @@ func (w *Worker) executeJob(ctx context.Context, job store.Job) {
|
|||||||
streamFalse := false
|
streamFalse := false
|
||||||
chatReq.Stream = &streamFalse
|
chatReq.Stream = &streamFalse
|
||||||
|
|
||||||
|
// Override keep_alive with the configured value so the worker model stays
|
||||||
|
// resident for the desired duration. The embedder is pinned separately.
|
||||||
|
chatReq.KeepAlive = w.keepAlive
|
||||||
|
|
||||||
// Execute the chat request.
|
// Execute the chat request.
|
||||||
resp, _, err := w.client.Chat(ctx, chatReq, false)
|
resp, _, err := w.client.Chat(ctx, chatReq, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func newTestWorker(t *testing.T, client ollama.Client) (*Worker, *store.Store, *
|
|||||||
inv := ollama.NewModelInventory(client, logger)
|
inv := ollama.NewModelInventory(client, logger)
|
||||||
notifier := NewNotifier()
|
notifier := NewNotifier()
|
||||||
dispatcher := webhook.NewDispatcher("", logger)
|
dispatcher := webhook.NewDispatcher("", logger)
|
||||||
w := New(st, client, inv, notifier, dispatcher, logger)
|
w := New(st, client, inv, notifier, dispatcher, logger, "-1")
|
||||||
return w, st, notifier
|
return w, st, notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -755,6 +755,74 @@ func TestStore_ResetInterruptedJobs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEncodeKeepAlive(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"-1", "-1"},
|
||||||
|
{"0", "0"},
|
||||||
|
{"3600", "3600"},
|
||||||
|
{"15m", `"15m"`},
|
||||||
|
{"1h", `"1h"`},
|
||||||
|
{"", "-1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := string(encodeKeepAlive(tt.input))
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("encodeKeepAlive(%q) = %s, want %s", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorker_SetsKeepAliveOnChatRequest(t *testing.T) {
|
||||||
|
client := &stubOllamaClient{}
|
||||||
|
st := openTestDB(t)
|
||||||
|
logger := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||||
|
inv := ollama.NewModelInventory(client, logger)
|
||||||
|
notifier := NewNotifier()
|
||||||
|
dispatcher := webhook.NewDispatcher("", logger)
|
||||||
|
|
||||||
|
// Use "15m" to verify non-default keep_alive propagates to outbound requests.
|
||||||
|
w := New(st, client, inv, notifier, dispatcher, logger, "15m")
|
||||||
|
|
||||||
|
job := store.Job{
|
||||||
|
ID: "01KEEPALIVE",
|
||||||
|
Model: "qwen3:30b",
|
||||||
|
Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}`),
|
||||||
|
MaxAttempts: 3,
|
||||||
|
}
|
||||||
|
if _, err := st.CreateJob(job); err != nil {
|
||||||
|
t.Fatalf("CreateJob: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
waitCh := notifier.Register("01KEEPALIVE")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go w.Run(ctx)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-waitCh:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for job to complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the chat request had keep_alive set to "15m".
|
||||||
|
client.mu.Lock()
|
||||||
|
defer client.mu.Unlock()
|
||||||
|
if len(client.chatCalls) != 1 {
|
||||||
|
t.Fatalf("expected 1 chat call, got %d", len(client.chatCalls))
|
||||||
|
}
|
||||||
|
gotKA := string(client.chatCalls[0].KeepAlive)
|
||||||
|
if gotKA != `"15m"` {
|
||||||
|
t.Errorf("keep_alive = %s, want %s", gotKA, `"15m"`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStore_DeleteTerminalJobsBefore(t *testing.T) {
|
func TestStore_DeleteTerminalJobsBefore(t *testing.T) {
|
||||||
st := openTestDB(t)
|
st := openTestDB(t)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user