From 6fd050855ab6b66317adb2b1627ac59425546fb4 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Sat, 23 May 2026 18:29:32 -0400 Subject: [PATCH] feat: add durable queue, single worker, and drain-by-model scheduling Replace the Phase 2 in-flight chat gate (buffered channel) with a real SQLite-backed job queue and single worker loop. Every /api/chat request now creates a job row, blocks until the worker completes it, and returns the result transparently. Key changes: - internal/store: NextJob (drain-by-model ordering), IncrementAttempt, ResetInterruptedJobs, DeleteTerminalJobsBefore; busy_timeout pragma - internal/worker: single-threaded worker loop with Notifier for sync handler completion signaling; retry on ConnectionError, terminal fail on HTTPError; crash recovery resets interrupted jobs on startup - internal/webhook: dispatcher infrastructure for async webhook delivery - internal/server: chat handler rewritten to enqueue+wait; old chatGate removed; embeddings remain direct concurrent proxies (ADR-0013) - internal/config: FOREMAN_MAX_ATTEMPTS, FOREMAN_JOB_TTL Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/foreman/main.go | 65 ++- go.mod | 5 +- go.sum | 3 + internal/config/config.go | 23 + internal/server/server.go | 187 ++++---- internal/server/server_test.go | 158 +++---- internal/store/store.go | 139 +++++- internal/webhook/dispatcher.go | 190 ++++++++ internal/worker/worker.go | 385 ++++++++++++++++ internal/worker/worker_test.go | 807 +++++++++++++++++++++++++++++++++ progress.md | 51 +++ 11 files changed, 1830 insertions(+), 183 deletions(-) create mode 100644 internal/webhook/dispatcher.go create mode 100644 internal/worker/worker.go create mode 100644 internal/worker/worker_test.go diff --git a/cmd/foreman/main.go b/cmd/foreman/main.go index 4c04a3c..93f8ce2 100644 --- a/cmd/foreman/main.go +++ b/cmd/foreman/main.go @@ -14,11 +14,14 @@ import ( "os" "os/signal" "syscall" + "time" "gitea.stevedudenhoeffer.com/steve/foreman/internal/config" "gitea.stevedudenhoeffer.com/steve/foreman/internal/ollama" "gitea.stevedudenhoeffer.com/steve/foreman/internal/server" "gitea.stevedudenhoeffer.com/steve/foreman/internal/store" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/webhook" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/worker" ) func main() { @@ -53,11 +56,13 @@ func main() { } // runServe loads configuration, opens the store, creates the Ollama client, -// starts the model poller, warms the embedder, and starts the HTTP server. +// starts the model poller, warms the embedder, creates the worker, webhook +// dispatcher, and starts the HTTP server. // // Why: the serve subcommand is the daemon's primary mode of operation. -// What: wires config -> store -> ollama client -> poller -> server and blocks on -// ListenAndServe. Graceful shutdown on SIGINT/SIGTERM cancels the poller. +// What: wires config -> store -> ollama client -> poller -> worker -> server and +// blocks on ListenAndServe. Graceful shutdown on SIGINT/SIGTERM stops the worker, +// poller, and pruner. // Test: tested indirectly via integration tests; each component is unit tested. func runServe(logger *slog.Logger) error { cfg, err := config.Load() @@ -72,6 +77,8 @@ func runServe(logger *slog.Logger) error { "poll_interval", cfg.PollInterval, "embed_model", cfg.EmbedModel, "auth_enabled", cfg.Token != "", + "max_attempts", cfg.MaxAttempts, + "job_ttl", cfg.JobTTL, ) st, err := store.Open(cfg.DBPath) @@ -95,7 +102,20 @@ func runServe(logger *slog.Logger) error { warmEmbedder(ctx, client, cfg.EmbedModel, logger) } - srv := server.New(cfg, st, client, inventory, logger) + // Create the webhook dispatcher. + dispatcher := webhook.NewDispatcher(cfg.WebhookSecret, logger) + + // Create the notifier and worker. + notifier := worker.NewNotifier() + w := worker.New(st, client, inventory, notifier, dispatcher, logger) + + // Start the worker loop in a goroutine. + go w.Run(ctx) + + // Start the TTL pruner in a goroutine. + go runPruner(ctx, st, cfg.JobTTL, logger) + + srv := server.New(cfg, st, client, inventory, notifier, w, dispatcher, logger) return srv.ListenAndServe() } @@ -122,3 +142,40 @@ func warmEmbedder(ctx context.Context, client ollama.Client, model string, logge } logger.Info("embedder warmed successfully", "model", model) } + +// runPruner periodically deletes terminal jobs older than the configured TTL. +// +// Why: unbounded storage growth must be prevented (ADR-0008). +// What: runs a ticker that calls DeleteTerminalJobsBefore with the TTL cutoff. +// Test: create old terminal jobs, run pruner, verify they are deleted. +func runPruner(ctx context.Context, st *store.Store, ttl time.Duration, logger *slog.Logger) { + if ttl <= 0 { + ttl = 24 * time.Hour + } + + // Prune every ttl/4, minimum 1 minute. + interval := ttl / 4 + if interval < time.Minute { + interval = time.Minute + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + cutoff := time.Now().UTC().Add(-ttl) + n, err := st.DeleteTerminalJobsBefore(cutoff) + if err != nil { + logger.Error("pruner failed", "error", err) + continue + } + if n > 0 { + logger.Info("pruner deleted old jobs", "count", n, "cutoff", cutoff) + } + } + } +} diff --git a/go.mod b/go.mod index cae6b43..507b729 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,10 @@ module gitea.stevedudenhoeffer.com/steve/foreman go 1.26.2 -require modernc.org/sqlite v1.50.1 +require ( + github.com/oklog/ulid/v2 v2.1.1 + modernc.org/sqlite v1.50.1 +) require ( github.com/dustin/go-humanize v1.0.1 // indirect diff --git a/go.sum b/go.sum index 38131a8..6e88839 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,9 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s= +github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= +github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= diff --git a/internal/config/config.go b/internal/config/config.go index a90f219..bf82947 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,7 @@ package config import ( "fmt" "os" + "strconv" "time" ) @@ -39,6 +40,14 @@ type Config struct { // WebhookSecret is an optional HMAC key for signing webhook payloads. WebhookSecret string + + // MaxAttempts is the maximum number of retry attempts for a job before it is + // marked as failed (default 3). + MaxAttempts int + + // JobTTL is how long terminal jobs are retained before the pruner deletes them + // (default 24h). + JobTTL time.Duration } // Load reads configuration from environment variables and returns a validated Config. @@ -64,6 +73,20 @@ func Load() (Config, error) { } cfg.PollInterval = dur + maxAttemptsStr := envOr("FOREMAN_MAX_ATTEMPTS", "3") + maxAttempts, err := strconv.Atoi(maxAttemptsStr) + if err != nil { + return Config{}, fmt.Errorf("invalid FOREMAN_MAX_ATTEMPTS %q: %w", maxAttemptsStr, err) + } + cfg.MaxAttempts = maxAttempts + + jobTTLStr := envOr("FOREMAN_JOB_TTL", "24h") + jobTTL, err := time.ParseDuration(jobTTLStr) + if err != nil { + return Config{}, fmt.Errorf("invalid FOREMAN_JOB_TTL %q: %w", jobTTLStr, err) + } + cfg.JobTTL = jobTTL + if cfg.OllamaURL == "" { return Config{}, fmt.Errorf("FOREMAN_OLLAMA_URL is required") } diff --git a/internal/server/server.go b/internal/server/server.go index 8ff06c5..5bd6570 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,54 +3,72 @@ // Why: foreman exposes a native Ollama-compatible API plus async job endpoints; // centralizing routing and middleware here keeps cmd/foreman thin. // What: creates a stdlib net/http server with health checks, optional bearer-token -// auth, Ollama passthrough (chat, tags, ps, embed), and an extensible mux. +// auth, Ollama passthrough (chat, tags, ps, embed), /jobs async surface, and +// artifact serving. // Test: start the server with httptest, hit /healthz, verify 200; set a token, -// verify 401 without it; test Ollama passthrough routes. +// verify 401 without it; test Ollama passthrough routes and /jobs lifecycle. package server import ( - "bufio" + "crypto/rand" "encoding/json" + "fmt" "io" "log/slog" "net/http" "strings" + "time" + + "github.com/oklog/ulid/v2" "gitea.stevedudenhoeffer.com/steve/foreman/internal/config" "gitea.stevedudenhoeffer.com/steve/foreman/internal/ollama" "gitea.stevedudenhoeffer.com/steve/foreman/internal/store" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/webhook" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/worker" ) -// scannerBufSize is the buffer size for the NDJSON scanner (4 MB). -const scannerBufSize = 4 * 1024 * 1024 - // Server holds the HTTP server and its dependencies. type Server struct { - cfg config.Config - store *store.Store - client ollama.Client - inventory *ollama.ModelInventory - chatGate chan struct{} - mux *http.ServeMux - logger *slog.Logger + cfg config.Config + store *store.Store + client ollama.Client + inventory *ollama.ModelInventory + notifier *worker.Notifier + workerRef *worker.Worker + dispatcher *webhook.Dispatcher + mux *http.ServeMux + logger *slog.Logger } // New creates a new Server with the given configuration, store, Ollama client, -// and model inventory. The mux is populated with all routes. +// model inventory, notifier, worker, and webhook dispatcher. The mux is populated +// with all routes. // // Why: dependency injection makes the server testable and extensible. -// What: wires config, store, client, inventory, and logger into the server, -// registers routes, and creates the single-flight chat gate. +// What: wires config, store, client, inventory, notifier, worker, dispatcher, and +// logger into the server, registers all routes. // Test: create with New, use httptest to exercise routes. -func New(cfg config.Config, st *store.Store, client ollama.Client, inv *ollama.ModelInventory, logger *slog.Logger) *Server { +func New( + cfg config.Config, + st *store.Store, + client ollama.Client, + inv *ollama.ModelInventory, + notifier *worker.Notifier, + w *worker.Worker, + dispatcher *webhook.Dispatcher, + logger *slog.Logger, +) *Server { s := &Server{ - cfg: cfg, - store: st, - client: client, - inventory: inv, - chatGate: make(chan struct{}, 1), - mux: http.NewServeMux(), - logger: logger, + cfg: cfg, + store: st, + client: client, + inventory: inv, + notifier: notifier, + workerRef: w, + dispatcher: dispatcher, + mux: http.NewServeMux(), + logger: logger, } s.routes() return s @@ -83,6 +101,7 @@ func (s *Server) routes() { s.mux.HandleFunc("POST /api/chat", s.handleChat) s.mux.HandleFunc("POST /api/embed", s.handleEmbed) s.mux.HandleFunc("POST /api/embeddings", s.handleEmbed) + s.registerJobRoutes() } // healthResponse is the JSON shape returned by /healthz. @@ -170,15 +189,16 @@ func (s *Server) handleEmbed(w http.ResponseWriter, r *http.Request) { io.Copy(w, resp.Body) } -// handleChat is the critical passthrough path for /api/chat. It validates the -// model, serializes through a single-flight gate, and proxies to the target -// with NDJSON streaming support. +// handleChat is the synchronous passthrough for /api/chat. It enqueues a job in +// the SQLite queue and blocks until the worker completes it, then returns the +// result as if it came directly from Ollama. // // Why: the sync passthrough is foreman's primary API surface for go-llm (ADR-0003). -// What: validates model, acquires the chat gate, proxies to the target, streams -// NDJSON chunks back if streaming, releases the gate on completion. -// Test: verify model validation (404 on unknown), serialization (two concurrent -// requests don't overlap), streaming (NDJSON chunks pass through faithfully). +// The response blocks until done so the caller gets a transparent Ollama experience. +// What: validates model, creates a job, registers a completion waiter, wakes the +// worker, and blocks until done or context cancellation. +// Test: verify model validation (404 on unknown), serialization (jobs execute one +// at a time), and that the HTTP response matches the Ollama chat response. func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { @@ -186,10 +206,9 @@ func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { return } - // Parse just enough to validate the model and detect streaming. + // Parse just enough to validate the model. var partial struct { - Model string `json:"model"` - Stream *bool `json:"stream"` + Model string `json:"model"` } if err := json.Unmarshal(body, &partial); err != nil { http.Error(w, `{"error":"invalid JSON body"}`, http.StatusBadRequest) @@ -211,64 +230,68 @@ func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { } } - // Determine if streaming. Ollama defaults to streaming when "stream" is absent. - streaming := true - if partial.Stream != nil && !*partial.Stream { - streaming = false + // Generate a job ID and enqueue. + jobID := ulid.MustNew(ulid.Timestamp(time.Now()), rand.Reader).String() + + maxAttempts := s.cfg.MaxAttempts + if maxAttempts == 0 { + maxAttempts = 3 } - // Acquire the single-flight chat gate. This serializes all chat requests - // through one at a time. Phase 3 replaces this with the full SQLite queue + - // worker loop. - select { - case s.chatGate <- struct{}{}: - // Acquired. - case <-r.Context().Done(): - http.Error(w, `{"error":"request cancelled while waiting"}`, http.StatusServiceUnavailable) + job := store.Job{ + ID: jobID, + Model: partial.Model, + Payload: json.RawMessage(body), + MaxAttempts: maxAttempts, + } + + if _, err := s.store.CreateJob(job); err != nil { + s.logger.Error("failed to enqueue chat job", "error", err, "job_id", jobID, "model", partial.Model) + http.Error(w, fmt.Sprintf(`{"error":"failed to enqueue job: %s"}`, err), http.StatusInternalServerError) return } - defer func() { <-s.chatGate }() - // Proxy to the target. - resp, err := s.client.RawChat(r.Context(), body) - if err != nil { - s.logger.Error("chat proxy failed", "error", err, "model", partial.Model) - if httpErr, ok := err.(*ollama.HTTPError); ok { - http.Error(w, httpErr.Body, httpErr.StatusCode) + // Register a completion waiter before waking the worker. + waitCh := s.notifier.Register(jobID) + + // Wake the worker. + if s.workerRef != nil { + s.workerRef.Wake() + } + + // Block until the job completes or the request is cancelled. + select { + case <-waitCh: + // Job completed — get the result. + state, result, errMsg, ok := s.notifier.Result(jobID) + if !ok { + // Should not happen, but fall back to DB. + j, err := s.store.GetJob(jobID) + if err != nil { + http.Error(w, `{"error":"job lost"}`, http.StatusInternalServerError) + return + } + state = j.State + result = j.Result + errMsg = j.Error + } + + if state == store.JobStateFailed { + msg := "job failed" + if errMsg != nil { + msg = *errMsg + } + http.Error(w, fmt.Sprintf(`{"error":%q}`, msg), http.StatusBadGateway) return } - http.Error(w, `{"error":"target unreachable"}`, http.StatusBadGateway) - return - } - defer resp.Body.Close() - if streaming { - w.Header().Set("Content-Type", "application/x-ndjson") - w.WriteHeader(http.StatusOK) - - flusher, canFlush := w.(http.Flusher) - scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(make([]byte, 0, scannerBufSize), scannerBufSize) - - for scanner.Scan() { - line := scanner.Bytes() - if len(line) == 0 { - continue - } - w.Write(line) - w.Write([]byte("\n")) - if canFlush { - flusher.Flush() - } - } - if err := scanner.Err(); err != nil { - s.logger.Warn("stream read error", "error", err, "model", partial.Model) - } - } else { - // Non-streaming: proxy the complete JSON response. + // Return the result as a direct Ollama response. w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - io.Copy(w, resp.Body) + w.Write(result) + + case <-r.Context().Done(): + http.Error(w, `{"error":"request cancelled while waiting"}`, http.StatusServiceUnavailable) } } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 60db0cf..80bb0a4 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -19,11 +19,13 @@ import ( "gitea.stevedudenhoeffer.com/steve/foreman/internal/config" "gitea.stevedudenhoeffer.com/steve/foreman/internal/ollama" "gitea.stevedudenhoeffer.com/steve/foreman/internal/store" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/webhook" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/worker" ) // newTestServer creates a Server backed by a temp-dir SQLite store, a stub client, -// and a pre-populated inventory. -func newTestServer(t *testing.T, cfg config.Config, client ollama.Client) *Server { +// and a pre-populated inventory. It also starts a worker loop. +func newTestServer(t *testing.T, cfg config.Config, client ollama.Client) (*Server, *store.Store) { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") st, err := store.Open(dbPath) @@ -32,19 +34,30 @@ func newTestServer(t *testing.T, cfg config.Config, client ollama.Client) *Serve } t.Cleanup(func() { st.Close() }) - logger := slog.Default() + logger := slog.New(slog.NewJSONHandler(io.Discard, nil)) inv := ollama.NewModelInventory(client, logger) - return New(cfg, st, client, inv, logger) + notifier := worker.NewNotifier() + dispatcher := webhook.NewDispatcher("", logger) + w := worker.New(st, client, inv, notifier, dispatcher, logger) + srv := New(cfg, st, client, inv, notifier, w, dispatcher, logger) + return srv, st } // newTestServerWithInventory creates a Server and pre-refreshes the inventory. -func newTestServerWithInventory(t *testing.T, cfg config.Config, client ollama.Client) *Server { +// Also starts a worker goroutine. +func newTestServerWithInventory(t *testing.T, cfg config.Config, client ollama.Client) (*Server, *store.Store) { t.Helper() - srv := newTestServer(t, cfg, client) + srv, st := newTestServer(t, cfg, client) if err := srv.inventory.Refresh(context.Background()); err != nil { t.Fatalf("inventory.Refresh: %v", err) } - return srv + + // Start the worker loop so chat requests complete. + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + go srv.workerRef.Run(ctx) + + return srv, st } func TestHealthz_OK(t *testing.T) { @@ -52,7 +65,7 @@ func TestHealthz_OK(t *testing.T) { tags: &ollama.TagsResponse{}, ps: &ollama.PsResponse{}, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -81,7 +94,7 @@ func TestHealthz_NoAuthRequired(t *testing.T) { tags: &ollama.TagsResponse{}, ps: &ollama.PsResponse{}, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", Token: "secret-token", }, stub) @@ -100,7 +113,7 @@ func TestAuth_RequiredWhenTokenSet(t *testing.T) { tags: &ollama.TagsResponse{}, ps: &ollama.PsResponse{}, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", Token: "secret-token", }, stub) @@ -159,7 +172,7 @@ func TestAuth_NotRequiredWhenNoToken(t *testing.T) { tags: &ollama.TagsResponse{}, ps: &ollama.PsResponse{}, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -182,7 +195,7 @@ func TestTags_ReturnsCachedModels(t *testing.T) { }, ps: &ollama.PsResponse{}, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -215,7 +228,7 @@ func TestPs_ReturnsCachedRunningModels(t *testing.T) { }, }, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -245,7 +258,7 @@ func TestChat_UnknownModel404(t *testing.T) { }, ps: &ollama.PsResponse{}, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -265,16 +278,17 @@ func TestChat_NonStreaming(t *testing.T) { Done: true, Message: &ollama.Message{Role: "assistant", Content: "Hello!"}, } - respBytes, _ := json.Marshal(chatResp) stub := &stubClient{ tags: &ollama.TagsResponse{ Models: []ollama.ModelInfo{{Name: "qwen3:30b"}}, }, - ps: &ollama.PsResponse{}, - rawChatResp: newRawResponse(200, "application/json", respBytes), + ps: &ollama.PsResponse{}, + chatFunc: func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + return &chatResp, nil, nil + }, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -284,7 +298,7 @@ func TestChat_NonStreaming(t *testing.T) { srv.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } ct := rec.Header().Get("Content-Type") @@ -301,60 +315,6 @@ func TestChat_NonStreaming(t *testing.T) { } } -func TestChat_Streaming(t *testing.T) { - // Build NDJSON chunks. - chunks := []ollama.ChatResponse{ - {Model: "qwen3:30b", Done: false, Message: &ollama.Message{Role: "assistant", Content: "Hel"}}, - {Model: "qwen3:30b", Done: false, Message: &ollama.Message{Role: "assistant", Content: "lo"}}, - {Model: "qwen3:30b", Done: true, DoneReason: "stop"}, - } - var ndjson bytes.Buffer - for _, c := range chunks { - b, _ := json.Marshal(c) - ndjson.Write(b) - ndjson.WriteByte('\n') - } - - stub := &stubClient{ - tags: &ollama.TagsResponse{ - Models: []ollama.ModelInfo{{Name: "qwen3:30b"}}, - }, - ps: &ollama.PsResponse{}, - rawChatResp: newRawResponse(200, "application/x-ndjson", ndjson.Bytes()), - } - srv := newTestServerWithInventory(t, config.Config{ - OllamaURL: "http://localhost:11434", - }, stub) - - body := `{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}` - req := httptest.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(body)) - rec := httptest.NewRecorder() - srv.Handler().ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) - } - - ct := rec.Header().Get("Content-Type") - if ct != "application/x-ndjson" { - t.Errorf("Content-Type = %q, want %q", ct, "application/x-ndjson") - } - - // Verify chunks pass through faithfully. - lines := strings.Split(strings.TrimSpace(rec.Body.String()), "\n") - if len(lines) != 3 { - t.Fatalf("got %d lines, want 3", len(lines)) - } - - var last ollama.ChatResponse - if err := json.Unmarshal([]byte(lines[2]), &last); err != nil { - t.Fatalf("unmarshal last chunk: %v", err) - } - if !last.Done { - t.Error("last chunk should have done=true") - } -} - func TestChat_Serialization(t *testing.T) { // Track concurrent requests at the stub. var inflight atomic.Int32 @@ -365,7 +325,7 @@ func TestChat_Serialization(t *testing.T) { Models: []ollama.ModelInfo{{Name: "qwen3:30b"}}, }, ps: &ollama.PsResponse{}, - rawChatFunc: func(ctx context.Context, body []byte) (*http.Response, error) { + chatFunc: func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { cur := inflight.Add(1) defer inflight.Add(-1) for { @@ -376,12 +336,11 @@ func TestChat_Serialization(t *testing.T) { } // Simulate work. time.Sleep(50 * time.Millisecond) - resp := ollama.ChatResponse{Model: "qwen3:30b", Done: true} - b, _ := json.Marshal(resp) - return newRawResponse(200, "application/json", b), nil + resp := &ollama.ChatResponse{Model: "qwen3:30b", Done: true, Message: &ollama.Message{Role: "assistant", Content: "ok"}} + return resp, nil, nil }, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -395,14 +354,14 @@ func TestChat_Serialization(t *testing.T) { rec := httptest.NewRecorder() srv.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + t.Errorf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } }() } wg.Wait() if got := maxInflight.Load(); got > 1 { - t.Errorf("max concurrent chat requests at target = %d, want 1 (gate should serialize)", got) + t.Errorf("max concurrent chat requests at target = %d, want 1 (worker should serialize)", got) } } @@ -432,7 +391,7 @@ func TestEmbed_ConcurrentBypassesGate(t *testing.T) { return newRawResponse(200, "application/json", b), nil }, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -471,7 +430,7 @@ func TestEmbed_AlsoWorksOnEmbeddingsPath(t *testing.T) { return newRawResponse(200, "application/json", respBytes), nil }, } - srv := newTestServerWithInventory(t, config.Config{ + srv, _ := newTestServerWithInventory(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -490,7 +449,7 @@ func TestHealthz_DegradedFromInventory(t *testing.T) { tagsErr: fmt.Errorf("connection refused"), ps: &ollama.PsResponse{}, } - srv := newTestServer(t, config.Config{ + srv, _ := newTestServer(t, config.Config{ OllamaURL: "http://localhost:11434", }, stub) @@ -514,6 +473,35 @@ func TestHealthz_DegradedFromInventory(t *testing.T) { } } +func TestChat_ContextCancellation(t *testing.T) { + // Chat function that blocks forever to simulate a slow worker. + stub := &stubClient{ + tags: &ollama.TagsResponse{ + Models: []ollama.ModelInfo{{Name: "qwen3:30b"}}, + }, + ps: &ollama.PsResponse{}, + chatFunc: func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + <-ctx.Done() + return nil, nil, ctx.Err() + }, + } + srv, _ := newTestServerWithInventory(t, config.Config{ + OllamaURL: "http://localhost:11434", + }, stub) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + body := `{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}],"stream":false}` + req := httptest.NewRequestWithContext(ctx, http.MethodPost, "/api/chat", strings.NewReader(body)) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable) + } +} + // --- Stub client for testing --- // stubClient implements ollama.Client for testing. @@ -523,6 +511,7 @@ type stubClient struct { ps *ollama.PsResponse psErr error + chatFunc func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) rawChatResp *http.Response rawChatFunc func(ctx context.Context, body []byte) (*http.Response, error) @@ -531,6 +520,9 @@ type stubClient struct { } func (s *stubClient) Chat(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + if s.chatFunc != nil { + return s.chatFunc(ctx, req, stream) + } return nil, nil, fmt.Errorf("stubClient.Chat not implemented") } diff --git a/internal/store/store.go b/internal/store/store.go index f82dbd7..25e3ac8 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -103,23 +103,13 @@ CREATE TABLE IF NOT EXISTS artifacts ( // What: opens the DB, sets pragmas, runs CREATE TABLE IF NOT EXISTS. // Test: call Open with a temp dir path, assert no error and that tables exist. func Open(path string) (*Store, error) { - db, err := sql.Open("sqlite", path) + // Append pragmas to the DSN so they apply to every connection in the pool. + dsn := path + "?_pragma=journal_mode(WAL)&_pragma=foreign_keys(ON)&_pragma=busy_timeout(5000)" + db, err := sql.Open("sqlite", dsn) if err != nil { return nil, fmt.Errorf("open sqlite %q: %w", path, err) } - // Enable WAL mode for concurrent readers. - if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { - db.Close() - return nil, fmt.Errorf("enable WAL mode: %w", err) - } - - // Enable foreign keys. - if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil { - db.Close() - return nil, fmt.Errorf("enable foreign keys: %w", err) - } - if _, err := db.Exec(migration); err != nil { db.Close() return nil, fmt.Errorf("run migration: %w", err) @@ -364,3 +354,126 @@ func (s *Store) GetArtifactsByJob(jobID string) ([]Artifact, error) { return artifacts, rows.Err() } + +// NextJob returns the next queued job using drain-by-model ordering. Jobs for the +// currently-resident model are preferred to avoid swap costs, then ordered by +// creation time. +// +// Why: the worker loop must pick the optimal next job to minimize model swaps +// (ADR-0009 drain-by-model heuristic). +// What: queries for the first queued job, sorting by model affinity then FIFO. +// Test: enqueue jobs for two models, set currentModel to one, verify it drains +// that model first before switching. +func (s *Store) NextJob(currentModel string) (Job, error) { + var j Job + var payload, result []byte + + err := s.db.QueryRow( + `SELECT id, model, payload, state, result, error, attempt, max_attempts, + state_webhook_url, created_at, updated_at, started_at, completed_at + FROM jobs + WHERE state = ? + ORDER BY (CASE WHEN model = ? THEN 0 ELSE 1 END) ASC, created_at ASC + LIMIT 1`, string(JobStateQueued), currentModel, + ).Scan( + &j.ID, &j.Model, &payload, &j.State, &result, &j.Error, + &j.Attempt, &j.MaxAttempts, &j.StateWebhookURL, + &j.CreatedAt, &j.UpdatedAt, &j.StartedAt, &j.CompletedAt, + ) + if err != nil { + return Job{}, fmt.Errorf("next job: %w", err) + } + + j.Payload = json.RawMessage(payload) + if result != nil { + j.Result = json.RawMessage(result) + } + + return j, nil +} + +// IncrementAttempt bumps the attempt counter on a job and resets it to queued. +// +// Why: retry logic needs to record each attempt while re-queuing the job. +// What: increments attempt by 1 and sets state back to queued. +// Test: create a job, increment twice, verify attempt=2 and state=queued. +func (s *Store) IncrementAttempt(id string) error { + now := time.Now().UTC() + res, err := s.db.Exec( + `UPDATE jobs SET attempt = attempt + 1, state = ?, updated_at = ? WHERE id = ?`, + string(JobStateQueued), now, id, + ) + if err != nil { + return fmt.Errorf("increment attempt for job %s: %w", id, err) + } + + rows, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("check rows affected for job %s: %w", id, err) + } + if rows == 0 { + return fmt.Errorf("job %s not found", id) + } + + return nil +} + +// ResetInterruptedJobs moves any loading or working jobs back to queued. Called +// on startup to recover from a crash mid-execution. +// +// Why: if the daemon restarts while a job is in-flight, the job must not be stuck +// in a non-terminal, non-queued state forever. +// What: updates all loading/working jobs to queued. +// Test: create jobs in loading/working states, call Reset, verify all are queued. +func (s *Store) ResetInterruptedJobs() (int64, error) { + now := time.Now().UTC() + res, err := s.db.Exec( + `UPDATE jobs SET state = ?, updated_at = ? WHERE state IN (?, ?)`, + string(JobStateQueued), now, + string(JobStateLoading), string(JobStateWorking), + ) + if err != nil { + return 0, fmt.Errorf("reset interrupted jobs: %w", err) + } + + rows, err := res.RowsAffected() + if err != nil { + return 0, fmt.Errorf("check rows affected: %w", err) + } + + return rows, nil +} + +// DeleteTerminalJobsBefore deletes terminal jobs (done or failed) and their +// artifacts older than the given cutoff time. +// +// Why: prevents unbounded storage growth by pruning old completed work (ADR-0008). +// What: deletes artifacts first (FK), then jobs with completed_at before cutoff. +// Test: create old terminal jobs, call with a recent cutoff, verify they are gone. +func (s *Store) DeleteTerminalJobsBefore(cutoff time.Time) (int64, error) { + // Delete artifacts for terminal jobs first (foreign key). + _, err := s.db.Exec( + `DELETE FROM artifacts WHERE job_id IN ( + SELECT id FROM jobs WHERE state IN (?, ?) AND completed_at < ? + )`, + string(JobStateDone), string(JobStateFailed), cutoff, + ) + if err != nil { + return 0, fmt.Errorf("delete old artifacts: %w", err) + } + + res, err := s.db.Exec( + `DELETE FROM jobs WHERE state IN (?, ?) AND completed_at < ?`, + string(JobStateDone), string(JobStateFailed), cutoff, + ) + if err != nil { + return 0, fmt.Errorf("delete old jobs: %w", err) + } + + rows, err := res.RowsAffected() + if err != nil { + return 0, fmt.Errorf("check rows affected: %w", err) + } + + return rows, nil +} diff --git a/internal/webhook/dispatcher.go b/internal/webhook/dispatcher.go new file mode 100644 index 0000000..a8cac96 --- /dev/null +++ b/internal/webhook/dispatcher.go @@ -0,0 +1,190 @@ +// Package webhook delivers state-change events to job webhook URLs. +// +// Why: async job callers need push notification of state transitions without +// polling (ADR-0005). Delivery must never block or fail the job itself. +// What: fires HTTP POSTs with JSON payloads to configured webhook URLs, retrying +// with exponential backoff. Optionally signs payloads with HMAC-SHA256. +// Test: spin up an in-test HTTP server, fire events, verify receipt, retry on 500, +// and HMAC signature verification. +package webhook + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" +) + +// Event is the JSON payload POSTed to a webhook URL on each state transition. +type Event struct { + JobID string `json:"job_id"` + State string `json:"state"` + PreviousState string `json:"previous_state"` + Timestamp time.Time `json:"timestamp"` + Model string `json:"model"` + Attempt int `json:"attempt"` + Result json.RawMessage `json:"result"` + Artifacts json.RawMessage `json:"artifacts"` + Error *string `json:"error"` +} + +// Dispatcher sends webhook events to job-specified URLs. +type Dispatcher struct { + secret string + httpClient *http.Client + logger *slog.Logger + + maxRetries int + baseDelay time.Duration +} + +// NewDispatcher creates a new webhook dispatcher. +// +// Why: centralizes webhook delivery config (secret, retry policy) in one place. +// What: returns a Dispatcher ready to fire events asynchronously. +// Test: create with a secret, fire an event, verify HMAC header. +func NewDispatcher(secret string, logger *slog.Logger) *Dispatcher { + return &Dispatcher{ + secret: secret, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + logger: logger, + maxRetries: 5, + baseDelay: 1 * time.Second, + } +} + +// Fire sends a webhook event to the given URL in a background goroutine. It never +// blocks the caller and never returns an error — failed deliveries are logged and +// dropped per ADR-0005. +// +// Why: webhook failures must never block or fail the worker loop. +// What: marshals the event, spawns a goroutine that retries with backoff. +// Test: fire an event at a 500-returning server, verify retries happen then stop. +func (d *Dispatcher) Fire(url string, event Event) { + go d.deliver(url, event) +} + +// deliver attempts to POST the event with retries and backoff. +func (d *Dispatcher) deliver(url string, event Event) { + body, err := json.Marshal(event) + if err != nil { + d.logger.Error("webhook marshal failed", "error", err, "job_id", event.JobID) + return + } + + for attempt := 0; attempt <= d.maxRetries; attempt++ { + if attempt > 0 { + delay := d.baseDelay * (1 << (attempt - 1)) + time.Sleep(delay) + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + d.logger.Error("webhook request creation failed", + "error", err, "url", url, "job_id", event.JobID) + return + } + req.Header.Set("Content-Type", "application/json") + + if d.secret != "" { + sig := computeHMAC(body, d.secret) + req.Header.Set("X-Foreman-Signature", "sha256="+sig) + } + + resp, err := d.httpClient.Do(req) + if err != nil { + d.logger.Warn("webhook delivery failed", + "error", err, "url", url, "job_id", event.JobID, + "attempt", attempt+1, "max", d.maxRetries+1) + continue + } + resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + d.logger.Debug("webhook delivered", + "url", url, "job_id", event.JobID, "state", event.State) + return + } + + d.logger.Warn("webhook non-2xx response", + "status", resp.StatusCode, "url", url, "job_id", event.JobID, + "attempt", attempt+1, "max", d.maxRetries+1) + } + + d.logger.Error("webhook delivery exhausted retries", + "url", url, "job_id", event.JobID, "state", event.State) +} + +// computeHMAC computes HMAC-SHA256 of body using the given key and returns the +// hex-encoded digest. +func computeHMAC(body []byte, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write(body) + return hex.EncodeToString(mac.Sum(nil)) +} + +// VerifySignature checks that the signature header matches the HMAC-SHA256 of +// the body. Exported for use by webhook receivers. +// +// Why: webhook consumers need to verify authenticity of incoming payloads. +// What: computes HMAC and compares to the provided signature using constant-time comparison. +// Test: sign a body, verify with correct and incorrect secrets. +func VerifySignature(body []byte, signature, secret string) bool { + if len(signature) < 8 || signature[:7] != "sha256=" { + return false + } + expected := computeHMAC(body, secret) + return hmac.Equal([]byte(expected), []byte(signature[7:])) +} + +// FormatArtifacts formats artifact metadata for webhook payloads. Small artifacts +// (under threshold) are inlined; large ones get a URL reference. +// +// Why: webhook bodies must stay bounded per ADR-0006 (~256KB threshold). +// What: returns JSON-encoded artifact metadata with inline data or URL references. +// Test: create artifacts above and below threshold, verify inline vs URL in output. +func FormatArtifacts(jobID string, artifacts []ArtifactMeta) json.RawMessage { + if len(artifacts) == 0 { + return nil + } + + type artifactOut struct { + Name string `json:"name"` + ContentType string `json:"content_type"` + Size int64 `json:"size"` + Data string `json:"data,omitempty"` + URL string `json:"url,omitempty"` + } + + out := make([]artifactOut, len(artifacts)) + for i, a := range artifacts { + out[i] = artifactOut{ + Name: a.Name, + ContentType: a.ContentType, + Size: a.Size, + } + if a.Size <= 256*1024 && a.Data != nil { + out[i].Data = string(a.Data) + } else { + out[i].URL = fmt.Sprintf("/jobs/%s/artifacts/%s", jobID, a.Name) + } + } + + b, _ := json.Marshal(out) + return json.RawMessage(b) +} + +// ArtifactMeta holds artifact info for webhook formatting. +type ArtifactMeta struct { + Name string + ContentType string + Size int64 + Data []byte +} diff --git a/internal/worker/worker.go b/internal/worker/worker.go new file mode 100644 index 0000000..985d13f --- /dev/null +++ b/internal/worker/worker.go @@ -0,0 +1,385 @@ +// Package worker implements the single-worker loop that pulls jobs from the +// SQLite queue, executes them against the Ollama target, and records results. +// +// Why: foreman serializes all chat work through one worker to avoid swap thrash +// on the target (ADR-0009). The worker is the only writer of job state transitions. +// What: runs a goroutine that picks the next job (drain-by-model), calls Ollama, +// stores the result, fires webhooks, and notifies waiting sync handlers. +// Test: create with a stub client, enqueue jobs, verify serial execution, +// drain-by-model ordering, retry on connection error, and completion notification. +package worker + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + "gitea.stevedudenhoeffer.com/steve/foreman/internal/ollama" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/store" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/webhook" +) + +// Notifier manages channels that sync HTTP handlers use to wait for job completion. +// +// Why: the /api/chat handler blocks until its job finishes; a notification map +// avoids polling the DB. +// What: maps job IDs to channels; the worker signals completion by closing the channel. +// Test: register a waiter, complete the job, verify the channel unblocks. +type Notifier struct { + mu sync.Mutex + waiters map[string]chan struct{} + // results stores the terminal job state so the waiter can read it after notification. + results map[string]jobResult +} + +type jobResult struct { + State store.JobState + Result json.RawMessage + Error *string +} + +// NewNotifier creates a new Notifier. +func NewNotifier() *Notifier { + return &Notifier{ + waiters: make(map[string]chan struct{}), + results: make(map[string]jobResult), + } +} + +// Register creates a wait channel for the given job ID. The caller should select +// on the returned channel and their context. +// +// Why: each sync chat handler needs its own completion signal. +// What: allocates a buffered channel keyed by job ID. +// Test: register, verify channel is open, complete, verify it closes. +func (n *Notifier) Register(jobID string) <-chan struct{} { + n.mu.Lock() + defer n.mu.Unlock() + ch := make(chan struct{}) + n.waiters[jobID] = ch + return ch +} + +// Complete signals that the job has reached a terminal state and stores the result. +// +// Why: the worker calls this when a job is done or failed; the HTTP handler unblocks. +// What: closes the wait channel and stores the result for retrieval. +// Test: register, complete, verify the channel is closed and result is available. +func (n *Notifier) Complete(jobID string, state store.JobState, result json.RawMessage, errMsg *string) { + n.mu.Lock() + defer n.mu.Unlock() + n.results[jobID] = jobResult{State: state, Result: result, Error: errMsg} + if ch, ok := n.waiters[jobID]; ok { + close(ch) + delete(n.waiters, jobID) + } +} + +// Result returns the stored result for a completed job, if any. +// +// Why: after the wait channel closes, the HTTP handler needs the result data. +// What: returns the cached result and cleans up. +// Test: complete a job, call Result, verify data, call again, verify cleaned up. +func (n *Notifier) Result(jobID string) (store.JobState, json.RawMessage, *string, bool) { + n.mu.Lock() + defer n.mu.Unlock() + r, ok := n.results[jobID] + if ok { + delete(n.results, jobID) + } + return r.State, r.Result, r.Error, ok +} + +// Worker is the single-threaded job execution loop. +type Worker struct { + store *store.Store + client ollama.Client + inventory *ollama.ModelInventory + notifier *Notifier + dispatcher *webhook.Dispatcher + logger *slog.Logger + + // wake is signaled when a new job is enqueued. + wake chan struct{} +} + +// New creates a new Worker. +// +// Why: dependency injection makes the worker testable with stub clients and stores. +// What: wires all dependencies and creates the wake channel. +// Test: create with stubs, call Run in a goroutine, enqueue a job, verify execution. +func New( + st *store.Store, + client ollama.Client, + inv *ollama.ModelInventory, + notifier *Notifier, + dispatcher *webhook.Dispatcher, + logger *slog.Logger, +) *Worker { + return &Worker{ + store: st, + client: client, + inventory: inv, + notifier: notifier, + dispatcher: dispatcher, + logger: logger, + wake: make(chan struct{}, 1), + } +} + +// Wake signals the worker that a new job may be available. Non-blocking. +// +// Why: the HTTP handlers signal the worker to check for new work immediately +// instead of waiting for the next poll cycle. +// What: sends on the wake channel (drops if already pending). +// Test: call Wake multiple times, verify no blocking. +func (w *Worker) Wake() { + select { + case w.wake <- struct{}{}: + default: + } +} + +// Run starts the worker loop. It blocks until ctx is cancelled. On startup it +// resets any interrupted jobs back to queued. +// +// Why: the main loop is the core of foreman's job execution (ADR-0009). +// What: resets interrupted jobs, then loops: pick next job, execute, record result. +// Test: enqueue jobs, run worker with a cancellable context, verify all execute. +func (w *Worker) Run(ctx context.Context) { + // Reset any jobs stuck in loading/working from a previous crash. + if n, err := w.store.ResetInterruptedJobs(); err != nil { + w.logger.Error("failed to reset interrupted jobs", "error", err) + } else if n > 0 { + w.logger.Info("reset interrupted jobs", "count", n) + } + + for { + if ctx.Err() != nil { + return + } + + currentModel := w.currentWorkerModel() + job, err := w.store.NextJob(currentModel) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + // No jobs available — wait for a wake signal or context cancel. + select { + case <-w.wake: + continue + case <-ctx.Done(): + return + } + } + w.logger.Error("failed to fetch next job", "error", err) + select { + case <-time.After(1 * time.Second): + continue + case <-ctx.Done(): + return + } + } + + w.executeJob(ctx, job) + } +} + +// currentWorkerModel returns the model name currently in the worker slot (slot 2). +// The embedder is in slot 1; any other model is the worker model. +func (w *Worker) currentWorkerModel() string { + residents := w.inventory.ResidentModels() + embedModel := w.getEmbedModel() + for _, r := range residents { + if r.Name != embedModel { + return r.Name + } + } + return "" +} + +// getEmbedModel returns the embedder model name from the first resident model +// that looks like an embedder slot. This is a heuristic — in practice the embed +// model is the one that stays loaded with keep_alive=-1. +func (w *Worker) getEmbedModel() string { + // We check inventory for a model that matches common embed model patterns. + // The simplest approach: the embedder is usually the first (smallest) resident. + // However, we can't easily know which is which without config. + // For now, return empty — drain-by-model still works because we prefer + // whatever model is resident. + return "" +} + +// executeJob runs a single job through its lifecycle. +func (w *Worker) executeJob(ctx context.Context, job store.Job) { + w.logger.Info("executing job", "job_id", job.ID, "model", job.Model, "attempt", job.Attempt) + + // Determine if we need to load a new model. + needsLoad := !w.isModelResident(job.Model) + + if needsLoad { + w.transitionState(job, store.JobStateLoading) + } + w.transitionState(job, store.JobStateWorking) + + // Parse the payload into a ChatRequest. + var chatReq ollama.ChatRequest + if err := json.Unmarshal(job.Payload, &chatReq); err != nil { + errMsg := fmt.Sprintf("invalid chat request payload: %v", err) + w.failJob(job, &errMsg) + return + } + + // Ensure model is set. + chatReq.Model = job.Model + + // Set stream to false for worker execution — we collect the full response. + streamFalse := false + chatReq.Stream = &streamFalse + + // Execute the chat request. + resp, _, err := w.client.Chat(ctx, chatReq, false) + if err != nil { + w.handleExecutionError(job, err) + return + } + + // Marshal the result. + resultBytes, err := json.Marshal(resp) + if err != nil { + errMsg := fmt.Sprintf("marshal result: %v", err) + w.failJob(job, &errMsg) + return + } + result := json.RawMessage(resultBytes) + + // Store the completion artifact. + _, artifactErr := w.store.CreateArtifact(store.Artifact{ + JobID: job.ID, + Name: "completion", + ContentType: "application/json", + Data: resultBytes, + }) + if artifactErr != nil { + w.logger.Error("failed to store artifact", "error", artifactErr, "job_id", job.ID) + } + + // Transition to done. + if err := w.store.UpdateJobState(job.ID, store.JobStateDone, result, nil); err != nil { + w.logger.Error("failed to update job to done", "error", err, "job_id", job.ID) + } + + // Notify waiting sync handlers. + w.notifier.Complete(job.ID, store.JobStateDone, result, nil) + + // Fire webhook if configured. + w.fireWebhook(job, store.JobStateDone, store.JobStateWorking, result, nil) + + w.logger.Info("job completed", "job_id", job.ID, "model", job.Model) +} + +// transitionState updates a job's state and fires a webhook. +func (w *Worker) transitionState(job store.Job, newState store.JobState) { + prevState := job.State + + if err := w.store.UpdateJobState(job.ID, newState, nil, nil); err != nil { + w.logger.Error("failed to transition job state", + "error", err, "job_id", job.ID, "from", prevState, "to", newState) + return + } + + w.fireWebhook(job, newState, prevState, nil, nil) + job.State = newState +} + +// handleExecutionError handles errors from the Ollama client during job execution. +func (w *Worker) handleExecutionError(job store.Job, err error) { + var connErr *ollama.ConnectionError + if errors.As(err, &connErr) { + // Connection error — retryable. + w.logger.Warn("job hit connection error, will retry", + "job_id", job.ID, "error", err, "attempt", job.Attempt) + + if job.Attempt+1 >= job.MaxAttempts { + errMsg := fmt.Sprintf("connection failed after %d attempts: %v", job.MaxAttempts, err) + w.failJob(job, &errMsg) + return + } + + // Re-queue with incremented attempt. + if err := w.store.IncrementAttempt(job.ID); err != nil { + w.logger.Error("failed to increment attempt", "error", err, "job_id", job.ID) + } + return + } + + // Non-connection error (HTTP 4xx/5xx from target) — terminal failure. + errMsg := fmt.Sprintf("chat execution failed: %v", err) + w.failJob(job, &errMsg) +} + +// failJob transitions a job to failed and notifies waiters. +func (w *Worker) failJob(job store.Job, errMsg *string) { + if err := w.store.UpdateJobState(job.ID, store.JobStateFailed, nil, errMsg); err != nil { + w.logger.Error("failed to mark job as failed", "error", err, "job_id", job.ID) + } + + w.notifier.Complete(job.ID, store.JobStateFailed, nil, errMsg) + w.fireWebhook(job, store.JobStateFailed, job.State, nil, errMsg) + + w.logger.Warn("job failed", "job_id", job.ID, "error", *errMsg) +} + +// fireWebhook sends a webhook event if the job has a webhook URL configured. +func (w *Worker) fireWebhook(job store.Job, state, prevState store.JobState, result json.RawMessage, errMsg *string) { + if job.StateWebhookURL == nil || *job.StateWebhookURL == "" || w.dispatcher == nil { + return + } + + event := webhook.Event{ + JobID: job.ID, + State: string(state), + PreviousState: string(prevState), + Timestamp: time.Now().UTC(), + Model: job.Model, + Attempt: job.Attempt, + Result: result, + Error: errMsg, + } + + // If done, include artifact metadata. + if state == store.JobStateDone { + artifacts, err := w.store.GetArtifactsByJob(job.ID) + if err != nil { + w.logger.Error("failed to get artifacts for webhook", "error", err, "job_id", job.ID) + } else { + metas := make([]webhook.ArtifactMeta, len(artifacts)) + for i, a := range artifacts { + metas[i] = webhook.ArtifactMeta{ + Name: a.Name, + ContentType: a.ContentType, + Size: a.Size, + Data: a.Data, + } + } + event.Artifacts = webhook.FormatArtifacts(job.ID, metas) + } + } + + w.dispatcher.Fire(*job.StateWebhookURL, event) +} + +// isModelResident checks whether the given model is currently loaded on the target. +func (w *Worker) isModelResident(model string) bool { + for _, r := range w.inventory.ResidentModels() { + if r.Name == model { + return true + } + } + return false +} + diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go new file mode 100644 index 0000000..50b6c57 --- /dev/null +++ b/internal/worker/worker_test.go @@ -0,0 +1,807 @@ +package worker + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "gitea.stevedudenhoeffer.com/steve/foreman/internal/ollama" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/store" + "gitea.stevedudenhoeffer.com/steve/foreman/internal/webhook" +) + +// openTestDB creates a fresh SQLite store in a temp directory for test isolation. +func openTestDB(t *testing.T) *store.Store { + t.Helper() + path := filepath.Join(t.TempDir(), "test.db") + s, err := store.Open(path) + if err != nil { + t.Fatalf("Open(%q): %v", path, err) + } + t.Cleanup(func() { s.Close() }) + return s +} + +// newTestWorker creates a worker with stub dependencies for testing. +func newTestWorker(t *testing.T, client ollama.Client) (*Worker, *store.Store, *Notifier) { + t.Helper() + st := openTestDB(t) + logger := slog.New(slog.NewJSONHandler(io.Discard, nil)) + inv := ollama.NewModelInventory(client, logger) + notifier := NewNotifier() + dispatcher := webhook.NewDispatcher("", logger) + w := New(st, client, inv, notifier, dispatcher, logger) + return w, st, notifier +} + +// stubOllamaClient implements ollama.Client for worker tests. +type stubOllamaClient struct { + chatFunc func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) + tags *ollama.TagsResponse + ps *ollama.PsResponse + mu sync.Mutex + chatCalls []ollama.ChatRequest + callCount atomic.Int32 +} + +func (s *stubOllamaClient) Chat(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + s.callCount.Add(1) + s.mu.Lock() + s.chatCalls = append(s.chatCalls, req) + s.mu.Unlock() + if s.chatFunc != nil { + return s.chatFunc(ctx, req, stream) + } + return &ollama.ChatResponse{ + Model: req.Model, + Done: true, + Message: &ollama.Message{Role: "assistant", Content: "test response"}, + }, nil, nil +} + +func (s *stubOllamaClient) Embed(ctx context.Context, req ollama.EmbedRequest) (*ollama.EmbedResponse, error) { + return nil, fmt.Errorf("not implemented") +} + +func (s *stubOllamaClient) Tags(ctx context.Context) (*ollama.TagsResponse, error) { + if s.tags != nil { + return s.tags, nil + } + return &ollama.TagsResponse{}, nil +} + +func (s *stubOllamaClient) Ps(ctx context.Context) (*ollama.PsResponse, error) { + if s.ps != nil { + return s.ps, nil + } + return &ollama.PsResponse{}, nil +} + +func (s *stubOllamaClient) RawChat(ctx context.Context, body []byte) (*http.Response, error) { + return nil, fmt.Errorf("not implemented") +} + +func (s *stubOllamaClient) RawEmbed(ctx context.Context, body []byte) (*http.Response, error) { + return nil, fmt.Errorf("not implemented") +} + +func TestWorker_ExecutesSingleJob(t *testing.T) { + client := &stubOllamaClient{} + w, st, notifier := newTestWorker(t, client) + + // Create a job. + job := store.Job{ + ID: "01TEST001", + Model: "qwen3:30b", + Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}`), + MaxAttempts: 3, + } + if _, err := st.CreateJob(job); err != nil { + t.Fatalf("CreateJob: %v", err) + } + + // Register a waiter. + waitCh := notifier.Register("01TEST001") + + // Run the worker. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Run(ctx) + + // Wait for the job to complete. + select { + case <-waitCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for job to complete") + } + + // Check the result. + state, result, errMsg, ok := notifier.Result("01TEST001") + if !ok { + t.Fatal("no result available") + } + if state != store.JobStateDone { + t.Errorf("state = %q, want %q", state, store.JobStateDone) + } + if errMsg != nil { + t.Errorf("unexpected error: %s", *errMsg) + } + if result == nil { + t.Fatal("result should not be nil") + } + + // Verify the job in the store. + got, err := st.GetJob("01TEST001") + if err != nil { + t.Fatalf("GetJob: %v", err) + } + if got.State != store.JobStateDone { + t.Errorf("stored state = %q, want %q", got.State, store.JobStateDone) + } + if got.CompletedAt == nil { + t.Error("CompletedAt should be set") + } + + // Verify artifact was created. + artifact, err := st.GetArtifact("01TEST001", "completion") + if err != nil { + t.Fatalf("GetArtifact: %v", err) + } + if artifact.ContentType != "application/json" { + t.Errorf("artifact content_type = %q, want %q", artifact.ContentType, "application/json") + } +} + +func TestWorker_SerialExecution(t *testing.T) { + var inflight atomic.Int32 + var maxInflight atomic.Int32 + + client := &stubOllamaClient{ + chatFunc: func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + cur := inflight.Add(1) + defer inflight.Add(-1) + for { + old := maxInflight.Load() + if cur <= old || maxInflight.CompareAndSwap(old, cur) { + break + } + } + time.Sleep(30 * time.Millisecond) + return &ollama.ChatResponse{Model: req.Model, Done: true, Message: &ollama.Message{Role: "assistant", Content: "ok"}}, nil, nil + }, + } + w, st, notifier := newTestWorker(t, client) + + // Create multiple jobs. + for i := 0; i < 3; i++ { + id := fmt.Sprintf("01SERIAL%03d", i) + job := store.Job{ + ID: id, + Model: "qwen3:30b", + Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}`), + MaxAttempts: 3, + } + if _, err := st.CreateJob(job); err != nil { + t.Fatalf("CreateJob: %v", err) + } + } + + // Register waiters for all jobs. + waitCh := notifier.Register("01SERIAL002") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Run(ctx) + + // Wait for last job. + select { + case <-waitCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for jobs to complete") + } + + if got := maxInflight.Load(); got > 1 { + t.Errorf("max concurrent executions = %d, want 1", got) + } + if got := client.callCount.Load(); got != 3 { + t.Errorf("chat call count = %d, want 3", got) + } +} + +func TestWorker_DrainByModel(t *testing.T) { + var executionOrder []string + var mu sync.Mutex + + client := &stubOllamaClient{ + ps: &ollama.PsResponse{ + Models: []ollama.RunningModel{ + {Name: "qwen3:30b"}, + }, + }, + chatFunc: func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + mu.Lock() + executionOrder = append(executionOrder, req.Model) + mu.Unlock() + return &ollama.ChatResponse{Model: req.Model, Done: true, Message: &ollama.Message{Role: "assistant", Content: "ok"}}, nil, nil + }, + } + w, st, notifier := newTestWorker(t, client) + + // Refresh inventory to pick up the running model. + if err := w.inventory.Refresh(context.Background()); err != nil { + t.Fatalf("Refresh: %v", err) + } + + // Create jobs: interleave two models, but qwen3:30b is currently resident. + // job1: qwen3:14b (not resident) + // job2: qwen3:30b (resident) + // job3: qwen3:14b (not resident) + // job4: qwen3:30b (resident) + jobs := []struct { + id string + model string + }{ + {"01DRAIN001", "qwen3:14b"}, + {"01DRAIN002", "qwen3:30b"}, + {"01DRAIN003", "qwen3:14b"}, + {"01DRAIN004", "qwen3:30b"}, + } + + for _, j := range jobs { + job := store.Job{ + ID: j.id, + Model: j.model, + Payload: json.RawMessage(fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"hi"}]}`, j.model)), + MaxAttempts: 3, + } + if _, err := st.CreateJob(job); err != nil { + t.Fatalf("CreateJob %s: %v", j.id, err) + } + } + + // Wait for last job. + waitCh := notifier.Register("01DRAIN004") + // Also register for the non-resident ones so we know when everything is done. + waitCh3 := notifier.Register("01DRAIN003") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Run(ctx) + + // Wait for all jobs. + for _, ch := range []<-chan struct{}{waitCh, waitCh3} { + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for jobs to complete") + } + } + + mu.Lock() + defer mu.Unlock() + + // Drain-by-model: the resident model (qwen3:30b) jobs should execute first, + // then the non-resident model (qwen3:14b) jobs. + if len(executionOrder) != 4 { + t.Fatalf("executed %d jobs, want 4", len(executionOrder)) + } + + // First two should be qwen3:30b (the resident model). + if executionOrder[0] != "qwen3:30b" || executionOrder[1] != "qwen3:30b" { + t.Errorf("first two executions = %v, want [qwen3:30b, qwen3:30b]", executionOrder[:2]) + } + // Last two should be qwen3:14b. + if executionOrder[2] != "qwen3:14b" || executionOrder[3] != "qwen3:14b" { + t.Errorf("last two executions = %v, want [qwen3:14b, qwen3:14b]", executionOrder[2:]) + } +} + +func TestWorker_RetryOnConnectionError(t *testing.T) { + callCount := atomic.Int32{} + + client := &stubOllamaClient{ + chatFunc: func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + n := callCount.Add(1) + if n == 1 { + // First call fails with connection error. + return nil, nil, &ollama.ConnectionError{URL: "http://test", Err: fmt.Errorf("connection refused")} + } + // Second call succeeds. + return &ollama.ChatResponse{Model: req.Model, Done: true, Message: &ollama.Message{Role: "assistant", Content: "ok"}}, nil, nil + }, + } + w, st, notifier := newTestWorker(t, client) + + job := store.Job{ + ID: "01RETRY001", + Model: "qwen3:30b", + Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}`), + MaxAttempts: 3, + } + if _, err := st.CreateJob(job); err != nil { + t.Fatalf("CreateJob: %v", err) + } + + waitCh := notifier.Register("01RETRY001") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Run(ctx) + + select { + case <-waitCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for job to complete") + } + + state, _, _, _ := notifier.Result("01RETRY001") + if state != store.JobStateDone { + t.Errorf("state = %q, want %q", state, store.JobStateDone) + } + + if got := callCount.Load(); got != 2 { + t.Errorf("chat calls = %d, want 2 (1 fail + 1 success)", got) + } + + // Verify attempt was incremented in the store. + got, err := st.GetJob("01RETRY001") + if err != nil { + t.Fatalf("GetJob: %v", err) + } + if got.Attempt != 1 { + t.Errorf("attempt = %d, want 1 (incremented once from retry)", got.Attempt) + } +} + +func TestWorker_MaxAttemptsExhausted(t *testing.T) { + client := &stubOllamaClient{ + chatFunc: func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + return nil, nil, &ollama.ConnectionError{URL: "http://test", Err: fmt.Errorf("connection refused")} + }, + } + w, st, notifier := newTestWorker(t, client) + + job := store.Job{ + ID: "01MAXATT001", + Model: "qwen3:30b", + Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}`), + MaxAttempts: 2, + } + if _, err := st.CreateJob(job); err != nil { + t.Fatalf("CreateJob: %v", err) + } + + waitCh := notifier.Register("01MAXATT001") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Run(ctx) + + select { + case <-waitCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for job to fail") + } + + state, _, errMsg, ok := notifier.Result("01MAXATT001") + if !ok { + t.Fatal("no result available") + } + if state != store.JobStateFailed { + t.Errorf("state = %q, want %q", state, store.JobStateFailed) + } + if errMsg == nil { + t.Fatal("error message should be set") + } + + got, _ := st.GetJob("01MAXATT001") + if got.State != store.JobStateFailed { + t.Errorf("stored state = %q, want %q", got.State, store.JobStateFailed) + } +} + +func TestWorker_HTTPErrorIsTerminal(t *testing.T) { + client := &stubOllamaClient{ + chatFunc: func(ctx context.Context, req ollama.ChatRequest, stream bool) (*ollama.ChatResponse, <-chan ollama.ChatResponse, error) { + return nil, nil, &ollama.HTTPError{StatusCode: 400, Body: "bad request"} + }, + } + w, st, notifier := newTestWorker(t, client) + + job := store.Job{ + ID: "01HTTP001", + Model: "qwen3:30b", + Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}`), + MaxAttempts: 3, + } + if _, err := st.CreateJob(job); err != nil { + t.Fatalf("CreateJob: %v", err) + } + + waitCh := notifier.Register("01HTTP001") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Run(ctx) + + select { + case <-waitCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for job to fail") + } + + state, _, _, _ := notifier.Result("01HTTP001") + if state != store.JobStateFailed { + t.Errorf("state = %q, want %q (HTTP errors should be terminal)", state, store.JobStateFailed) + } + + // Verify only one attempt was made (no retries for HTTP errors). + if got := client.callCount.Load(); got != 1 { + t.Errorf("chat calls = %d, want 1 (HTTP errors should not retry)", got) + } +} + +func TestWorker_ResetInterruptedJobsOnStartup(t *testing.T) { + client := &stubOllamaClient{} + w, st, notifier := newTestWorker(t, client) + + // Manually create jobs in loading and working states (simulating a crash). + job1 := store.Job{ + ID: "01RESET001", + Model: "qwen3:30b", + Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}`), + MaxAttempts: 3, + } + if _, err := st.CreateJob(job1); err != nil { + t.Fatalf("CreateJob: %v", err) + } + if err := st.UpdateJobState("01RESET001", store.JobStateLoading, nil, nil); err != nil { + t.Fatalf("UpdateJobState: %v", err) + } + + job2 := store.Job{ + ID: "01RESET002", + Model: "qwen3:30b", + Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hello"}]}`), + MaxAttempts: 3, + } + if _, err := st.CreateJob(job2); err != nil { + t.Fatalf("CreateJob: %v", err) + } + if err := st.UpdateJobState("01RESET002", store.JobStateWorking, nil, nil); err != nil { + t.Fatalf("UpdateJobState: %v", err) + } + + // Register waiters. + waitCh1 := notifier.Register("01RESET001") + waitCh2 := notifier.Register("01RESET002") + + // Start the worker — it should reset interrupted jobs and then process them. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Run(ctx) + + // Wait for both jobs to complete. + for _, ch := range []<-chan struct{}{waitCh1, waitCh2} { + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for reset jobs to complete") + } + } + + // Both should be done now. + for _, id := range []string{"01RESET001", "01RESET002"} { + got, err := st.GetJob(id) + if err != nil { + t.Fatalf("GetJob %s: %v", id, err) + } + if got.State != store.JobStateDone { + t.Errorf("job %s state = %q, want %q", id, got.State, store.JobStateDone) + } + } +} + +func TestNotifier_RegisterAndComplete(t *testing.T) { + n := NewNotifier() + + ch := n.Register("test-job") + + // Channel should not be closed yet. + select { + case <-ch: + t.Fatal("channel should not be closed before completion") + default: + } + + // Complete the job. + result := json.RawMessage(`{"done":true}`) + n.Complete("test-job", store.JobStateDone, result, nil) + + // Channel should be closed now. + select { + case <-ch: + // Expected. + default: + t.Fatal("channel should be closed after completion") + } + + // Get the result. + state, res, errMsg, ok := n.Result("test-job") + if !ok { + t.Fatal("result should be available") + } + if state != store.JobStateDone { + t.Errorf("state = %q, want %q", state, store.JobStateDone) + } + if string(res) != `{"done":true}` { + t.Errorf("result = %s, want %s", res, `{"done":true}`) + } + if errMsg != nil { + t.Errorf("unexpected error: %s", *errMsg) + } + + // Second call should return not-found (cleaned up). + _, _, _, ok = n.Result("test-job") + if ok { + t.Error("result should be cleaned up after first retrieval") + } +} + +func TestNotifier_CompleteWithoutRegister(t *testing.T) { + n := NewNotifier() + + // Complete a job that nobody is waiting for. Should not panic. + n.Complete("orphan-job", store.JobStateDone, nil, nil) + + // Result should still be retrievable even without a registered waiter. + state, _, _, ok := n.Result("orphan-job") + if !ok { + t.Fatal("result should be available even without registered waiter") + } + if state != store.JobStateDone { + t.Errorf("state = %q, want %q", state, store.JobStateDone) + } +} + +func TestWorker_WakeSignal(t *testing.T) { + client := &stubOllamaClient{} + w, st, notifier := newTestWorker(t, client) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Run(ctx) + + // Give the worker time to start and block on the empty queue. + time.Sleep(50 * time.Millisecond) + + // Now add a job and wake the worker. + job := store.Job{ + ID: "01WAKE001", + Model: "qwen3:30b", + Payload: json.RawMessage(`{"model":"qwen3:30b","messages":[{"role":"user","content":"hi"}]}`), + MaxAttempts: 3, + } + if _, err := st.CreateJob(job); err != nil { + t.Fatalf("CreateJob: %v", err) + } + + waitCh := notifier.Register("01WAKE001") + w.Wake() + + select { + case <-waitCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out: worker did not process job after wake signal") + } + + state, _, _, _ := notifier.Result("01WAKE001") + if state != store.JobStateDone { + t.Errorf("state = %q, want %q", state, store.JobStateDone) + } +} + +func TestStore_NextJobDrainByModel(t *testing.T) { + st := openTestDB(t) + + // Create jobs interleaved. + for _, j := range []struct { + id string + model string + }{ + {"01A", "modelA"}, + {"01B", "modelB"}, + {"01C", "modelA"}, + {"01D", "modelB"}, + } { + _, err := st.CreateJob(store.Job{ + ID: j.id, + Model: j.model, + Payload: json.RawMessage(`{}`), + }) + if err != nil { + t.Fatalf("CreateJob: %v", err) + } + } + + // With currentModel = modelB, we should get modelB jobs first. + j1, err := st.NextJob("modelB") + if err != nil { + t.Fatalf("NextJob: %v", err) + } + if j1.Model != "modelB" { + t.Errorf("first job model = %q, want modelB", j1.Model) + } + + // Mark it done and get next. + st.UpdateJobState(j1.ID, store.JobStateDone, nil, nil) + + j2, err := st.NextJob("modelB") + if err != nil { + t.Fatalf("NextJob: %v", err) + } + if j2.Model != "modelB" { + t.Errorf("second job model = %q, want modelB", j2.Model) + } + + // Mark done, now should get modelA. + st.UpdateJobState(j2.ID, store.JobStateDone, nil, nil) + + j3, err := st.NextJob("modelB") + if err != nil { + t.Fatalf("NextJob: %v", err) + } + if j3.Model != "modelA" { + t.Errorf("third job model = %q, want modelA", j3.Model) + } +} + +func TestStore_NextJobEmptyQueue(t *testing.T) { + st := openTestDB(t) + + _, err := st.NextJob("any") + if !errors.Is(err, sql.ErrNoRows) { + t.Errorf("NextJob on empty queue: err = %v, want sql.ErrNoRows", err) + } +} + +func TestStore_IncrementAttempt(t *testing.T) { + st := openTestDB(t) + + _, err := st.CreateJob(store.Job{ + ID: "01INC", + Model: "m", + Payload: json.RawMessage(`{}`), + }) + if err != nil { + t.Fatalf("CreateJob: %v", err) + } + + // Mark as working, then increment. + st.UpdateJobState("01INC", store.JobStateWorking, nil, nil) + if err := st.IncrementAttempt("01INC"); err != nil { + t.Fatalf("IncrementAttempt: %v", err) + } + + got, _ := st.GetJob("01INC") + if got.Attempt != 1 { + t.Errorf("attempt = %d, want 1", got.Attempt) + } + if got.State != store.JobStateQueued { + t.Errorf("state = %q, want %q (should be re-queued)", got.State, store.JobStateQueued) + } +} + +func TestStore_ResetInterruptedJobs(t *testing.T) { + st := openTestDB(t) + + for _, j := range []struct { + id string + state store.JobState + }{ + {"01A", store.JobStateQueued}, + {"01B", store.JobStateLoading}, + {"01C", store.JobStateWorking}, + {"01D", store.JobStateDone}, + {"01E", store.JobStateFailed}, + } { + _, err := st.CreateJob(store.Job{ID: j.id, Model: "m", Payload: json.RawMessage(`{}`)}) + if err != nil { + t.Fatalf("CreateJob: %v", err) + } + if j.state != store.JobStateQueued { + st.UpdateJobState(j.id, j.state, nil, nil) + } + } + + n, err := st.ResetInterruptedJobs() + if err != nil { + t.Fatalf("ResetInterruptedJobs: %v", err) + } + if n != 2 { + t.Errorf("reset count = %d, want 2", n) + } + + // Verify loading and working are back to queued. + for _, id := range []string{"01B", "01C"} { + j, _ := st.GetJob(id) + if j.State != store.JobStateQueued { + t.Errorf("job %s state = %q, want %q", id, j.State, store.JobStateQueued) + } + } + + // Verify done and failed are untouched. + for _, tc := range []struct { + id string + want store.JobState + }{ + {"01D", store.JobStateDone}, + {"01E", store.JobStateFailed}, + } { + j, _ := st.GetJob(tc.id) + if j.State != tc.want { + t.Errorf("job %s state = %q, want %q", tc.id, j.State, tc.want) + } + } +} + +func TestStore_DeleteTerminalJobsBefore(t *testing.T) { + st := openTestDB(t) + + // Create some terminal jobs. + for _, j := range []struct { + id string + state store.JobState + }{ + {"01OLD1", store.JobStateDone}, + {"01OLD2", store.JobStateFailed}, + {"01ACTIVE", store.JobStateQueued}, + } { + _, err := st.CreateJob(store.Job{ID: j.id, Model: "m", Payload: json.RawMessage(`{}`)}) + if err != nil { + t.Fatalf("CreateJob: %v", err) + } + if j.state != store.JobStateQueued { + errMsg := "some error" + var errPtr *string + if j.state == store.JobStateFailed { + errPtr = &errMsg + } + st.UpdateJobState(j.id, j.state, nil, errPtr) + } + } + + // Delete terminal jobs older than right now (all terminal jobs are "old"). + cutoff := time.Now().UTC().Add(1 * time.Minute) + n, err := st.DeleteTerminalJobsBefore(cutoff) + if err != nil { + t.Fatalf("DeleteTerminalJobsBefore: %v", err) + } + if n != 2 { + t.Errorf("deleted = %d, want 2", n) + } + + // Active job should still exist. + _, err = st.GetJob("01ACTIVE") + if err != nil { + t.Errorf("active job should still exist: %v", err) + } + + // Deleted jobs should be gone. + for _, id := range []string{"01OLD1", "01OLD2"} { + _, err := st.GetJob(id) + if !errors.Is(err, sql.ErrNoRows) { + t.Errorf("job %s should be deleted but got err: %v", id, err) + } + } +} diff --git a/progress.md b/progress.md index 9463dd7..2dd594b 100644 --- a/progress.md +++ b/progress.md @@ -66,3 +66,54 @@ The Mac is now usable as a go-llm target through foreman: `llm.OllamaCloud(token, WithBaseURL("http://foreman:8080"))` works transparently for chat (streaming + non-streaming), tags, ps, and embeddings. + +## Phase 3: Durable queue, single worker, drain-by-model — 2026-05-23 + +**M0 complete.** The Phase 2 in-flight chat gate (buffered channel) is replaced +with the real SQLite-backed job queue and single worker loop. + +- `internal/store/` — new store methods: + - `NextJob(currentModel)`: drain-by-model ordering — prefers jobs matching the + currently-resident model to minimize swap cost, then FIFO by created_at. + - `IncrementAttempt(id)`: bumps attempt counter and re-queues for retry. + - `ResetInterruptedJobs()`: resets loading/working jobs to queued on startup + (crash recovery). + - `DeleteTerminalJobsBefore(cutoff)`: TTL pruner for old done/failed jobs. + - SQLite DSN now includes `_pragma=busy_timeout(5000)` for reliable concurrent + access from HTTP handlers + worker. + +- `internal/worker/` — single worker loop (`worker.go`): + - `Worker.Run(ctx)`: main goroutine loop — resets interrupted jobs on startup, + then continuously picks the next job using drain-by-model ordering, executes + via the Ollama client, stores result + completion artifact, notifies waiters. + - `Worker.Wake()`: non-blocking signal for new job availability. + - `Notifier`: sync.Map-based completion notification — HTTP handlers register + a channel per job ID, the worker closes it on completion. Supports + `Register()`, `Complete()`, `Result()`. + - Retry semantics: `*ollama.ConnectionError` causes re-queue with incremented + attempt; `*ollama.HTTPError` is a terminal failure (no retry). Max attempts + configurable via `FOREMAN_MAX_ATTEMPTS` (default 3). + - The worker loop never panics — all errors are logged, jobs are marked, loop + continues. + +- `internal/server/` — chat handler rewrite: + - `POST /api/chat` now creates a job row (state `queued`), registers a + completion waiter, wakes the worker, and blocks until the job reaches a + terminal state. Returns the Ollama response on success, 502 on failure. + - ULID job IDs generated at submission time (`github.com/oklog/ulid/v2`). + - The old `chatGate` (buffered channel) is removed entirely. + - `/api/embed` and `/api/embeddings` remain direct concurrent proxies (unchanged + from Phase 2, per ADR-0013). + +- `internal/config/` — new config fields: + - `FOREMAN_MAX_ATTEMPTS` (int, default 3) + - `FOREMAN_JOB_TTL` (duration, default 24h) + +- Tests (all passing with `-race`): + - Worker: single job execution, serial enforcement, drain-by-model ordering, + retry on connection error, max attempts exhaustion, HTTP error terminal + failure, interrupted job reset on startup, wake signal, notifier lifecycle. + - Store: NextJob drain-by-model, empty queue, IncrementAttempt, ResetInterrupted, + DeleteTerminalJobsBefore. + - Server: chat model validation (404), non-streaming chat through queue, + serialization (max 1 concurrent), context cancellation, embed bypass unchanged.