diff --git a/.env.example b/.env.example index c2016b9..4801b72 100644 --- a/.env.example +++ b/.env.example @@ -17,6 +17,13 @@ FOREMAN_TOKEN=change-me-to-a-secret # Always-resident embedding model (pinned in slot 1) FOREMAN_EMBED_MODEL=nomic-embed-text +# How long the worker model stays resident on the target after a request. +# Accepts Ollama duration strings: "-1" (forever/pin), "0" (unload immediately), +# "15m", "1h", "3600" (seconds), etc. +# Does NOT affect the embedder, which is always pinned with keep_alive=-1. +# Default: -1 (pin forever — best for a dedicated box) +FOREMAN_KEEP_ALIVE=-1 + # === Persistence === # SQLite database path (default: foreman.db) diff --git a/client/client_test.go b/client/client_test.go index e2db6f2..90da0c3 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -48,7 +48,7 @@ func newTestForeman(t *testing.T, ollamaClient ollama.Client, webhookSecret stri notifier := worker.NewNotifier() dispatcher := webhook.NewDispatcher(webhookSecret, logger) - w := worker.New(st, ollamaClient, inv, notifier, dispatcher, logger) + w := worker.New(st, ollamaClient, inv, notifier, dispatcher, logger, "-1") cfg := config.Config{ OllamaURL: "http://localhost:11434", @@ -260,7 +260,7 @@ func TestSubmit_AuthToken(t *testing.T) { notifier := worker.NewNotifier() dispatcher := webhook.NewDispatcher("", logger) - w := worker.New(st, ollamaStub, inv, notifier, dispatcher, logger) + w := worker.New(st, ollamaStub, inv, notifier, dispatcher, logger, "-1") cfg := config.Config{ OllamaURL: "http://localhost:11434", diff --git a/cmd/foreman/main.go b/cmd/foreman/main.go index 93f8ce2..330a97f 100644 --- a/cmd/foreman/main.go +++ b/cmd/foreman/main.go @@ -79,6 +79,7 @@ func runServe(logger *slog.Logger) error { "auth_enabled", cfg.Token != "", "max_attempts", cfg.MaxAttempts, "job_ttl", cfg.JobTTL, + "keep_alive", cfg.KeepAlive, ) st, err := store.Open(cfg.DBPath) @@ -107,7 +108,7 @@ func runServe(logger *slog.Logger) error { // Create the notifier and worker. notifier := worker.NewNotifier() - w := worker.New(st, client, inventory, notifier, dispatcher, logger) + w := worker.New(st, client, inventory, notifier, dispatcher, logger, cfg.KeepAlive) // Start the worker loop in a goroutine. go w.Run(ctx) diff --git a/internal/config/config.go b/internal/config/config.go index bf82947..4a99f36 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,6 +48,13 @@ type Config struct { // JobTTL is how long terminal jobs are retained before the pruner deletes them // (default 24h). JobTTL time.Duration + + // KeepAlive is the keep_alive value sent in outbound /api/chat requests to the + // Ollama target. It controls how long the worker model stays resident after a + // request. Accepts Ollama duration strings like "15m", "1h", "-1" (forever), or + // "0" (unload immediately). Default is "-1" (pin forever). This does NOT affect + // the embedder, which is always pinned with keep_alive=-1. + KeepAlive string } // Load reads configuration from environment variables and returns a validated Config. @@ -64,6 +71,7 @@ func Load() (Config, error) { EmbedModel: os.Getenv("FOREMAN_EMBED_MODEL"), DBPath: envOr("FOREMAN_DB_PATH", "foreman.db"), WebhookSecret: os.Getenv("FOREMAN_WEBHOOK_SECRET"), + KeepAlive: envOr("FOREMAN_KEEP_ALIVE", "-1"), } pollStr := envOr("FOREMAN_POLL_INTERVAL", "30s") diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 694ee32..8facf17 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -31,6 +31,9 @@ func TestLoad_Defaults(t *testing.T) { if cfg.PollInterval != 30*time.Second { t.Errorf("PollInterval = %v, want %v", cfg.PollInterval, 30*time.Second) } + if cfg.KeepAlive != "-1" { + t.Errorf("KeepAlive = %q, want %q", cfg.KeepAlive, "-1") + } } func TestLoad_AllEnvVars(t *testing.T) { @@ -42,6 +45,7 @@ func TestLoad_AllEnvVars(t *testing.T) { t.Setenv("FOREMAN_DB_PATH", "/data/foreman.db") t.Setenv("FOREMAN_POLL_INTERVAL", "1m") t.Setenv("FOREMAN_WEBHOOK_SECRET", "hmac-key") + t.Setenv("FOREMAN_KEEP_ALIVE", "15m") cfg, err := Load() if err != nil { @@ -72,6 +76,9 @@ func TestLoad_AllEnvVars(t *testing.T) { if cfg.WebhookSecret != "hmac-key" { t.Errorf("WebhookSecret = %q", cfg.WebhookSecret) } + if cfg.KeepAlive != "15m" { + t.Errorf("KeepAlive = %q, want %q", cfg.KeepAlive, "15m") + } } func TestLoad_MissingOllamaURL(t *testing.T) { diff --git a/internal/server/jobs_test.go b/internal/server/jobs_test.go index 97c5d30..ccfbe24 100644 --- a/internal/server/jobs_test.go +++ b/internal/server/jobs_test.go @@ -44,7 +44,7 @@ func newJobTestServer(t *testing.T, client ollama.Client, webhookSecret string) notifier := worker.NewNotifier() dispatcher := webhook.NewDispatcher(webhookSecret, logger) - w := worker.New(st, client, inv, notifier, dispatcher, logger) + w := worker.New(st, client, inv, notifier, dispatcher, logger, "-1") cfg := config.Config{ OllamaURL: "http://localhost:11434", diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 80bb0a4..8815bb9 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -38,7 +38,7 @@ func newTestServer(t *testing.T, cfg config.Config, client ollama.Client) (*Serv inv := ollama.NewModelInventory(client, logger) notifier := worker.NewNotifier() dispatcher := webhook.NewDispatcher("", logger) - w := worker.New(st, client, inv, notifier, dispatcher, logger) + w := worker.New(st, client, inv, notifier, dispatcher, logger, "-1") srv := New(cfg, st, client, inv, notifier, w, dispatcher, logger) return srv, st } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 985d13f..83081ea 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -16,6 +16,7 @@ import ( "errors" "fmt" "log/slog" + "strconv" "sync" "time" @@ -104,6 +105,11 @@ type Worker struct { dispatcher *webhook.Dispatcher logger *slog.Logger + // keepAlive is the JSON-encoded keep_alive value sent in outbound chat requests + // to control how long the worker model stays resident on the target. Derived from + // FOREMAN_KEEP_ALIVE config; does not affect the embedder. + keepAlive json.RawMessage + // wake is signaled when a new job is enqueued. wake chan struct{} } @@ -120,6 +126,7 @@ func New( notifier *Notifier, dispatcher *webhook.Dispatcher, logger *slog.Logger, + keepAlive string, ) *Worker { return &Worker{ store: st, @@ -128,10 +135,35 @@ func New( notifier: notifier, dispatcher: dispatcher, logger: logger, + keepAlive: encodeKeepAlive(keepAlive), wake: make(chan struct{}, 1), } } +// encodeKeepAlive converts a FOREMAN_KEEP_ALIVE config string to a json.RawMessage +// suitable for the Ollama ChatRequest KeepAlive field. +// +// Why: Ollama's keep_alive field accepts either a JSON number (seconds, or -1 for +// forever) or a JSON string duration ("15m", "1h"). Pure-numeric values and "-1" +// are encoded as JSON numbers; everything else is encoded as a JSON string. +// What: returns a json.RawMessage containing the appropriate JSON representation. +// Test: assert "-1" -> `-1`, "0" -> `0`, "15m" -> `"15m"`, "3600" -> `3600`. +func encodeKeepAlive(val string) json.RawMessage { + if val == "" { + val = "-1" + } + + // If the value parses as an integer, emit it as a JSON number. + // This covers "-1", "0", "3600", etc. + if _, err := strconv.Atoi(val); err == nil { + return json.RawMessage(val) + } + + // Otherwise, emit it as a JSON string (e.g. "15m", "1h"). + b, _ := json.Marshal(val) + return json.RawMessage(b) +} + // Wake signals the worker that a new job may be available. Non-blocking. // // Why: the HTTP handlers signal the worker to check for new work immediately @@ -241,6 +273,10 @@ func (w *Worker) executeJob(ctx context.Context, job store.Job) { streamFalse := false chatReq.Stream = &streamFalse + // Override keep_alive with the configured value so the worker model stays + // resident for the desired duration. The embedder is pinned separately. + chatReq.KeepAlive = w.keepAlive + // Execute the chat request. resp, _, err := w.client.Chat(ctx, chatReq, false) if err != nil { diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go index 50b6c57..7b48070 100644 --- a/internal/worker/worker_test.go +++ b/internal/worker/worker_test.go @@ -40,7 +40,7 @@ func newTestWorker(t *testing.T, client ollama.Client) (*Worker, *store.Store, * inv := ollama.NewModelInventory(client, logger) notifier := NewNotifier() dispatcher := webhook.NewDispatcher("", logger) - w := New(st, client, inv, notifier, dispatcher, logger) + w := New(st, client, inv, notifier, dispatcher, logger, "-1") return w, st, notifier } @@ -755,6 +755,74 @@ func TestStore_ResetInterruptedJobs(t *testing.T) { } } +func TestEncodeKeepAlive(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"-1", "-1"}, + {"0", "0"}, + {"3600", "3600"}, + {"15m", `"15m"`}, + {"1h", `"1h"`}, + {"", "-1"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := string(encodeKeepAlive(tt.input)) + if got != tt.want { + t.Errorf("encodeKeepAlive(%q) = %s, want %s", tt.input, got, tt.want) + } + }) + } +} + +func TestWorker_SetsKeepAliveOnChatRequest(t *testing.T) { + client := &stubOllamaClient{} + st := openTestDB(t) + logger := slog.New(slog.NewJSONHandler(io.Discard, nil)) + inv := ollama.NewModelInventory(client, logger) + notifier := NewNotifier() + dispatcher := webhook.NewDispatcher("", logger) + + // Use "15m" to verify non-default keep_alive propagates to outbound requests. + w := New(st, client, inv, notifier, dispatcher, logger, "15m") + + job := store.Job{ + ID: "01KEEPALIVE", + Model: "qwen3:30b", + Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}`), + MaxAttempts: 3, + } + if _, err := st.CreateJob(job); err != nil { + t.Fatalf("CreateJob: %v", err) + } + + waitCh := notifier.Register("01KEEPALIVE") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Run(ctx) + + select { + case <-waitCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for job to complete") + } + + // Verify the chat request had keep_alive set to "15m". + client.mu.Lock() + defer client.mu.Unlock() + if len(client.chatCalls) != 1 { + t.Fatalf("expected 1 chat call, got %d", len(client.chatCalls)) + } + gotKA := string(client.chatCalls[0].KeepAlive) + if gotKA != `"15m"` { + t.Errorf("keep_alive = %s, want %s", gotKA, `"15m"`) + } +} + func TestStore_DeleteTerminalJobsBefore(t *testing.T) { st := openTestDB(t)