diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml new file mode 100644 index 0000000..6179aa2 --- /dev/null +++ b/.gitea/workflows/ci.yml @@ -0,0 +1,82 @@ +name: executus CI + +# Go library CI: build, vet, race-tested, tidy-clean, plus the executus +# invariant that the CORE module never pulls a host/DB dependency. Mirrors +# majordomo's gates; private-module access (the private majordomo dependency) +# uses the same Gitea credentials gadfly's CI uses. +# +# Required repo secrets: +# REGISTRY_USER / REGISTRY_PASSWORD Gitea creds with read access to the +# private majordomo module. + +on: + push: + branches: [main] + tags: ["v*"] + pull_request: + types: [opened, synchronize, reopened] + workflow_dispatch: {} + +concurrency: + group: executus-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - name: Checkout + run: | + REPO_URL="https://token:${{ github.token }}@gitea.stevedudenhoeffer.com/${{ github.repository }}.git" + if [ "${{ github.event_name }}" = "pull_request" ]; then + git clone --depth=1 --branch "${{ github.head_ref }}" "$REPO_URL" . + else + git clone --depth=1 --branch "${{ github.ref_name }}" "$REPO_URL" . + fi + + - name: Set up Go + run: | + GO_VERSION=$(grep '^go ' go.mod | awk '{print $2}') + curl -sL "https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz" | tar -C /usr/local -xzf - + echo "/usr/local/go/bin" >> $GITHUB_PATH + echo "GOPATH=${HOME}/go" >> $GITHUB_ENV + echo "${HOME}/go/bin" >> $GITHUB_PATH + + - name: Configure private module access + env: + REGISTRY_USER: ${{ secrets.REGISTRY_USER }} + REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} + run: | + git config --global url."https://${REGISTRY_USER}:${REGISTRY_PASSWORD}@gitea.stevedudenhoeffer.com/".insteadOf "https://gitea.stevedudenhoeffer.com/" + echo "GOFLAGS=-mod=mod" >> $GITHUB_ENV + echo "GONOSUMCHECK=gitea.stevedudenhoeffer.com/*" >> $GITHUB_ENV + echo "GONOSUMDB=gitea.stevedudenhoeffer.com/*" >> $GITHUB_ENV + echo "GOPRIVATE=gitea.stevedudenhoeffer.com/*" >> $GITHUB_ENV + + - name: Build + run: go build ./... + + - name: Vet + run: go vet ./... + + - name: Test (race) + run: go test -race -count=1 -timeout 5m ./... + + - name: go mod tidy is clean + run: | + go mod tidy + git diff --exit-code go.mod go.sum + + - name: Core stays majordomo+stdlib only + run: | + # The core module must never pull a host/DB dependency. If any of these + # appear in go.sum, a battery leaked into the core import graph. + [ -f go.sum ] || { echo "OK: no external dependencies yet."; exit 0; } + FORBIDDEN='gorm.io|go-redis|redis/go-redis|bwmarrin/discordgo|modernc.org/sqlite|mattn/go-sqlite3|gin-gonic/gin' + if grep -qE "$FORBIDDEN" go.sum; then + echo "ERROR: forbidden dependency in core go.sum:" + grep -E "$FORBIDDEN" go.sum + exit 1 + fi + echo "OK: core go.sum is free of host/DB dependencies." diff --git a/.gitignore b/.gitignore index 5b90e79..fb8687f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,27 +1,20 @@ -# ---> Go -# If you prefer the allow list template instead of the deny list, see community template: -# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore -# -# Binaries for programs and plugins +# Binaries *.exe -*.exe~ -*.dll -*.so -*.dylib - -# Test binary, built with `go test -c` *.test - -# Output of the go coverage tool, specifically when used with LiteIDE *.out +/bin/ +/dist/ -# Dependency directories (remove the comment below to include it) -# vendor/ - -# Go workspace file +# Go workspace / local overrides go.work go.work.sum -# env file +# Env / secrets (never commit) .env +.env.* +!.env.example +# Editor / OS +.idea/ +.vscode/ +.DS_Store diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..be5d25b --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,99 @@ +# executus — developer & agent guide + +> ⚠️ **This project is vibe-coded** (AI-authored, human-steered). See `README.md`. + +executus is a **batteries-included base for LLM agent harnesses**, layered +strictly above [majordomo]. majordomo is the lean substrate (agent loop, `llm` +types, providers, media, parse/failover/tiering). executus is the opinionated +layer majordomo deliberately omits. **executus requires no majordomo changes** — +it decorates `llm.Model` and wraps `majordomo/agent.Agent`. + +[majordomo]: https://gitea.stevedudenhoeffer.com/steve/majordomo + +## North star + +A brand-new project imports executus, does a little setup, and is most of the way +to agentic capabilities. The mechanism is **one shipped default per seam**: +`executus.New()` (once the runtime lands) is agentic with zero host wiring; the +same builder lets a serious host swap each default for its own implementation and +register its own tools. + +Two consumers define the envelope: + +- **mort** (heavy) — Discord, mortbux, media, MySQL/GORM, DB-backed convar config, + saved skills, audit, scheduling, run-critic. +- **gadfly** (light) — a CI PR-reviewer Docker image, env-var configured, running + an N-models × M-lenses structured-output swarm. Needs model fleet, lanes, + bounded runs, structured output, fan-out, a few read tools — and **none** of the + batteries. + +That spread is why executus is **tiered**: a light host imports core only; a heavy +host opts into batteries. + +## Module & layering + +One module `gitea.stevedudenhoeffer.com/steve/executus`, `go.mod` = **majordomo + +stdlib only** (no gorm/redis/discordgo/cgo). A second nested module +`contrib/store` carries the SQLite dependency so the core never inherits it. + +``` +CORE (majordomo + stdlib): + config/ ConfigSource seam (+ env default) [P0 ✓] + lane/ bounded fair-share worker pool [P0 ✓] + fanout/ programmatic N×M swarm [P0 ✓] + deliver/ output egress seam (+ Discard/Stdout) [P0 ✓] + identity/ caller identity seams [P0 ✓] + run/ progress bridge now; the executor kernel + [P0 partial] + nil-safe Ports + RunnableAgent later [P2] + dispatchguard/ loop/depth/fan-out caps [P0 ✓] + pendingattach/ attachment dedupe [P0 ✓] + tool/ registry + 3-stage permissions + ssrf/llmmeta [P1] + model/ config-driven tier resolution over majordomo [P1] + compact/ context compactor (WithCompactor hook) [P2] + tools/{web,net,store,compose,meta,comms} generic tools [P3] + structured/ Generate[T] convenience over majordomo [P1] + +BATTERIES (opt-in siblings, each nil-safe + a default): + persona/ Agent noun + AgentStore seam + yml loader [P4] + skill/ rich Skill + SkillStore seam + toml loader [P4] + audit/ run-trace Sink (+ Noop/Slog) [P4] + critic/ two-tier timeout state machine + Escalator [P4] + schedule/ cron runner cores [P4] + checkpoint/ durable resume seam [P4] + budget/ rolling-window tracker (+ NoOp) [P4] + +contrib/store/ SECOND module (+ modernc.org/sqlite): [P4] + in-memory + pure-Go SQLite impls of every *Store seam +``` + +### The one architectural move + +The kernel must import **no battery**. In mort today, `agentexec` imports +`agents`, `agentcritic`, and `skillaudit` directly — those three up-pointing edges +get inverted into nil-safe `run.Ports` interfaces (`PaletteSource`, `Critic`, +`Audit`) plus a `RunnableAgent` DTO. Everything else is wide-but-shallow +repackaging. + +## Invariants (enforced in CI) + +- The core module builds with **majordomo + stdlib only**. `go.sum` must not + contain gorm/redis/discordgo/sqlite/gin. +- No `core/*` package imports a `battery/*` package. +- Standard Go gates: `go build`, `go vet`, `go test -race`, `go mod tidy` clean. + +## Extraction roadmap + +P0 module + zero-coupling moves + core seams (this) → P1 tool registry + model → +P2 run kernel + Ports inversion → P3 generic tools + defaults → P4 persona/skill +redesign + batteries + SQLite store → P5 gadfly on core (light-tier canary) → P6 +rewire mort + tag v0.1.0. The mort-side rewrite reuses mort's existing +`mort_*_adapters.go` wall as the host adapter layer. + +## Conventions + +- Mirror majordomo's house style: gofmt; check errors immediately and wrap with + `fmt.Errorf("...: %w", err)`; `// Why:` comments where rationale isn't obvious; + hermetic tests (majordomo's fake provider; no network in the default suite). +- Every seam is an interface with a nil-safe accessor and a shipped default. +- Keep the core seam surface small and stable — push churn into tools and host + adapters, not core interfaces. diff --git a/README.md b/README.md index 04e36a3..367d652 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,61 @@ # executus +> ⚠️ **This project is vibe-coded.** +> executus is written almost entirely by an AI coding agent (Claude), with a +> human steering at the design and review level rather than typing the code. +> That's a deliberate choice, stated up front — the same way [gadfly] is. Read +> the code before you depend on it, pin a version, and file issues if something +> looks off. It is offered as-is. + +[gadfly]: https://gitea.stevedudenhoeffer.com/steve/gadfly + +A **batteries-included base for building LLM agent harnesses in Go.** Import it, +do a little wiring, and you have agentic capabilities: a bounded run loop, a tool +registry with a suite of common tools, context compaction, config-driven model +tiering and failover, structured output, and parallel fan-out — with sensible +defaults so a brand-new project is agentic with almost no setup, and pluggable +seams so a serious host can swap in its own storage, config, delivery, and tools. + +executus sits **strictly above** [majordomo] — the lean LLM substrate (agent +loop, canonical `llm` types, providers, media normalization, model parsing / +failover / tiering). majordomo stays the substrate; executus is the opinionated, +batteries-included layer on top. executus requires **no changes to majordomo**. + +[majordomo]: https://gitea.stevedudenhoeffer.com/steve/majordomo + +## Status + +Early. Being extracted, phase by phase, from the agent layer of [mort] (a Discord +bot) — mort and gadfly are the first two consumers (heavy and light). See +`CLAUDE.md` for the architecture and the extraction roadmap (P0–P6). + +[mort]: https://gitea.stevedudenhoeffer.com/steve/mort + +**Available today (P0):** + +- `lane/` — bounded worker pool with fair-share queueing (run- and + provider-concurrency). +- `fanout/` — programmatic N×M swarm with bounded global + per-key concurrency. +- `config/` — the host config seam (`Source`) with an env-var default. +- `deliver/` — the output-egress seam with `Discard`/`Stdout` defaults. +- `identity/` — caller-identity seams (`AdminPolicy`, `MemberResolver`). +- `dispatchguard/`, `pendingattach/`, `run/progress.go` — run-safety primitives. + +## Design + +Two tiers in one module (`go.mod` = majordomo + stdlib only): + +- **Core** — everything a light host needs to be agentic: run loop, tool + registry + common tools, model resolution, compaction, lanes, fan-out, + structured output. No persistence, no scheduling. +- **Batteries** (opt-in sibling packages) — persona/agent nouns, saved skills, + audit, run-critic, scheduling, budgets, checkpointing. Each is nil-safe and + ships a default, so you add only what you use. + +Persistence that needs a real database lives in a **separate** nested module +(`contrib/store`, pure-Go SQLite) so the core never drags in a DB driver — a +static-binary host (gadfly) stays static. + +## License + +TBD. diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..5c3c020 --- /dev/null +++ b/config/config.go @@ -0,0 +1,137 @@ +// Package config is executus's runtime-configuration seam. +// +// A host supplies a Source so the harness can read tunable knobs (model tiers, +// caps, thresholds, lane widths) without depending on any particular config +// backend. Mort adapts its DB-backed convar.Manager; Gadfly adapts environment +// variables; a brand-new project can use Env (or pass a nil Source and rely on +// the code defaults every reader provides). +// +// Design rules: +// - Every accessor takes a code default. A Source is NEVER required to know a +// key — readers degrade to the default, so the harness runs with zero config. +// - Reads are LIVE: callers read on every use so a host whose backend mutates +// at runtime (e.g. convar) propagates without a restart. Sources that cache +// (mort's 5-minute convar cache) may additionally implement Reloader to +// signal invalidation. +package config + +import ( + "os" + "strconv" + "strings" +) + +// Source is the host configuration seam. All methods take a default and must be +// safe for concurrent use. +type Source interface { + String(key, def string) string + Int(key string, def int) int + Float(key string, def float64) float64 + Bool(key string, def bool) bool +} + +// Reloader is an optional capability for Sources whose values can change at +// runtime and that can notify watchers (e.g. a tier-reload or cache +// invalidation). Sources that do not implement it are simply read live on every +// access. Watch returns a cancel func; a nil-safe no-op is acceptable. +type Reloader interface { + Watch(prefix string, fn func(key string)) (cancel func()) +} + +// Nil-safe package helpers: callers that may hold a nil Source use these instead +// of dereferencing. They let every battery treat config as optional. + +func String(s Source, key, def string) string { + if s == nil { + return def + } + return s.String(key, def) +} + +func Int(s Source, key string, def int) int { + if s == nil { + return def + } + return s.Int(key, def) +} + +func Float(s Source, key string, def float64) float64 { + if s == nil { + return def + } + return s.Float(key, def) +} + +func Bool(s Source, key string, def bool) bool { + if s == nil { + return def + } + return s.Bool(key, def) +} + +// Env is the default Source: it reads process environment variables. A key is +// mapped to an env var name by uppercasing it and replacing every rune outside +// [A-Za-z0-9] with '_', then prefixing. So Env("GADFLY_").String("models", "") +// reads GADFLY_MODELS, and Env("").Int("model.tier.fast.max_steps", 8) reads +// MODEL_TIER_FAST_MAX_STEPS. An unset or unparseable value yields the default. +func Env(prefix string) Source { return envSource{prefix: prefix} } + +type envSource struct{ prefix string } + +func (e envSource) envName(key string) string { + var b strings.Builder + b.WriteString(e.prefix) + for _, r := range key { + switch { + case r >= 'a' && r <= 'z': + b.WriteRune(r - 32) + case (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9'): + b.WriteRune(r) + default: + b.WriteByte('_') + } + } + return b.String() +} + +func (e envSource) raw(key string) (string, bool) { + v, ok := os.LookupEnv(e.envName(key)) + if !ok { + return "", false + } + return strings.TrimSpace(v), true +} + +func (e envSource) String(key, def string) string { + if v, ok := e.raw(key); ok && v != "" { + return v + } + return def +} + +func (e envSource) Int(key string, def int) int { + if v, ok := e.raw(key); ok { + if n, err := strconv.Atoi(v); err == nil { + return n + } + } + return def +} + +func (e envSource) Float(key string, def float64) float64 { + if v, ok := e.raw(key); ok { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f + } + } + return def +} + +func (e envSource) Bool(key string, def bool) bool { + if v, ok := e.raw(key); ok { + if b, err := strconv.ParseBool(v); err == nil { + return b + } + } + return def +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..df23aa8 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,56 @@ +package config + +import "testing" + +func TestEnvNameMapping(t *testing.T) { + e := envSource{prefix: "GADFLY_"} + cases := map[string]string{ + "models": "GADFLY_MODELS", + "model.tier.fast.max_steps": "GADFLY_MODEL_TIER_FAST_MAX_STEPS", + "provider-concurrency": "GADFLY_PROVIDER_CONCURRENCY", + "a/b.c": "GADFLY_A_B_C", + } + for key, want := range cases { + if got := e.envName(key); got != want { + t.Errorf("envName(%q) = %q, want %q", key, got, want) + } + } +} + +func TestEnvReadsAndDefaults(t *testing.T) { + t.Setenv("EX_MODELS", "a,b,c") + t.Setenv("EX_MAX", "12") + t.Setenv("EX_RATIO", "0.7") + t.Setenv("EX_ON", "true") + t.Setenv("EX_BLANK", "") + s := Env("EX_") + + if got := s.String("models", "def"); got != "a,b,c" { + t.Errorf("String = %q", got) + } + if got := s.String("blank", "def"); got != "def" { + t.Errorf("blank String should fall back to default, got %q", got) + } + if got := s.String("missing", "def"); got != "def" { + t.Errorf("missing String = %q", got) + } + if got := s.Int("max", 1); got != 12 { + t.Errorf("Int = %d", got) + } + if got := s.Int("models", 1); got != 1 { + t.Errorf("unparseable Int should default, got %d", got) + } + if got := s.Float("ratio", 1); got != 0.7 { + t.Errorf("Float = %v", got) + } + if got := s.Bool("on", false); got != true { + t.Errorf("Bool = %v", got) + } +} + +func TestNilSafeHelpers(t *testing.T) { + var s Source // nil + if String(s, "k", "d") != "d" || Int(s, "k", 7) != 7 || Float(s, "k", 1.5) != 1.5 || !Bool(s, "k", true) { + t.Fatal("nil Source helpers must return defaults") + } +} diff --git a/deliver/deliver.go b/deliver/deliver.go new file mode 100644 index 0000000..65192af --- /dev/null +++ b/deliver/deliver.go @@ -0,0 +1,78 @@ +// Package deliver is executus's output-egress seam. +// +// Where a run's final output and any generated artifacts go is host-specific: +// Mort posts a Discord embed (with a paste fallback and state-react emoji), +// Gadfly consolidates findings into one PR comment, a CLI host prints to stdout. +// The harness depends only on the Delivery interface and ships two defaults so a +// new host needs no wiring: Discard (return output to the caller only) and +// Stdout. +package deliver + +import ( + "context" + "fmt" + "io" + "os" +) + +// Target names where output should land. Its fields are host-interpreted (a +// Discord channel ID, a PR number, etc.); the harness never parses them. +type Target struct { + Kind string // host-defined: "channel", "dm", "thread", "stdout", "comment", ... + ID string +} + +// Artifact is a generated file accompanying a run's output (an image, a report, +// a STL, ...). Bytes are owned by the caller; a Delivery must not retain them +// past the call without copying. +type Artifact struct { + Name string + MIME string + Bytes []byte +} + +// Delivery is the output seam. Deliver returns a host-defined id for the posted +// output when one exists (a message ID, a paste URL); an empty id is fine. +// Implementations must be nil-safe for a nil/empty artifacts slice. +type Delivery interface { + Deliver(ctx context.Context, t Target, output string, artifacts []Artifact) (id string, err error) + DeliverError(ctx context.Context, t Target, runErr error) error +} + +// Discard is the light-host default: it drops the output (the caller already has +// it as the run Result). Gadfly uses this — it reads results in-process and does +// its own consolidation. +type Discard struct{} + +func (Discard) Deliver(context.Context, Target, string, []Artifact) (string, error) { + return "", nil +} +func (Discard) DeliverError(context.Context, Target, error) error { return nil } + +// Stdout writes output (and an artifact manifest) to an io.Writer, defaulting to +// os.Stdout. Handy for local/dev and example hosts. +type Stdout struct{ W io.Writer } + +func (s Stdout) w() io.Writer { + if s.W != nil { + return s.W + } + return os.Stdout +} + +func (s Stdout) Deliver(_ context.Context, t Target, output string, artifacts []Artifact) (string, error) { + w := s.w() + if t.Kind != "" || t.ID != "" { + fmt.Fprintf(w, "[%s:%s]\n", t.Kind, t.ID) + } + fmt.Fprintln(w, output) + for _, a := range artifacts { + fmt.Fprintf(w, " \n", a.Name, a.MIME, len(a.Bytes)) + } + return "", nil +} + +func (s Stdout) DeliverError(_ context.Context, _ Target, runErr error) error { + fmt.Fprintf(s.w(), "ERROR: %v\n", runErr) + return nil +} diff --git a/dispatchguard/dispatchguard.go b/dispatchguard/dispatchguard.go new file mode 100644 index 0000000..7828d14 --- /dev/null +++ b/dispatchguard/dispatchguard.go @@ -0,0 +1,159 @@ +// Package dispatchguard is the single chokepoint that bounds agent/skill +// composition: it stops a run from invoking one of its own ancestors +// (loop), from nesting past a depth cap, and from spawning more than a +// budget of descendant runs under one root. +// +// Why a standalone package: "invoke a skill/agent from inside a run" has +// three dispatch surfaces — the agent_invoke/skill_invoke TOOLS, the +// palette skill__/agent__ wrappers, and the agent-as-chatbot tool. The +// guards historically lived only in the TOOLS, so the other two paths +// recursed unbounded (the 2026-06-09 general-agent self-recursion +// incident was exactly this) and the DB-walk guard the tools used failed +// OPEN when the audit store was nil or a run's parent_run_id was empty +// (the chatbot-tool path produces parentless runs). Every dispatch path +// ultimately funnels through Executor.Run -> runInner, so enforcing the +// guard there — against an in-memory ancestor chain carried on the +// context — covers all three paths at once, synchronously, with no +// dependency on the audit table. +// +// The chain + descendant budget ride on context.Context, so propagation +// is automatic: a run stamps itself onto the ctx it hands to its agent +// loop, every tool handler inherits that ctx, and any sub-invocation's +// Executor.Run receives it — so the child sees its full ancestry without +// anyone threading it explicitly. The budget counter is a shared pointer +// created at the root and seen by every descendant; it is garbage +// collected with the context, so there is no global map to clean up. +// +// This package deliberately imports nothing from skillexec/agentexec/ +// agents (only the standard library) so it can be called from both +// executors — and the future merged engine — without an import cycle. +package dispatchguard + +import ( + "context" + "fmt" + "slices" + "sync/atomic" +) + +// Default limits, used when Enter is called with a non-positive value. +// Chosen to allow real fan-out (general -> researcher -> sub-task) while +// still capping a runaway recursion or fan-out tree well before it can +// exhaust a lane or the model budget. +const ( + DefaultMaxDepth = 5 + DefaultMaxDescendant = 64 +) + +// AncestorRef identifies one run in the current dispatch chain. +type AncestorRef struct { + Kind string // "agent" | "skill" + ID string // the noun's stable UUID + RunID string // this run's audit id (for diagnostics) +} + +type chainKeyT struct{} +type budgetKeyT struct{} + +var chainKey chainKeyT +var budgetKey budgetKeyT + +type descendantBudget struct { + count atomic.Int64 + cap int64 +} + +// Chain returns the ancestor refs carried on ctx, oldest-first (the root +// run is index 0). Never nil-panics; returns nil when ctx carries none. +func Chain(ctx context.Context) []AncestorRef { + if v, ok := ctx.Value(chainKey).([]AncestorRef); ok { + return v + } + return nil +} + +// Rejection describes why a run must not proceed. It is intentionally a +// soft outcome: the caller records an audit row and returns the Message +// as the run's output so a delegating parent agent sees a clear, +// actionable refusal rather than a hard error. +type Rejection struct { + // Kind is one of "loop", "depth", "budget" — used for the audit + // status (status_for) and for tests. + Kind string + Detail string +} + +// Status maps a rejection to the audit row status. +func (r *Rejection) Status() string { return "rejected_" + r.Kind } + +// Message is the human/LLM-facing refusal text returned as the run's +// output. +func (r *Rejection) Message() string { + return "⚠️ delegation refused (" + r.Kind + "): " + r.Detail + + ". Synthesize an answer from what you already have instead of re-delegating." +} + +// Enter is called once at the top of every run, BEFORE the agent loop +// starts. It checks the loop / depth / descendant-budget guards against +// the ancestor chain on ctx and returns: +// +// - a child context with `ref` appended to the chain (and, at the root, +// a fresh descendant budget) — use this as the base context for the +// run's agent loop so sub-invocations inherit the ancestry; and +// - a non-nil *Rejection when the run must NOT proceed (in which case +// the returned context equals the input and should be ignored). +// +// maxDepth / maxDescendant <= 0 fall back to the package defaults. +func Enter(ctx context.Context, ref AncestorRef, maxDepth, maxDescendant int) (context.Context, *Rejection) { + if maxDepth <= 0 { + maxDepth = DefaultMaxDepth + } + if maxDescendant <= 0 { + maxDescendant = DefaultMaxDescendant + } + + chain := Chain(ctx) + + // 1) Loop: refuse to invoke a noun already executing in this chain. + for _, a := range chain { + if a.Kind == ref.Kind && a.ID == ref.ID { + return ctx, &Rejection{ + Kind: "loop", + Detail: fmt.Sprintf("%s %q is already running higher in this dispatch chain (depth %d)", + ref.Kind, ref.ID, len(chain)), + } + } + } + + // 2) Depth: refuse to nest past the cap. + if len(chain) >= maxDepth { + return ctx, &Rejection{ + Kind: "depth", + Detail: fmt.Sprintf("dispatch chain depth %d reached the cap of %d", len(chain), maxDepth), + } + } + + // 3) Descendant budget: only descendants (non-root) count against the + // per-root budget. The root creates the shared counter below. + if len(chain) > 0 { + if b, ok := ctx.Value(budgetKey).(*descendantBudget); ok && b != nil && b.cap > 0 { + if b.count.Add(1) > b.cap { + return ctx, &Rejection{ + Kind: "budget", + Detail: fmt.Sprintf("this run tree already spawned its budget of %d descendant runs", b.cap), + } + } + } + } + + // Stamp the child context. slices.Clone keeps each branch's chain + // independent so sibling sub-invocations don't see each other. + newChain := append(slices.Clone(chain), ref) + out := context.WithValue(ctx, chainKey, newChain) + if len(chain) == 0 { + // Root: install the shared descendant budget every descendant + // will increment. + out = context.WithValue(out, budgetKey, &descendantBudget{cap: int64(maxDescendant)}) + } + return out, nil +} diff --git a/dispatchguard/dispatchguard_test.go b/dispatchguard/dispatchguard_test.go new file mode 100644 index 0000000..cd33246 --- /dev/null +++ b/dispatchguard/dispatchguard_test.go @@ -0,0 +1,110 @@ +package dispatchguard + +import ( + "context" + "testing" +) + +func TestEnter_RootHasEmptyChainAndStampsItself(t *testing.T) { + ctx := context.Background() + if got := Chain(ctx); got != nil { + t.Fatalf("root chain should be nil, got %v", got) + } + out, rej := Enter(ctx, AncestorRef{Kind: "agent", ID: "A", RunID: "r1"}, 5, 64) + if rej != nil { + t.Fatalf("root run must not be rejected: %+v", rej) + } + chain := Chain(out) + if len(chain) != 1 || chain[0].ID != "A" { + t.Fatalf("child ctx chain = %v, want [A]", chain) + } +} + +func TestEnter_DirectSelfInvocationIsLoop(t *testing.T) { + ctx := context.Background() + ctx, rej := Enter(ctx, AncestorRef{Kind: "agent", ID: "A", RunID: "r1"}, 5, 64) + if rej != nil { + t.Fatal("first enter should pass") + } + _, rej = Enter(ctx, AncestorRef{Kind: "agent", ID: "A", RunID: "r2"}, 5, 64) + if rej == nil || rej.Kind != "loop" { + t.Fatalf("re-entering A should be a loop rejection, got %+v", rej) + } + if rej.Status() != "rejected_loop" { + t.Fatalf("status = %q", rej.Status()) + } +} + +func TestEnter_IndirectLoopAcrossNouns(t *testing.T) { + // A -> B -> A must be caught even though B is a different noun, and + // even though no parent_run_id DB row exists (this is the chatbot-tool + // parentless-run hole the in-memory chain closes). + ctx := context.Background() + ctx, _ = Enter(ctx, AncestorRef{Kind: "agent", ID: "A", RunID: "r1"}, 5, 64) + ctx, _ = Enter(ctx, AncestorRef{Kind: "skill", ID: "B", RunID: "r2"}, 5, 64) + _, rej := Enter(ctx, AncestorRef{Kind: "agent", ID: "A", RunID: "r3"}, 5, 64) + if rej == nil || rej.Kind != "loop" { + t.Fatalf("A->B->A should be a loop, got %+v", rej) + } +} + +func TestEnter_DifferentNounSameIDIsNotLoop(t *testing.T) { + // A skill and an agent can legitimately share an ID space; the guard + // keys on (Kind, ID), so skill "X" inside agent "X" is allowed. + ctx := context.Background() + ctx, _ = Enter(ctx, AncestorRef{Kind: "agent", ID: "X", RunID: "r1"}, 5, 64) + _, rej := Enter(ctx, AncestorRef{Kind: "skill", ID: "X", RunID: "r2"}, 5, 64) + if rej != nil { + t.Fatalf("skill X under agent X should be allowed, got %+v", rej) + } +} + +func TestEnter_DepthCap(t *testing.T) { + ctx := context.Background() + // maxDepth=3: chains of length 0,1,2 may enter; length 3 is rejected. + var rej *Rejection + for i, id := range []string{"A", "B", "C", "D"} { + ctx, rej = Enter(ctx, AncestorRef{Kind: "agent", ID: id}, 3, 64) + if i < 3 && rej != nil { + t.Fatalf("enter %d (%s) should pass, got %+v", i, id, rej) + } + if i == 3 { + if rej == nil || rej.Kind != "depth" { + t.Fatalf("4th enter at maxDepth=3 should be depth rejection, got %+v", rej) + } + } + } +} + +func TestEnter_DescendantBudgetSharedAcrossTree(t *testing.T) { + // Root installs a budget of 2 descendants. The root itself doesn't + // count; the first two children pass, the third is rejected. + root := context.Background() + root, rej := Enter(root, AncestorRef{Kind: "agent", ID: "root", RunID: "r0"}, 5, 2) + if rej != nil { + t.Fatal("root must pass") + } + _, rej = Enter(root, AncestorRef{Kind: "agent", ID: "c1"}, 5, 2) + if rej != nil { + t.Fatalf("child 1 should pass, got %+v", rej) + } + _, rej = Enter(root, AncestorRef{Kind: "agent", ID: "c2"}, 5, 2) + if rej != nil { + t.Fatalf("child 2 should pass, got %+v", rej) + } + _, rej = Enter(root, AncestorRef{Kind: "agent", ID: "c3"}, 5, 2) + if rej == nil || rej.Kind != "budget" { + t.Fatalf("child 3 should exhaust the descendant budget, got %+v", rej) + } +} + +func TestEnter_SiblingChainsAreIndependent(t *testing.T) { + // Appending to a parent chain must not mutate a sibling's slice. + root := context.Background() + root, _ = Enter(root, AncestorRef{Kind: "agent", ID: "root"}, 5, 64) + branch1, _ := Enter(root, AncestorRef{Kind: "agent", ID: "b1"}, 5, 64) + branch2, _ := Enter(root, AncestorRef{Kind: "agent", ID: "b2"}, 5, 64) + if c1, c2 := Chain(branch1), Chain(branch2); c1[len(c1)-1].ID == c2[len(c2)-1].ID { + t.Fatalf("sibling branches share a tail: %v / %v", c1, c2) + } +} diff --git a/examples/minimal/main.go b/examples/minimal/main.go new file mode 100644 index 0000000..fdc8976 --- /dev/null +++ b/examples/minimal/main.go @@ -0,0 +1,27 @@ +// Command minimal demonstrates executus's standalone core primitives available +// today (P0): the config seam + bounded fan-out. The full zero-config "agentic +// in ~12 lines" example arrives once the model, tool, and run packages land +// (P1–P3). +package main + +import ( + "context" + "fmt" + + "gitea.stevedudenhoeffer.com/steve/executus/config" + "gitea.stevedudenhoeffer.com/steve/executus/fanout" +) + +func main() { + cfg := config.Env("EXECUTUS_") // e.g. EXECUTUS_FANOUT_MAX_CONCURRENT=8 + max := cfg.Int("fanout.max_concurrent", 4) + + items := []string{"alpha", "beta", "gamma", "delta"} + results := fanout.Run(context.Background(), items, + fanout.Options[string]{MaxConcurrent: max}, + func(_ context.Context, s string) (int, error) { return len(s), nil }) + + for _, r := range results { + fmt.Printf("%-6s -> %d (err=%v)\n", items[r.Index], r.Value, r.Err) + } +} diff --git a/fanout/fanout.go b/fanout/fanout.go new file mode 100644 index 0000000..87a19c7 --- /dev/null +++ b/fanout/fanout.go @@ -0,0 +1,124 @@ +// Package fanout is executus's programmatic swarm primitive: run a function over +// many items concurrently with bounded global and per-key concurrency, returning +// one result per item in input order. +// +// This is distinct from the LLM-callable agent_spawn_parallel tool. fanout is a +// plain Go API a host drives directly — it is what Gadfly uses to run an +// N-models × M-lenses review fleet (flatten the matrix into items, key each by +// its provider, cap per-provider concurrency) and what any host uses to scatter +// bounded agent runs and gather structured results for consolidation. +// +// fanout has no dependency beyond the stdlib; a caller wires per-provider caps +// from config (Mort: convar; Gadfly: GADFLY_PROVIDER_CONCURRENCY). +package fanout + +import ( + "context" + "sync" +) + +// Result pairs a task's output with its error and original index. fn errors are +// captured here, not propagated — one failing task never aborts the batch. +type Result[T any] struct { + Index int + Value T + Err error +} + +// Options bound a fan-out. +// +// MaxConcurrent — cap on total in-flight tasks (0 = unbounded). +// PerKey — cap on in-flight tasks sharing a key bucket; a key absent +// from the map (or mapped to <=0) is uncapped beyond +// MaxConcurrent. Used for per-provider concurrency. +// Key — maps an item to its bucket; nil means all items are unkeyed. +type Options[A any] struct { + MaxConcurrent int + PerKey map[string]int + Key func(A) string +} + +// Run executes fn over items concurrently under opts and returns one Result per +// item, in input order. Context cancellation stops un-started tasks (their +// Result carries ctx.Err()); already-running tasks observe ctx through fn. +func Run[A any, T any](ctx context.Context, items []A, opts Options[A], fn func(ctx context.Context, item A) (T, error)) []Result[T] { + results := make([]Result[T], len(items)) + + var global chan struct{} + if opts.MaxConcurrent > 0 { + global = make(chan struct{}, opts.MaxConcurrent) + } + // Build per-key semaphores up front; the map is read-only during the run so + // concurrent reads are safe. + keySems := make(map[string]chan struct{}, len(opts.PerKey)) + for k, n := range opts.PerKey { + if n > 0 { + keySems[k] = make(chan struct{}, n) + } + } + + var wg sync.WaitGroup + for i, it := range items { + wg.Add(1) + go func(i int, it A) { + defer wg.Done() + results[i].Index = i + + if err := ctx.Err(); err != nil { + results[i].Err = err + return + } + + // Acquire global then key (consistent order avoids deadlock). + if global != nil { + select { + case global <- struct{}{}: + defer func() { <-global }() + case <-ctx.Done(): + results[i].Err = ctx.Err() + return + } + } + if opts.Key != nil { + if ks := keySems[opts.Key(it)]; ks != nil { + select { + case ks <- struct{}{}: + defer func() { <-ks }() + case <-ctx.Done(): + results[i].Err = ctx.Err() + return + } + } + } + + v, err := fn(ctx, it) + results[i].Value = v + results[i].Err = err + }(i, it) + } + wg.Wait() + return results +} + +// Values returns the successful values (Err == nil) from a result slice, in +// order. Convenience for consolidation steps that ignore failures. +func Values[T any](rs []Result[T]) []T { + out := make([]T, 0, len(rs)) + for _, r := range rs { + if r.Err == nil { + out = append(out, r.Value) + } + } + return out +} + +// Errors returns the non-nil errors from a result slice, in order. +func Errors[T any](rs []Result[T]) []error { + var out []error + for _, r := range rs { + if r.Err != nil { + out = append(out, r.Err) + } + } + return out +} diff --git a/fanout/fanout_test.go b/fanout/fanout_test.go new file mode 100644 index 0000000..722d7a6 --- /dev/null +++ b/fanout/fanout_test.go @@ -0,0 +1,106 @@ +package fanout + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" +) + +func TestRunPreservesOrderAndCapturesErrors(t *testing.T) { + items := []int{0, 1, 2, 3, 4} + got := Run(context.Background(), items, Options[int]{MaxConcurrent: 2}, + func(_ context.Context, n int) (int, error) { + if n == 2 { + return 0, errors.New("boom") + } + return n * 10, nil + }) + + if len(got) != len(items) { + t.Fatalf("len = %d", len(got)) + } + for i, r := range got { + if r.Index != i { + t.Errorf("result[%d].Index = %d", i, r.Index) + } + if i == 2 { + if r.Err == nil { + t.Errorf("expected error at index 2") + } + } else if r.Value != i*10 { + t.Errorf("result[%d].Value = %d, want %d", i, r.Value, i*10) + } + } + if vals := Values(got); len(vals) != 4 { + t.Errorf("Values len = %d, want 4", len(vals)) + } + if errs := Errors(got); len(errs) != 1 { + t.Errorf("Errors len = %d, want 1", len(errs)) + } +} + +func TestMaxConcurrentBound(t *testing.T) { + const max = 3 + var inflight, peak int32 + items := make([]int, 30) + Run(context.Background(), items, Options[int]{MaxConcurrent: max}, + func(_ context.Context, _ int) (int, error) { + n := atomic.AddInt32(&inflight, 1) + for { + p := atomic.LoadInt32(&peak) + if n <= p || atomic.CompareAndSwapInt32(&peak, p, n) { + break + } + } + time.Sleep(2 * time.Millisecond) + atomic.AddInt32(&inflight, -1) + return 0, nil + }) + if peak > max { + t.Errorf("peak concurrency %d exceeded MaxConcurrent %d", peak, max) + } +} + +func TestPerKeyCap(t *testing.T) { + // Two providers; provider "slow" capped at 1, so its peak must be 1 even + // though MaxConcurrent allows more. + var slowInflight, slowPeak int32 + type job struct{ provider string } + items := make([]job, 12) + for i := range items { + items[i] = job{provider: "slow"} + } + Run(context.Background(), items, Options[job]{ + MaxConcurrent: 8, + PerKey: map[string]int{"slow": 1}, + Key: func(j job) string { return j.provider }, + }, func(_ context.Context, _ job) (int, error) { + n := atomic.AddInt32(&slowInflight, 1) + for { + p := atomic.LoadInt32(&slowPeak) + if n <= p || atomic.CompareAndSwapInt32(&slowPeak, p, n) { + break + } + } + time.Sleep(time.Millisecond) + atomic.AddInt32(&slowInflight, -1) + return 0, nil + }) + if slowPeak != 1 { + t.Errorf("per-key cap not honored: slow peak = %d, want 1", slowPeak) + } +} + +func TestContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + got := Run(ctx, make([]int, 5), Options[int]{MaxConcurrent: 2}, + func(ctx context.Context, _ int) (int, error) { return 1, nil }) + for i, r := range got { + if r.Err == nil { + t.Errorf("result[%d] expected ctx error after cancel", i) + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..814dfd4 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module gitea.stevedudenhoeffer.com/steve/executus + +go 1.26.2 diff --git a/identity/identity.go b/identity/identity.go new file mode 100644 index 0000000..22b571e --- /dev/null +++ b/identity/identity.go @@ -0,0 +1,54 @@ +// Package identity is executus's caller-identity seam. +// +// A CallerID is an opaque string the host assigns (a Discord snowflake, an OAuth +// subject, a CI principal). Two optional capabilities hang off it: AdminPolicy +// gates authoring-class actions, and MemberResolver supplies per-caller +// enrichment (timezone, display name) for tools that want it. Both are nil-safe +// so a host that has no notion of "members" or "admins" wires nothing — the +// defaults treat everyone as a non-admin unknown. +package identity + +import "context" + +// Member is optional per-caller enrichment. Attrs carries host-specific extras +// (a seerr user id, a persona blurb) without widening this struct per host. +type Member struct { + ID string + DisplayName string + Timezone string + Attrs map[string]string +} + +// AdminPolicy decides whether a caller may perform authoring-class actions +// (saving a shared skill, registering an agent). Default: NonAdmin. +type AdminPolicy interface { + IsAdmin(ctx context.Context, callerID string) bool +} + +// NonAdmin is the default policy: nobody is an admin. A single-principal host +// (a CI job) typically overrides with a constant-true policy for its principal. +type NonAdmin struct{} + +func (NonAdmin) IsAdmin(context.Context, string) bool { return false } + +// MemberResolver supplies optional enrichment for a CallerID. ok=false means the +// member is unknown (the harness then proceeds without enrichment). +type MemberResolver interface { + Resolve(ctx context.Context, callerID string) (Member, bool) +} + +// IsAdmin is the nil-safe accessor: a nil AdminPolicy denies. +func IsAdmin(p AdminPolicy, ctx context.Context, callerID string) bool { + if p == nil { + return false + } + return p.IsAdmin(ctx, callerID) +} + +// Resolve is the nil-safe accessor: a nil resolver returns an unknown member. +func Resolve(r MemberResolver, ctx context.Context, callerID string) (Member, bool) { + if r == nil { + return Member{}, false + } + return r.Resolve(ctx, callerID) +} diff --git a/lane/lane.go b/lane/lane.go new file mode 100644 index 0000000..b993921 --- /dev/null +++ b/lane/lane.go @@ -0,0 +1,183 @@ +// Package lane provides a bounded worker pool primitive with +// priority-aware fair-share queueing. Used by mort to bound concurrent +// access to constrained resources (LLM provider connection limits, +// skill execution slots, etc). +// +// Key design constraints: +// - Submit is non-blocking past the dispatch decision. If a slot is +// available the job is dispatched immediately; otherwise it is +// enqueued and Submit returns the queue position. Callers that +// want "wait until done" semantics use SubmitWait. +// - Fair-share-by-user prevents one heavy user from starving others +// (see policy_fair_share.go). +// - Priority is a tie-breaker within a user's queue (higher first). +// - Cancel must work for queued jobs; running jobs are owned by the +// caller's Run goroutine and not killable from here — the caller +// is expected to wire ctx cancellation if desired. +// - Stats are sampled cheaply; ETA is best-effort. +// +// Persistence (DB-backed restart recovery) is layered ON TOP of the +// in-memory primitives via pkg/lane/persistence.go. +package lane + +import ( + "context" + "errors" + "time" +) + +// Job is what callers submit to a Lane. Implementations carry whatever +// state Run needs. +// +// Why: keeping Job a tiny interface lets multiple subsystems (LLM +// transport wrapper, skill executor, future runners) define their own +// concrete job types without leaking implementation details into the +// lane primitives. Persistence is layered on top via the optional +// MetadataProvider interface in persistence.go. +// +// Test: see pool_test.go for end-to-end submit/run/cancel coverage. +type Job interface { + // ID is unique per submission; used by Cancel and by the + // persistence layer to correlate DB rows with in-memory queue + // entries. + ID() string + + // CallerID is the user identity for fair-share queueing. Empty + // string is allowed but lumps every empty-caller job into a + // single bucket; production callers should always populate this. + CallerID() string + + // Priority is the tie-breaker within a single caller's sub-queue. + // Higher numbers run first. Default 0. + Priority() int + + // Run executes the job. The lane calls Run inside a worker + // goroutine when a slot is available. Errors are returned to the + // SubmitWait caller (or logged and dropped for fire-and-forget + // Submit). The provided context is the lane's worker context; + // callers SHOULD respect cancellation but the lane does not kill + // long-running Runs that ignore it. + Run(ctx context.Context) error +} + +// Lane is the bounded worker pool surface. +// +// Why an interface: lets tests substitute a fake lane, and lets the +// persistence wrapper compose around the in-memory implementation +// without having to extend it. +// +// Test: pool_test.go covers the in-memory pool implementation; +// persistence_test.go covers the persistence wrapper. +type Lane interface { + // Name returns the lane's stable identifier (e.g. "ollama"). + Name() string + + // Submit enqueues the job. If a slot is available, the job is + // dispatched immediately and Submit returns (0, 0, nil). If the + // lane is full, Submit returns (queuePos, eta, nil) — the job + // runs later when a slot frees. Submit does NOT block beyond the + // dispatch decision; for "wait until done" semantics use + // SubmitWait. + // + // queuePos is the 1-based position in the queue at submission + // time (1 = next to run). eta is a best-effort estimate based on + // recent throughput; zero when running immediately. + Submit(ctx context.Context, job Job) (queuePos int, eta time.Duration, err error) + + // SubmitWait submits the job and blocks until Run completes or + // ctx is cancelled. Returns Run's error (or ctx.Err on cancel). + // When ctx is cancelled while the job is queued, the job is + // removed from the queue and never runs. When ctx is cancelled + // while the job is running, SubmitWait still waits for Run to + // return — Run's own respect for the context is the caller's + // responsibility. + SubmitWait(ctx context.Context, job Job) error + + // Cancel removes a queued job by ID. Returns ErrNotQueued if the + // job isn't in the queue (already running, finished, or + // unknown). + Cancel(jobID string) error + + // Stats returns a snapshot of the lane's current state. + Stats() LaneStats + + // SetMaxConcurrent updates the lane's concurrency cap. Existing + // running jobs continue to run; new dispatches respect the new + // cap. Calling this with n <= 0 is a no-op (lanes need at least + // one slot to make progress). + SetMaxConcurrent(n int) +} + +// LaneStats is a snapshot of a lane's current state. All fields are +// captured under the lane's mutex so the snapshot is internally +// consistent. +type LaneStats struct { + Name string + MaxConcurrent int + Running int + Queued int + OldestQueuedAt *time.Time + Throughput1m int // jobs completed in the last 60s +} + +// Sentinels. +// +// Why exported sentinels: callers compare with errors.Is so tests and +// production handlers can distinguish lane-internal failures from +// caller errors. +var ( + // ErrNotQueued is returned by Cancel when the job isn't in the + // queue (already running, finished, or unknown). + ErrNotQueued = errors.New("lane: job not queued") + + // ErrLaneClosed is returned by Submit/SubmitWait after Close has + // been called. + ErrLaneClosed = errors.New("lane: closed") + + // ErrCancelled is returned by SubmitWait when the job is + // cancelled while queued (either via Cancel or by ctx.Done). + ErrCancelled = errors.New("lane: job cancelled") + + // ErrPreempted is delivered to a SubmitWait caller when the job's + // running goroutine was cancelled mid-run because a higher-priority + // queued job arrived at a full lane and this job was marked + // preemptible. v9. + ErrPreempted = errors.New("lane: preempted by higher priority job") + + // ErrLaneBusy is returned by SubmitWithMaxWait when the estimated + // queue wait would exceed the caller's maxWait. The job is NOT + // enqueued — caller may retry or degrade. v9. + ErrLaneBusy = errors.New("lane: estimated wait exceeds max") +) + +// Preemptible is an optional Job extension. A Job that returns true is +// eligible for preemption: when a higher-priority job arrives at a +// full lane, the lane scheduler may cancel this job's worker context +// mid-run. The job's Run method MUST honour ctx.Done for the +// cancellation to take effect. +// +// Why an interface (vs a flag on the Job): keeps the base Job +// interface tiny and lets each subsystem decide its preemption +// semantics. Skill jobs implement this by reading +// `skills.Skill.Preemptible`; LLM-transport jobs leave it +// unimplemented (they're never preemptible — cancelling an in-flight +// LLM call costs more than it saves). +// +// v9. +type Preemptible interface { + IsPreemptible() bool +} + +// PreemptionPolicy reports whether a running job should be preempted +// by an arriving higher-priority queued job. Optional registry-level +// surface: when nil, the default policy is "preempt the oldest +// preemptible running job whose runtime exceeds the min-runtime +// guard". v9. +type PreemptionPolicy interface { + // MinRuntime returns the minimum elapsed wall-clock time before a + // preemptible job may be preempted. Default 30s when nil. + MinRuntime() time.Duration + // Enabled reports whether preemption is enabled at all on this + // lane. Default true when nil. + Enabled() bool +} diff --git a/lane/persistence.go b/lane/persistence.go new file mode 100644 index 0000000..1ca2160 --- /dev/null +++ b/lane/persistence.go @@ -0,0 +1,375 @@ +package lane + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" +) + +// PersistenceStore is the narrow surface PersistedLane needs to +// persist and recover lane jobs across process restarts. +// +// Why an interface here vs reaching into pkg/logic/skills directly: +// keeps the lane primitive generic — anyone with a job-row table that +// satisfies these six methods can plug in. pkg/logic/skills.Storage +// satisfies it via a thin adapter (PersistedSkillsStore). +// +// Test: persistence_test.go covers the round-trip + restart recovery +// flow using an in-memory fake store. +type PersistenceStore interface { + // EnqueueJob writes a row in state=queued. lane is the lane + // name; metadata is opaque payload preserved verbatim across + // restart for reconstruct paths. + EnqueueJob(ctx context.Context, jobID, lane, callerID string, priority int, metadata []byte) error + + // UpdateJobState transitions the row to a new state. The state + // strings are the QueueJobState values from + // pkg/logic/skills/skill_queue_job.go ("queued", "running", + // "finished", "cancelled", "failed"). Stamps the matching + // timestamp column. + UpdateJobState(ctx context.Context, jobID string, state string, at time.Time) error + + // ListQueuedJobs returns rows in state=queued for the given + // lane. Used by Recover to re-submit pending work. + ListQueuedJobs(ctx context.Context, lane string) ([]QueuedJobRef, error) + + // ListRunningJobs returns rows in state=running for the given + // lane. After a process restart these are unrecoverable (the + // worker goroutine is gone) and Recover marks them failed. + ListRunningJobs(ctx context.Context, lane string) ([]QueuedJobRef, error) + + // PurgeFinishedJobs deletes terminal-state rows older than the + // cutoff. Returns count deleted. + PurgeFinishedJobs(ctx context.Context, olderThan time.Time) (int64, error) +} + +// QueuedJobRef is a thin row reference returned by List* methods. +// Carries enough state for Recover to reconstruct or mark a job. +// +// Why a separate type from the skills.QueueJob domain: the lane +// package doesn't import the skills package (and would create an +// import cycle if it did). The narrow ref type keeps the contract +// flat. +type QueuedJobRef struct { + JobID string + Lane string + CallerID string + Priority int + Metadata []byte + EnqueuedAt time.Time +} + +// MetadataProvider is the optional interface a Job can implement to +// supply its restart-recovery payload. +// +// Why optional: not every job needs to be reconstructed (raw LLM +// transport jobs are issued ad-hoc by callers; a restart just drops +// the in-flight ones). Skills set Metadata so the executor can +// rehydrate the original Invocation. +type MetadataProvider interface { + Metadata() []byte +} + +// PersistedLane wraps a Lane with DB persistence. Submit writes a +// row before delegating to the inner lane; Run state transitions +// update the row in place. +// +// Why a wrapper vs baking persistence into the pool: keeps the +// in-memory primitives test-friendly (pool_test.go runs without a +// DB). Production wires a PersistedLane around each named lane that +// needs restart recovery; lanes that don't (e.g. transient +// LLM-transport lanes used by anonymous callers) can stay +// in-memory only. +type PersistedLane struct { + inner Lane + store PersistenceStore +} + +// NewPersistedLane wraps an existing Lane with a persistence store. +// The inner lane keeps doing all the in-memory queueing; the +// PersistedLane writes a DB row for each Submit and updates state on +// transitions. +func NewPersistedLane(inner Lane, store PersistenceStore) *PersistedLane { + return &PersistedLane{inner: inner, store: store} +} + +// Inner returns the wrapped lane. Used by Recover to bypass the +// persistence path on re-submission (the row already exists). +func (p *PersistedLane) Inner() Lane { return p.inner } + +// Name delegates to the inner lane. +func (p *PersistedLane) Name() string { return p.inner.Name() } + +// Submit writes the queued row, then delegates to the inner lane. +// The job is wrapped so Run-time state transitions update the row. +// +// On enqueue-row write failure: returns the error WITHOUT submitting +// to the inner lane. We don't want to dispatch a job that we couldn't +// persist — admin visibility (and restart recovery) would then be +// inconsistent with the running set. +func (p *PersistedLane) Submit(ctx context.Context, job Job) (int, time.Duration, error) { + var meta []byte + if mp, ok := job.(MetadataProvider); ok { + meta = mp.Metadata() + } + if err := p.store.EnqueueJob(ctx, job.ID(), p.inner.Name(), + job.CallerID(), job.Priority(), meta); err != nil { + return 0, 0, fmt.Errorf("persist enqueue: %w", err) + } + wrapped := &persistedJob{inner: job, store: p.store} + return p.inner.Submit(ctx, wrapped) +} + +// SubmitWait writes the queued row and blocks until Run completes +// (or ctx is cancelled). Same persistence semantics as Submit. +func (p *PersistedLane) SubmitWait(ctx context.Context, job Job) error { + var meta []byte + if mp, ok := job.(MetadataProvider); ok { + meta = mp.Metadata() + } + if err := p.store.EnqueueJob(ctx, job.ID(), p.inner.Name(), + job.CallerID(), job.Priority(), meta); err != nil { + return fmt.Errorf("persist enqueue: %w", err) + } + wrapped := &persistedJob{inner: job, store: p.store} + return p.inner.SubmitWait(ctx, wrapped) +} + +// Cancel removes the job from the inner queue and writes +// state=cancelled to the persistence store. If Cancel returns +// ErrNotQueued (already running, etc.) the row state is NOT touched — +// the caller knows the job is past the queue stage. +func (p *PersistedLane) Cancel(jobID string) error { + if err := p.inner.Cancel(jobID); err != nil { + return err + } + // Inner cancel succeeded — update DB. + if uerr := p.store.UpdateJobState(context.Background(), jobID, + string(stateCancelled), time.Now()); uerr != nil { + // Best-effort: log; return nil because the in-memory + // cancellation already happened. + slog.Warn("lane persist: cancel state update failed", + "job", jobID, "error", uerr) + } + return nil +} + +// Stats delegates to the inner lane. +func (p *PersistedLane) Stats() LaneStats { return p.inner.Stats() } + +// SetMaxConcurrent delegates to the inner lane. +func (p *PersistedLane) SetMaxConcurrent(n int) { p.inner.SetMaxConcurrent(n) } + +// Recover reconciles the persistence store with the in-memory lane +// after a process restart. +// +// - Rows in state=running at restart correspond to jobs whose +// worker goroutine was killed. They are marked failed (no +// auto-retry — skills can have side effects, see v6 spec +// "Restart amnesia"). +// - Rows in state=queued are re-submitted to the inner lane via +// reconstructFn(ref) → Job. If reconstructFn returns nil the row +// is marked failed with reason "lost on restart" — the caller +// could not reconstruct the original work. +// +// Recover bypasses the PersistedLane.Submit path (which would write a +// duplicate row). The row already exists in state=queued; we just +// re-submit to the in-memory queue and let normal Run-time +// transitions take over from there. +func (p *PersistedLane) Recover(ctx context.Context, reconstructFn func(QueuedJobRef) Job) error { + // 1. Mark running rows as failed. + running, err := p.store.ListRunningJobs(ctx, p.inner.Name()) + if err != nil { + return fmt.Errorf("list running: %w", err) + } + for _, ref := range running { + if uerr := p.store.UpdateJobState(ctx, ref.JobID, + string(stateFailed), time.Now()); uerr != nil { + slog.Warn("lane recover: failed to mark lost-on-restart", + "lane", p.inner.Name(), "job", ref.JobID, "error", uerr) + continue + } + slog.Warn("lane recover: job lost on restart", + "lane", p.inner.Name(), "job", ref.JobID) + } + + // 2. Re-submit queued rows. + queued, err := p.store.ListQueuedJobs(ctx, p.inner.Name()) + if err != nil { + return fmt.Errorf("list queued: %w", err) + } + for _, ref := range queued { + var job Job + if reconstructFn != nil { + job = reconstructFn(ref) + } + if job == nil { + if uerr := p.store.UpdateJobState(ctx, ref.JobID, + string(stateFailed), time.Now()); uerr != nil { + slog.Warn("lane recover: cannot reconstruct, mark-failed errored", + "lane", p.inner.Name(), "job", ref.JobID, "error", uerr) + } else { + slog.Warn("lane recover: cannot reconstruct, marked failed", + "lane", p.inner.Name(), "job", ref.JobID) + } + continue + } + // Wrap the reconstructed job so Run-time state transitions + // still update the existing row (no fresh enqueue). + wrapped := &persistedJob{inner: job, store: p.store} + if _, _, serr := p.inner.Submit(ctx, wrapped); serr != nil { + slog.Warn("lane recover: re-submit failed", + "lane", p.inner.Name(), "job", ref.JobID, "error", serr) + // Mark failed — job is in DB as queued but in-memory + // dispatch never happened. + if uerr := p.store.UpdateJobState(ctx, ref.JobID, + string(stateFailed), time.Now()); uerr != nil { + slog.Warn("lane recover: post-resubmit-failure mark errored", + "job", ref.JobID, "error", uerr) + } + } + } + return nil +} + +// persistedJob wraps an inner Job to write state transitions on +// Run() entry and exit. +type persistedJob struct { + inner Job + store PersistenceStore +} + +func (p *persistedJob) ID() string { return p.inner.ID() } +func (p *persistedJob) CallerID() string { return p.inner.CallerID() } +func (p *persistedJob) Priority() int { return p.inner.Priority() } +func (p *persistedJob) Metadata() []byte { + if mp, ok := p.inner.(MetadataProvider); ok { + return mp.Metadata() + } + return nil +} + +func (p *persistedJob) Run(ctx context.Context) error { + // Mark running. + if uerr := p.store.UpdateJobState(ctx, p.inner.ID(), + string(stateRunning), time.Now()); uerr != nil { + // Don't abort the run if the audit write fails — the + // inner work is what the caller asked for. Log and continue. + slog.Warn("lane persist: state=running update failed", + "job", p.inner.ID(), "error", uerr) + } + err := p.inner.Run(ctx) + terminal := stateFinished + if err != nil { + // Cancellation surfaced as ErrCancelled (queued cancel) is + // already written by PersistedLane.Cancel; if it bubbles up + // here that means Run was called and Run returned with the + // cancellation error — record as cancelled. + if errors.Is(err, ErrCancelled) { + terminal = stateCancelled + } else { + terminal = stateFailed + } + } + if uerr := p.store.UpdateJobState(ctx, p.inner.ID(), + string(terminal), time.Now()); uerr != nil { + slog.Warn("lane persist: terminal state update failed", + "job", p.inner.ID(), "state", terminal, "error", uerr) + } + return err +} + +// Internal copies of the QueueJobState string constants. Why duplicate +// them here vs importing skills: pkg/lane is generic and cannot +// import skills (would create a cycle). Production callers wire the +// PersistedLane via an adapter that satisfies PersistenceStore — +// the strings are part of the contract. +const ( + stateRunning = "running" + stateFinished = "finished" + stateCancelled = "cancelled" + stateFailed = "failed" +) + +// Sweeper periodically purges finished/cancelled/failed rows older +// than the configured retention window. +// +// Why a separate goroutine struct vs reusing +// pkg/logic/skills.StorageSweeper: the queue rows are owned by the +// lane primitive; keeping the sweeper in pkg/lane lets future lane +// users (LLM transport, GPU lanes) share it without pulling in skills +// concerns. +// +// Test: persistence_test.go drives Sweep synchronously. +type Sweeper struct { + store PersistenceStore + clock func() time.Time + interval time.Duration + // retention is computed at Sweep call time so a runtime convar + // change takes effect without restart. + retention func() time.Duration +} + +// NewSweeper constructs the sweeper. retention may be nil → defaults +// to 24h. clock may be nil → time.Now. +func NewSweeper(store PersistenceStore, retention func() time.Duration, clock func() time.Time) *Sweeper { + if clock == nil { + clock = time.Now + } + if retention == nil { + retention = func() time.Duration { return 24 * time.Hour } + } + return &Sweeper{ + store: store, + clock: clock, + retention: retention, + interval: time.Hour, + } +} + +// SetInterval overrides the loop cadence. interval <= 0 is a no-op. +func (s *Sweeper) SetInterval(d time.Duration) { + if d > 0 { + s.interval = d + } +} + +// Start launches the sweeper loop. Returns immediately; cancellation +// via ctx. +func (s *Sweeper) Start(ctx context.Context) { + go s.loop(ctx) +} + +// Sweep runs one purge pass synchronously. Public for tests. +func (s *Sweeper) Sweep(ctx context.Context) { + cutoff := s.clock().Add(-s.retention()) + n, err := s.store.PurgeFinishedJobs(ctx, cutoff) + if err != nil { + slog.Warn("lane sweeper: purge failed", "error", err) + return + } + if n > 0 { + slog.Info("lane sweeper: purged finished jobs", "deleted", n) + } +} + +func (s *Sweeper) loop(ctx context.Context) { + tick := time.NewTicker(s.interval) + defer tick.Stop() + // Startup delay so cold-start load doesn't stack everything in + // the first second. 90s is a reasonable spread. + startup := time.NewTimer(90 * time.Second) + defer startup.Stop() + for { + select { + case <-ctx.Done(): + return + case <-startup.C: + s.Sweep(ctx) + case <-tick.C: + s.Sweep(ctx) + } + } +} diff --git a/lane/persistence_test.go b/lane/persistence_test.go new file mode 100644 index 0000000..7deb1a2 --- /dev/null +++ b/lane/persistence_test.go @@ -0,0 +1,380 @@ +package lane + +import ( + "context" + "errors" + "sync" + "testing" + "time" +) + +// fakeStore is an in-memory PersistenceStore used by persistence +// tests. Records every method call so tests can assert ordering. +type fakeStore struct { + mu sync.Mutex + rows map[string]*storeRow + enqErr error + updErr error + purgeFn func(time.Time) (int64, error) +} + +type storeRow struct { + jobID, lane, callerID string + priority int + metadata []byte + state string + enqueuedAt time.Time + startedAt *time.Time + finishedAt *time.Time +} + +func newFakeStore() *fakeStore { return &fakeStore{rows: map[string]*storeRow{}} } + +func (f *fakeStore) EnqueueJob(_ context.Context, jobID, lane, callerID string, priority int, metadata []byte) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.enqErr != nil { + return f.enqErr + } + if _, exists := f.rows[jobID]; exists { + return errors.New("duplicate enqueue") + } + f.rows[jobID] = &storeRow{ + jobID: jobID, lane: lane, callerID: callerID, + priority: priority, metadata: metadata, + state: "queued", enqueuedAt: time.Now(), + } + return nil +} + +func (f *fakeStore) UpdateJobState(_ context.Context, jobID, state string, at time.Time) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.updErr != nil { + return f.updErr + } + r, ok := f.rows[jobID] + if !ok { + return errors.New("not found") + } + r.state = state + t := at + switch state { + case "running": + r.startedAt = &t + case "finished", "cancelled", "failed": + r.finishedAt = &t + } + return nil +} + +func (f *fakeStore) ListQueuedJobs(_ context.Context, lane string) ([]QueuedJobRef, error) { + return f.list(lane, "queued"), nil +} + +func (f *fakeStore) ListRunningJobs(_ context.Context, lane string) ([]QueuedJobRef, error) { + return f.list(lane, "running"), nil +} + +func (f *fakeStore) list(lane, state string) []QueuedJobRef { + f.mu.Lock() + defer f.mu.Unlock() + var out []QueuedJobRef + for _, r := range f.rows { + if r.lane == lane && r.state == state { + out = append(out, QueuedJobRef{ + JobID: r.jobID, Lane: r.lane, + CallerID: r.callerID, Priority: r.priority, + Metadata: r.metadata, EnqueuedAt: r.enqueuedAt, + }) + } + } + return out +} + +func (f *fakeStore) PurgeFinishedJobs(_ context.Context, olderThan time.Time) (int64, error) { + if f.purgeFn != nil { + return f.purgeFn(olderThan) + } + f.mu.Lock() + defer f.mu.Unlock() + var deleted int64 + for id, r := range f.rows { + if (r.state == "finished" || r.state == "cancelled" || r.state == "failed") && + r.finishedAt != nil && r.finishedAt.Before(olderThan) { + delete(f.rows, id) + deleted++ + } + } + return deleted, nil +} + +func (f *fakeStore) state(jobID string) string { + f.mu.Lock() + defer f.mu.Unlock() + if r, ok := f.rows[jobID]; ok { + return r.state + } + return "" +} + +// metaJob is a Job impl that exposes Metadata. Used in persistence +// tests that assert metadata round-trip. +type metaJob struct { + id, caller string + priority int + meta []byte + run func(ctx context.Context) error +} + +func (m *metaJob) ID() string { return m.id } +func (m *metaJob) CallerID() string { return m.caller } +func (m *metaJob) Priority() int { return m.priority } +func (m *metaJob) Metadata() []byte { return m.meta } +func (m *metaJob) Run(ctx context.Context) error { return m.run(ctx) } + +// TestPersistedLane_Submit_WritesRow verifies Submit writes a queued +// row with the right fields, then on Run completes transitions to +// finished. +func TestPersistedLane_Submit_WritesRow(t *testing.T) { + store := newFakeStore() + inner := New("ollama", 1) + pl := NewPersistedLane(inner, store) + + done := make(chan struct{}) + job := &metaJob{ + id: "j1", caller: "alice", priority: 3, + meta: []byte(`{"prompt":"hi"}`), + run: func(ctx context.Context) error { + close(done) + return nil + }, + } + if _, _, err := pl.Submit(context.Background(), job); err != nil { + t.Fatal(err) + } + <-done + + // Wait for state-update goroutine to land "finished". + waitFor(t, func() bool { return store.state("j1") == "finished" }) + + r := store.rows["j1"] + if r.lane != "ollama" || r.callerID != "alice" || r.priority != 3 { + t.Fatalf("row identity mismatch: %+v", r) + } + if string(r.metadata) != `{"prompt":"hi"}` { + t.Fatalf("metadata mismatch: %s", r.metadata) + } + if r.startedAt == nil || r.finishedAt == nil { + t.Fatalf("expected started_at + finished_at set; row=%+v", r) + } +} + +// TestPersistedLane_Submit_RunErrorMarksFailed verifies a failing Run +// transitions to state=failed. +func TestPersistedLane_Submit_RunErrorMarksFailed(t *testing.T) { + store := newFakeStore() + inner := New("test", 1) + pl := NewPersistedLane(inner, store) + + job := &metaJob{ + id: "j1", caller: "alice", + run: func(ctx context.Context) error { + return errors.New("boom") + }, + } + if err := pl.SubmitWait(context.Background(), job); err == nil { + t.Fatal("expected error from Run") + } + if got := store.state("j1"); got != "failed" { + t.Fatalf("expected state=failed, got %s", got) + } +} + +// TestPersistedLane_EnqueueErrorAborts verifies that if EnqueueJob +// errors, the inner lane never sees the job. +func TestPersistedLane_EnqueueErrorAborts(t *testing.T) { + store := newFakeStore() + store.enqErr = errors.New("disk full") + inner := New("test", 1).(*pool) + pl := NewPersistedLane(inner, store) + + job := &funcJob{ + id: "j1", caller: "alice", + run: func(ctx context.Context) error { + t.Fatal("Run should not be called when persist enqueue fails") + return nil + }, + } + _, _, err := pl.Submit(context.Background(), job) + if err == nil { + t.Fatal("expected Submit to fail") + } + // Inner lane should be empty. + if got := inner.Stats().Running + inner.Stats().Queued; got != 0 { + t.Fatalf("expected inner lane empty, got running+queued=%d", got) + } +} + +// TestPersistedLane_Cancel_QueuedWritesCancelled verifies cancelling a +// queued job writes state=cancelled. +func TestPersistedLane_Cancel_QueuedWritesCancelled(t *testing.T) { + store := newFakeStore() + inner := New("test", 1).(*pool) + pl := NewPersistedLane(inner, store) + + blocker := newTestJob("blocker") + if _, _, err := pl.Submit(context.Background(), blocker); err != nil { + t.Fatal(err) + } + <-blocker.started + + target := newTestJob("target") + if _, _, err := pl.Submit(context.Background(), target); err != nil { + t.Fatal(err) + } + if got := inner.Stats().Queued; got != 1 { + t.Fatalf("expected queued=1, got %d", got) + } + + if err := pl.Cancel("target"); err != nil { + t.Fatal(err) + } + waitFor(t, func() bool { return store.state("target") == "cancelled" }) + + close(blocker.release) +} + +// TestRecover_RunningMarkedFailed verifies that running rows at +// recovery time are marked failed (lost-on-restart). +func TestRecover_RunningMarkedFailed(t *testing.T) { + store := newFakeStore() + now := time.Now() + store.rows["r1"] = &storeRow{ + jobID: "r1", lane: "ollama", callerID: "alice", + state: "running", enqueuedAt: now.Add(-1 * time.Hour), + startedAt: &now, + } + + inner := New("ollama", 1) + pl := NewPersistedLane(inner, store) + + // reconstructFn never called for running rows. + if err := pl.Recover(context.Background(), nil); err != nil { + t.Fatal(err) + } + if got := store.state("r1"); got != "failed" { + t.Fatalf("expected r1 → failed, got %s", got) + } +} + +// TestRecover_QueuedReSubmitted verifies queued rows are re-submitted +// to the inner lane via reconstructFn. +func TestRecover_QueuedReSubmitted(t *testing.T) { + store := newFakeStore() + store.rows["q1"] = &storeRow{ + jobID: "q1", lane: "ollama", callerID: "alice", + state: "queued", enqueuedAt: time.Now(), + metadata: []byte("opaque"), + } + store.rows["q2"] = &storeRow{ + jobID: "q2", lane: "ollama", callerID: "bob", + state: "queued", enqueuedAt: time.Now(), + } + + inner := New("ollama", 2) + pl := NewPersistedLane(inner, store) + + calls := make(chan string, 2) + reconstruct := func(ref QueuedJobRef) Job { + return &funcJob{ + id: ref.JobID, caller: ref.CallerID, + run: func(ctx context.Context) error { + calls <- ref.JobID + return nil + }, + } + } + if err := pl.Recover(context.Background(), reconstruct); err != nil { + t.Fatal(err) + } + + got := map[string]bool{} + for i := 0; i < 2; i++ { + select { + case id := <-calls: + got[id] = true + case <-time.After(time.Second): + t.Fatalf("expected 2 reconstructed runs; only got %v", got) + } + } + if !got["q1"] || !got["q2"] { + t.Fatalf("expected both q1 and q2 reconstructed; got %v", got) + } + + // After Run completes, both rows should be state=finished. + waitFor(t, func() bool { + return store.state("q1") == "finished" && store.state("q2") == "finished" + }) +} + +// TestRecover_NilReconstructMarksFailed verifies that when +// reconstructFn returns nil for a queued row, the row is marked +// failed. +func TestRecover_NilReconstructMarksFailed(t *testing.T) { + store := newFakeStore() + store.rows["q1"] = &storeRow{ + jobID: "q1", lane: "ollama", callerID: "alice", + state: "queued", enqueuedAt: time.Now(), + } + inner := New("ollama", 1) + pl := NewPersistedLane(inner, store) + + if err := pl.Recover(context.Background(), func(QueuedJobRef) Job { return nil }); err != nil { + t.Fatal(err) + } + if got := store.state("q1"); got != "failed" { + t.Fatalf("expected q1 → failed, got %s", got) + } +} + +// TestSweeper_PurgesFinishedRows verifies Sweep calls +// PurgeFinishedJobs with the right cutoff. +func TestSweeper_PurgesFinishedRows(t *testing.T) { + store := newFakeStore() + old := time.Now().Add(-25 * time.Hour) + finished := time.Now() + store.rows["old"] = &storeRow{ + jobID: "old", lane: "x", state: "finished", + enqueuedAt: old, finishedAt: &old, + } + store.rows["recent"] = &storeRow{ + jobID: "recent", lane: "x", state: "finished", + enqueuedAt: finished, finishedAt: &finished, + } + sw := NewSweeper(store, func() time.Duration { return 24 * time.Hour }, nil) + sw.Sweep(context.Background()) + if _, ok := store.rows["old"]; ok { + t.Fatal("old row should have been purged") + } + if _, ok := store.rows["recent"]; !ok { + t.Fatal("recent row should remain") + } +} + +// TestSweeper_RetentionIsDynamic verifies the retention function is +// called per Sweep, so a runtime convar change takes effect. +func TestSweeper_RetentionIsDynamic(t *testing.T) { + store := newFakeStore() + called := 0 + retention := func() time.Duration { + called++ + return time.Hour + } + sw := NewSweeper(store, retention, nil) + sw.Sweep(context.Background()) + sw.Sweep(context.Background()) + if called != 2 { + t.Fatalf("expected retention called twice, got %d", called) + } +} diff --git a/lane/policy_fair_share.go b/lane/policy_fair_share.go new file mode 100644 index 0000000..4da4855 --- /dev/null +++ b/lane/policy_fair_share.go @@ -0,0 +1,192 @@ +package lane + +import ( + "sort" + "time" +) + +// fairSharePolicy implements queuePolicy with per-user sub-queues. +// Dequeue rotates through users round-robin so one user can't starve +// others. Within a user's queue, higher priority comes first; ties +// broken FIFO. +// +// Why round-robin not weighted-fair: simple, no tuning. If user A has +// 5 queued and user B has 1, user B's job runs after at most one of +// user A's jobs. That matches the v6 spec's "Steve queues 10, Dave +// queues 1, Dave gets in after at most 1 of Steve's" guarantee. +// +// Why a separate file: keeps pool.go focused on the in-memory pool; +// the queue policy is a swap-out. v7 may add weighted fair share or +// strict priority. +type fairSharePolicy struct { + // perUser maps caller_id → ordered sub-queue (priority desc, + // FIFO ties). + perUser map[string][]*queuedJob + // users is the round-robin rotation order. A user is appended + // when they first enqueue; removed when their sub-queue empties. + users []string + // nextIdx is the index into users for the next Dequeue. Wraps + // modulo len(users). + nextIdx int +} + +// NewFairSharePolicy returns a queuePolicy with per-user round-robin +// dequeue and priority-ordered FIFO within each user's sub-queue. +// +// Why exported: lets tests (and future callers in pkg/logic/skills) +// construct lanes with explicit fair-share policy via NewWithPolicy. +func NewFairSharePolicy() queuePolicy { + return &fairSharePolicy{ + perUser: make(map[string][]*queuedJob), + } +} + +// NewWithFairShare constructs a Lane backed by a pool with fair-share +// queueing. Convenience wrapper used by the registry default. +func NewWithFairShare(name string, maxConcurrent int) Lane { + return NewWithPolicy(name, maxConcurrent, NewFairSharePolicy()) +} + +// Enqueue adds the job to the caller's sub-queue, sorted by priority +// (higher first) with FIFO tie-breaking. +func (f *fairSharePolicy) Enqueue(j *queuedJob) { + user := j.job.CallerID() + if _, ok := f.perUser[user]; !ok { + f.perUser[user] = []*queuedJob{} + f.users = append(f.users, user) + } + sub := f.perUser[user] + // Insert sorted by priority desc; FIFO ties via stable insert + // after the last entry of equal-or-higher priority. + // + // Why sort.Search: O(log n) within a single user's queue. Since + // per-user backlog is typically small, even a linear scan would + // be fine, but sort.Search keeps the worst case bounded. + i := sort.Search(len(sub), func(i int) bool { + return sub[i].job.Priority() < j.job.Priority() + }) + sub = append(sub, nil) + copy(sub[i+1:], sub[i:]) + sub[i] = j + f.perUser[user] = sub +} + +// Dequeue rotates users round-robin until it finds a non-empty +// sub-queue. Returns nil when all sub-queues are empty. +// +// Why a single-pass loop bounded by len(users): a user whose sub-queue +// is empty stays in `users` only briefly (we delete on the empty +// transition); a single rotation through `users` always finds a non- +// empty sub-queue if one exists, and an empty rotation means truly +// empty. +func (f *fairSharePolicy) Dequeue() *queuedJob { + if len(f.users) == 0 { + return nil + } + for tries := 0; tries < len(f.users); tries++ { + // Bounds-safe selection — len(users) might shrink during + // iteration, so re-bound on every iteration. + if f.nextIdx >= len(f.users) { + f.nextIdx = 0 + } + user := f.users[f.nextIdx] + sub := f.perUser[user] + // Advance the cursor for next time, regardless of whether + // we picked from this user. A round-robin pass that finds + // every user empty exits the loop. + f.nextIdx++ + if len(sub) == 0 { + continue + } + j := sub[0] + sub[0] = nil + sub = sub[1:] + if len(sub) == 0 { + // User's sub-queue is now empty — remove from rotation. + delete(f.perUser, user) + f.users = removeStringAt(f.users, f.nextIdx-1) + // f.nextIdx-1 is the index we just dequeued from. After + // removing, nextIdx now points at the next user (if any), + // so we don't decrement. + if f.nextIdx > len(f.users) { + f.nextIdx = 0 + } + } else { + f.perUser[user] = sub + } + return j + } + return nil +} + +// Cancel walks every sub-queue looking for a matching job ID. Returns +// true if found and removed. +// +// Why O(n) scan: callers cancel by job ID without knowing the user. +// Could maintain a jobID → user index for O(1) cancel; deferred to +// later if profiling shows it matters. n is bounded by total queued +// jobs across all users. +func (f *fairSharePolicy) Cancel(jobID string) bool { + for user, sub := range f.perUser { + for i, j := range sub { + if j.job.ID() == jobID { + // Remove from sub-queue. + j.done <- jobResult{err: ErrCancelled} + f.perUser[user] = append(sub[:i], sub[i+1:]...) + if len(f.perUser[user]) == 0 { + delete(f.perUser, user) + f.users = removeString(f.users, user) + if f.nextIdx > len(f.users) { + f.nextIdx = 0 + } + } + return true + } + } + } + return false +} + +// Len returns the total queued count across every sub-queue. +func (f *fairSharePolicy) Len() int { + total := 0 + for _, sub := range f.perUser { + total += len(sub) + } + return total +} + +// OldestEnqueueTime returns the earliest enqueue time across every +// sub-queue. Returns nil if every queue is empty. +func (f *fairSharePolicy) OldestEnqueueTime() *time.Time { + var oldest *time.Time + for _, sub := range f.perUser { + for _, j := range sub { + if oldest == nil || j.enqueuedAt.Before(*oldest) { + t := j.enqueuedAt + oldest = &t + } + } + } + return oldest +} + +// removeString returns a new slice with the first occurrence of target +// removed. Order is preserved (round-robin order matters). +func removeString(s []string, target string) []string { + for i, v := range s { + if v == target { + return append(s[:i], s[i+1:]...) + } + } + return s +} + +// removeStringAt returns a new slice with the element at idx removed. +// Order is preserved. idx is bounds-checked defensively. +func removeStringAt(s []string, idx int) []string { + if idx < 0 || idx >= len(s) { + return s + } + return append(s[:idx], s[idx+1:]...) +} diff --git a/lane/policy_fair_share_test.go b/lane/policy_fair_share_test.go new file mode 100644 index 0000000..7e8b0c9 --- /dev/null +++ b/lane/policy_fair_share_test.go @@ -0,0 +1,278 @@ +package lane + +import ( + "context" + "fmt" + "testing" + "time" +) + +// fakeJob is a Job impl that records its ID; doesn't block. Used by +// policy tests that need to enumerate dequeue order. +type fakeJob struct { + id string + caller string + priority int +} + +func (f *fakeJob) ID() string { return f.id } +func (f *fakeJob) CallerID() string { return f.caller } +func (f *fakeJob) Priority() int { return f.priority } +func (f *fakeJob) Run(ctx context.Context) error { return nil } + +// enq is a test helper that enqueues a fakeJob with the given fields +// directly on a fairSharePolicy. +func enq(p queuePolicy, id, user string, priority int) *queuedJob { + qj := &queuedJob{ + job: &fakeJob{id: id, caller: user, priority: priority}, + enqueuedAt: time.Now(), + done: make(chan jobResult, 1), + } + p.Enqueue(qj) + return qj +} + +// drainOrder returns the IDs in the order Dequeue produces them. +func drainOrder(p queuePolicy) []string { + var out []string + for { + j := p.Dequeue() + if j == nil { + return out + } + out = append(out, j.job.ID()) + } +} + +// TestFairShare_RoundRobinAcrossUsers covers the spec's headline +// guarantee: A submits 10, B submits 1, B's job runs after at most 1 +// of A's. +func TestFairShare_RoundRobinAcrossUsers(t *testing.T) { + p := NewFairSharePolicy() + for i := 0; i < 10; i++ { + enq(p, fmt.Sprintf("a%d", i), "userA", 0) + } + enq(p, "b1", "userB", 0) + + order := drainOrder(p) + // First two dequeues should be one A then b1 (or b1 then A, + // depending on rotation start). Either way, b1 must appear within + // the first two entries. + foundB := -1 + for i, id := range order { + if id == "b1" { + foundB = i + break + } + } + if foundB == -1 { + t.Fatalf("b1 was never dequeued; order=%v", order) + } + if foundB > 1 { + t.Fatalf("b1 dequeued at position %d; expected 0 or 1; order=%v", + foundB, order) + } + if len(order) != 11 { + t.Fatalf("expected 11 dequeues, got %d (%v)", len(order), order) + } +} + +// TestFairShare_PriorityWithinUser covers per-user priority ordering. +// Within one user, priority 5 > 1 > 0, FIFO ties. +func TestFairShare_PriorityWithinUser(t *testing.T) { + p := NewFairSharePolicy() + enq(p, "lo1", "u1", 0) + enq(p, "hi", "u1", 5) + enq(p, "mid", "u1", 1) + enq(p, "lo2", "u1", 0) + + order := drainOrder(p) + if got := order[0]; got != "hi" { + t.Fatalf("expected hi first, got %v", order) + } + if got := order[1]; got != "mid" { + t.Fatalf("expected mid second, got %v", order) + } + // lo1 was enqueued before lo2 — FIFO preserves order. + if order[2] != "lo1" || order[3] != "lo2" { + t.Fatalf("expected lo1 then lo2 (FIFO ties), got %v", order) + } +} + +// TestFairShare_PrioritySortStable covers a regression-prone case: +// when an existing job at priority N is in the queue, a new job at +// priority N appended afterward must come AFTER (FIFO ties), not +// before. +func TestFairShare_PrioritySortStable(t *testing.T) { + p := NewFairSharePolicy() + enq(p, "a", "u1", 1) + enq(p, "b", "u1", 1) + enq(p, "c", "u1", 1) + order := drainOrder(p) + want := []string{"a", "b", "c"} + for i, id := range want { + if order[i] != id { + t.Fatalf("expected FIFO order %v, got %v", want, order) + } + } +} + +// TestFairShare_CancelRemovesFromSubQueue verifies Cancel removes a +// queued job and rotation continues correctly. +func TestFairShare_CancelRemovesFromSubQueue(t *testing.T) { + p := NewFairSharePolicy() + a := enq(p, "a1", "userA", 0) + enq(p, "b1", "userB", 0) + enq(p, "a2", "userA", 0) + + if !p.Cancel("a1") { + t.Fatal("expected Cancel(a1) to return true") + } + // Verify a's done channel got cancelled signal. + select { + case res := <-a.done: + if res.err != ErrCancelled { + t.Fatalf("expected ErrCancelled, got %v", res.err) + } + default: + t.Fatal("expected a1.done to have a cancellation signal") + } + + if p.Len() != 2 { + t.Fatalf("expected len=2 after cancel, got %d", p.Len()) + } + + // Drain — should be one of (b1, a2) or (a2, b1). + order := drainOrder(p) + if len(order) != 2 { + t.Fatalf("expected 2 dequeues, got %v", order) + } +} + +// TestFairShare_CancelLastInUserRemovesFromRotation verifies that +// cancelling the last queued job in a user's sub-queue removes the +// user from the rotation (no empty-user spinning on next Dequeue). +func TestFairShare_CancelLastInUserRemovesFromRotation(t *testing.T) { + p := NewFairSharePolicy().(*fairSharePolicy) + enq(p, "a1", "userA", 0) + enq(p, "b1", "userB", 0) + if !p.Cancel("a1") { + t.Fatal("cancel a1 failed") + } + if _, ok := p.perUser["userA"]; ok { + t.Fatal("userA should have been removed from perUser map") + } + for _, u := range p.users { + if u == "userA" { + t.Fatal("userA should have been removed from rotation") + } + } +} + +// TestFairShare_OldestEnqueueTime verifies the earliest enqueue time +// across all sub-queues is reported. +func TestFairShare_OldestEnqueueTime(t *testing.T) { + p := NewFairSharePolicy() + t1 := time.Now().Add(-10 * time.Second) + t2 := time.Now().Add(-5 * time.Second) + + p.Enqueue(&queuedJob{ + job: &fakeJob{id: "a", caller: "uA"}, + enqueuedAt: t1, + done: make(chan jobResult, 1), + }) + p.Enqueue(&queuedJob{ + job: &fakeJob{id: "b", caller: "uB"}, + enqueuedAt: t2, + done: make(chan jobResult, 1), + }) + got := p.OldestEnqueueTime() + if got == nil { + t.Fatal("expected non-nil oldest") + } + if !got.Equal(t1) { + t.Fatalf("expected %v, got %v", t1, *got) + } +} + +// TestFairShare_EmptyDequeue verifies Dequeue returns nil on empty +// queue. +func TestFairShare_EmptyDequeue(t *testing.T) { + p := NewFairSharePolicy() + if j := p.Dequeue(); j != nil { + t.Fatalf("expected nil dequeue, got %v", j) + } +} + +// TestFairShare_LaneIntegration verifies NewWithFairShare wires a +// fair-share lane that respects the same scheduling guarantees. +// +// Two users, A submits 4, B submits 1 — with maxConcurrent=1, B's +// job must dispatch within the first two queued positions (after at +// most one of A's jobs). +// +// We capture dispatch order by recording the run order via a shared +// channel; each Run sends its id then waits for release. +func TestFairShare_LaneIntegration(t *testing.T) { + lane := NewWithFairShare("test", 1) + + // Block dispatch with a single running job so subsequent submits + // queue. + blocker := newTestJob("blocker") + blocker.caller = "blocker-user" + if _, _, err := lane.Submit(context.Background(), blocker); err != nil { + t.Fatal(err) + } + <-blocker.started + + startOrder := make(chan string, 5) + mkJob := func(id, caller string) *funcJob { + return &funcJob{ + id: id, caller: caller, + run: func(ctx context.Context) error { + startOrder <- id + return nil + }, + } + } + for _, id := range []string{"a1", "a2", "a3", "a4"} { + if _, _, err := lane.Submit(context.Background(), mkJob(id, "userA")); err != nil { + t.Fatal(err) + } + } + if _, _, err := lane.Submit(context.Background(), mkJob("b1", "userB")); err != nil { + t.Fatal(err) + } + + // Release blocker; queued jobs dispatch one at a time as each + // previous one finishes (Run returns immediately after sending + // to startOrder). + close(blocker.release) + + var observed []string + deadline := time.After(2 * time.Second) + for i := 0; i < 5; i++ { + select { + case id := <-startOrder: + observed = append(observed, id) + case <-deadline: + t.Fatalf("never observed all dispatches; got %v", observed) + } + } + + // b1 must run at position 0 or 1 (after at most one A). + foundB := -1 + for i, id := range observed { + if id == "b1" { + foundB = i + break + } + } + if foundB == -1 { + t.Fatalf("b1 was never dispatched; order=%v", observed) + } + if foundB > 1 { + t.Fatalf("b1 ran at position %d among %v; expected 0 or 1", + foundB, observed) + } +} diff --git a/lane/pool.go b/lane/pool.go new file mode 100644 index 0000000..9adc057 --- /dev/null +++ b/lane/pool.go @@ -0,0 +1,694 @@ +package lane + +import ( + "context" + "sync" + "time" +) + +// pool implements Lane with a slot-counting mutex + a pluggable queue +// policy. A single dispatch path lives inside complete(): when a job +// finishes it pulls the next queued job (if any) under the same lock, +// guaranteeing a strict "release one slot, fill one slot" rhythm with +// no goroutine racing to pick the same job. +// +// Why a mutex + map vs a buffered channel as semaphore: we need to +// inspect "running" + "queued" state for Stats, Cancel, and the +// dispatch decision. A single mutex over both maps keeps that cheap +// and consistent. +// +// Test: pool_test.go covers slot-available, slot-full, cancel, +// SubmitWait blocking, Stats accuracy, throughput sampling, and +// SetMaxConcurrent. +type pool struct { + name string + + mu sync.Mutex + maxConcurrent int + running map[string]*runningJob + queue queuePolicy + closed bool + + // completions is a sliding window of job-finish timestamps used + // for the Throughput1m stat. Append on every complete(); prune + // entries older than 60s on read + on each append. Bounded by + // the throughput rate, not by an explicit cap — at 60s/window + // even a tight loop tops out at a few thousand entries. + completions []time.Time + + // runtimes is a bounded sliding window of completed-job wall-clock + // runtimes used by SubmitWithMaxWait's ETA estimator. Capped at + // the configured eta window size (default 16). v9. + runtimes []time.Duration + etaWindowSize int + + // preemption configuration. Both can be reconfigured after + // construction via SetPreemptionPolicy. nil-safe defaults preserve + // pre-v9 behavior (no preemption). v9. + preemptPolicy PreemptionPolicy +} + +type runningJob struct { + job Job + // startedAt captures dispatch wall-clock for future ETA tuning; + // not currently surfaced. + startedAt time.Time + // runCtx is the context passed to Job.Run; cancel calls the + // associated CancelCauseFunc. v9. + runCtx context.Context + cancel context.CancelCauseFunc + // preempted is set true when the lane scheduler chose this job for + // preemption. The worker reads this on Run-return to deliver + // ErrPreempted instead of the actual ctx.Cause. v9. + preempted bool +} + +// queuedJob is the in-queue representation of a Submit. done is buffered +// so the dispatch goroutine can signal completion without blocking +// (SubmitWait may have given up on ctx.Done before the job runs; +// dispatch must still be able to deliver the result without leaking). +type queuedJob struct { + job Job + enqueuedAt time.Time + // done is closed (or sent on) exactly once when the job's outcome + // is known: either Run returned, or the job was cancelled before + // dispatch. + done chan jobResult +} + +type jobResult struct { + err error +} + +// queuePolicy is the pluggable queue ordering. fifoPolicy is the +// default; fairSharePolicy lives in policy_fair_share.go. +// +// Why pluggable: the LLM-transport lane wants fair-share, but +// single-resource lanes (e.g. gpu-imagine, max_concurrent=1) work +// fine with FIFO. Future v7 work might add weighted fair share or +// strict priority — keeping the policy small lets us evolve. +type queuePolicy interface { + // Enqueue adds a job to the queue. Implementations may reorder + // the queue based on caller / priority. + Enqueue(j *queuedJob) + // Dequeue returns the next job to run, removing it from the + // queue. Returns nil when empty. + Dequeue() *queuedJob + // Cancel removes a job by ID and signals its done channel with + // ErrCancelled. Returns true if found. + Cancel(jobID string) bool + // Len returns the number of queued jobs. + Len() int + // OldestEnqueueTime returns the earliest enqueue timestamp, or + // nil if the queue is empty. + OldestEnqueueTime() *time.Time +} + +// New constructs a pool with FIFO queueing. +// +// Why a separate New / NewWithFairShare instead of a single function +// taking a policy: lanes are usually instantiated by name from convars +// — keeping the constructor selection explicit makes call sites read +// clearly ("we want fair-share for the ollama lane"). +func New(name string, maxConcurrent int) Lane { + if maxConcurrent <= 0 { + maxConcurrent = 1 + } + return &pool{ + name: name, + maxConcurrent: maxConcurrent, + running: make(map[string]*runningJob), + queue: newFIFOPolicy(), + } +} + +// NewWithPolicy constructs a pool with a caller-supplied queue policy. +// Used by NewWithFairShare and by tests that exercise custom orderings. +func NewWithPolicy(name string, maxConcurrent int, policy queuePolicy) Lane { + if maxConcurrent <= 0 { + maxConcurrent = 1 + } + if policy == nil { + policy = newFIFOPolicy() + } + return &pool{ + name: name, + maxConcurrent: maxConcurrent, + running: make(map[string]*runningJob), + queue: policy, + } +} + +func (p *pool) Name() string { return p.name } + +func (p *pool) Submit(ctx context.Context, job Job) (int, time.Duration, error) { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return 0, 0, ErrLaneClosed + } + if len(p.running) < p.maxConcurrent { + // Slot available — dispatch immediately. + rj := p.newRunningJobLocked(job) + p.running[job.ID()] = rj + // We need a done channel even for fire-and-forget Submit so + // complete() has somewhere to signal; it's discarded. + done := make(chan jobResult, 1) + p.mu.Unlock() + go p.run(rj, done) + return 0, 0, nil + } + // V9 preemption: incoming job has higher priority than at least one + // preemptible running job that has been running for the min-runtime + // guard. If we can find such a victim, cancel it and dispatch the + // new job into the freed slot. The victim's worker delivers + // ErrPreempted on its done channel. + if p.tryPreemptLocked(job) { + rj := p.newRunningJobLocked(job) + p.running[job.ID()] = rj + done := make(chan jobResult, 1) + p.mu.Unlock() + go p.run(rj, done) + return 0, 0, nil + } + // Queue. + qj := &queuedJob{ + job: job, + enqueuedAt: time.Now(), + done: make(chan jobResult, 1), + } + p.queue.Enqueue(qj) + pos := p.queue.Len() + eta := p.estimateETALocked(pos) + p.mu.Unlock() + return pos, eta, nil +} + +// SubmitWithMaxWait is like Submit but returns ErrLaneBusy without +// enqueueing if the estimated wait time would exceed maxWait. maxWait +// <= 0 disables the gate (equivalent to Submit). v9. +// +// ETA is computed from the recent completed-job runtime window; with +// no history the estimator falls back to a conservative 1s/slot. +// Callers ARE NOT charged for an ErrLaneBusy submission — the job is +// never enqueued. The estimated wait at the time of decision is +// returned alongside the error so callers can log/report the exact +// gate value. +func (p *pool) SubmitWithMaxWait(ctx context.Context, job Job, maxWait time.Duration) (int, time.Duration, error) { + if maxWait <= 0 { + return p.Submit(ctx, job) + } + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return 0, 0, ErrLaneClosed + } + if len(p.running) < p.maxConcurrent { + rj := p.newRunningJobLocked(job) + p.running[job.ID()] = rj + done := make(chan jobResult, 1) + p.mu.Unlock() + go p.run(rj, done) + return 0, 0, nil + } + if p.tryPreemptLocked(job) { + rj := p.newRunningJobLocked(job) + p.running[job.ID()] = rj + done := make(chan jobResult, 1) + p.mu.Unlock() + go p.run(rj, done) + return 0, 0, nil + } + // Estimate wait at queue tail (current depth + 1). + pos := p.queue.Len() + 1 + eta := p.estimateWaitLocked(pos) + if eta > maxWait { + p.mu.Unlock() + return pos, eta, ErrLaneBusy + } + qj := &queuedJob{ + job: job, + enqueuedAt: time.Now(), + done: make(chan jobResult, 1), + } + p.queue.Enqueue(qj) + p.mu.Unlock() + return pos, eta, nil +} + +// newRunningJobLocked allocates the per-running-job state. Caller MUST +// hold p.mu. v9: every running job carries its own context so the +// preemption path has somewhere to deliver cancellation. +func (p *pool) newRunningJobLocked(job Job) *runningJob { + jobCtx, cancel := context.WithCancelCause(context.Background()) + return &runningJob{ + job: job, + startedAt: time.Now(), + runCtx: jobCtx, + cancel: cancel, + } +} + +// tryPreemptLocked picks a preemption victim and cancels it. Returns +// true if a slot was freed. Caller MUST hold p.mu and MUST verify +// the lane is full before calling. v9. +// +// Selection: among running jobs that (a) implement Preemptible and +// IsPreemptible() returns true, AND (b) have a strictly LOWER priority +// than the incoming job, AND (c) have been running for >= MinRuntime, +// pick the one with the LOWEST priority; FIFO tie-break by oldest +// startedAt. We pick lowest priority first so we always sacrifice the +// least-valuable running job. The min-runtime guard prevents thrashing +// (a just-dispatched job staying alive long enough to make progress). +func (p *pool) tryPreemptLocked(incoming Job) bool { + if p.preemptPolicy != nil && !p.preemptPolicy.Enabled() { + return false + } + pol, ok := incoming.(Preemptible) + _ = pol + _ = ok + // We don't gate by "incoming is preemptible". Even non-preemptible + // incoming jobs may preempt a preemptible victim: the goal is to + // give higher-priority work the slot, regardless of whether THAT + // work is itself preemptible. Mark a skill preemptible only when + // you'd accept losing its work to whatever priority arrives next. + minRuntime := p.minRuntimeLocked() + now := time.Now() + var victim *runningJob + for _, rj := range p.running { + pj, isPre := rj.job.(Preemptible) + if !isPre || !pj.IsPreemptible() { + continue + } + if rj.preempted { + continue // already chosen in a prior race; don't double-cancel + } + if rj.job.Priority() >= incoming.Priority() { + continue + } + if now.Sub(rj.startedAt) < minRuntime { + continue + } + if victim == nil || + rj.job.Priority() < victim.job.Priority() || + (rj.job.Priority() == victim.job.Priority() && rj.startedAt.Before(victim.startedAt)) { + victim = rj + } + } + if victim == nil { + return false + } + victim.preempted = true + if victim.cancel != nil { + victim.cancel(ErrPreempted) + } + // We DO NOT remove the victim from p.running here — the worker + // goroutine's Run() may take some non-trivial time to honour + // cancellation. The slot will free when the worker calls + // complete(). Until then, we count this victim as still occupying + // a slot. The caller MUST not assume an immediate slot is + // available; it should still go through the normal "queue if + // full" path. We return true to signal "preemption requested" so + // the caller can elect to immediately enqueue at queue head. + // + // However, the v9 spec wants the higher-priority job to take the + // slot directly. We accomplish this by NOT going through the + // queue: the caller already verified len(running) >= + // maxConcurrent, but by setting victim.preempted=true and + // signalling cancel, the victim's worker will exit imminently. + // We dispatch the incoming job NOW, accepting that running may + // briefly exceed maxConcurrent. The complete() path doesn't + // re-enforce the cap; SetMaxConcurrent uses the same "let + // in-flight finish" semantics. So the incoming job runs in + // parallel with the about-to-die victim, and order-of-magnitude + // the lane may briefly hold maxConcurrent+1 jobs. This is + // acceptable because preemption is opt-in and rare. + return true +} + +// minRuntimeLocked returns the configured preemption min-runtime, or +// the default of 30s when the policy is nil. Caller MUST hold p.mu. +// +// A configured policy returning d == 0 is honored as "no min-runtime +// guard" (preempt immediately). d < 0 falls back to the default. +func (p *pool) minRuntimeLocked() time.Duration { + if p.preemptPolicy == nil { + return 30 * time.Second + } + d := p.preemptPolicy.MinRuntime() + if d < 0 { + return 30 * time.Second + } + return d +} + +// SetPreemptionPolicy installs a new preemption policy. Existing +// running jobs are unaffected; future dispatch decisions consult the +// new policy. v9. +func (p *pool) SetPreemptionPolicy(policy PreemptionPolicy) { + p.mu.Lock() + p.preemptPolicy = policy + p.mu.Unlock() +} + +// SetETAWindowSize updates the rolling window size used by +// SubmitWithMaxWait's ETA estimator. v9. +func (p *pool) SetETAWindowSize(n int) { + if n <= 0 { + return + } + p.mu.Lock() + p.etaWindowSize = n + if len(p.runtimes) > n { + p.runtimes = p.runtimes[len(p.runtimes)-n:] + } + p.mu.Unlock() +} + +func (p *pool) SubmitWait(ctx context.Context, job Job) error { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return ErrLaneClosed + } + if len(p.running) < p.maxConcurrent { + rj := p.newRunningJobLocked(job) + p.running[job.ID()] = rj + done := make(chan jobResult, 1) + p.mu.Unlock() + go p.run(rj, done) + select { + case res := <-done: + return res.err + case <-ctx.Done(): + // Run has its own context; we cannot kill it from here. + // Wait for it to finish and return ctx.Err to the caller. + <-done + return ctx.Err() + } + } + // V9 preemption: same path as Submit. + if p.tryPreemptLocked(job) { + rj := p.newRunningJobLocked(job) + p.running[job.ID()] = rj + done := make(chan jobResult, 1) + p.mu.Unlock() + go p.run(rj, done) + select { + case res := <-done: + return res.err + case <-ctx.Done(): + <-done + return ctx.Err() + } + } + qj := &queuedJob{ + job: job, + enqueuedAt: time.Now(), + done: make(chan jobResult, 1), + } + p.queue.Enqueue(qj) + p.mu.Unlock() + + select { + case res := <-qj.done: + return res.err + case <-ctx.Done(): + // Try to cancel before dispatch picks it up. + if p.Cancel(job.ID()) == nil { + return ctx.Err() + } + // Already dequeued and running — wait for the run to finish. + <-qj.done + return ctx.Err() + } +} + +// run executes the job and arranges for the next queued job to be +// dispatched on completion. The done channel is signaled exactly once +// with the run's error. +// +// v9: each running job carries its own cancellable context so the +// preemption path can deliver cancellation. Pre-v9 callers passed +// context.Background; that semantic is preserved for jobs that ignore +// ctx.Done. Jobs that respect ctx will see cancellation immediately +// when the lane scheduler chooses them as a preemption victim. +func (p *pool) run(rj *runningJob, done chan<- jobResult) { + jobCtx := p.newJobContext(rj) + err := rj.job.Run(jobCtx) + // If the lane chose this job for preemption, override the worker's + // returned error with ErrPreempted so SubmitWait callers can + // distinguish "preempted" from a generic ctx.Cause. + p.mu.Lock() + preempted := rj.preempted + startedAt := rj.startedAt + p.mu.Unlock() + if preempted { + err = ErrPreempted + } + done <- jobResult{err: err} + p.complete(rj.job.ID(), startedAt, time.Now()) +} + +// runQueued is the dispatch path for jobs that were queued, not +// dispatched immediately. Identical to run() except it signals the +// queued job's done channel (the caller's SubmitWait waits on it). +func (p *pool) runQueued(rj *runningJob, qj *queuedJob) { + jobCtx := p.newJobContext(rj) + err := qj.job.Run(jobCtx) + p.mu.Lock() + preempted := rj.preempted + startedAt := rj.startedAt + p.mu.Unlock() + if preempted { + err = ErrPreempted + } + qj.done <- jobResult{err: err} + p.complete(qj.job.ID(), startedAt, time.Now()) +} + +// newJobContext returns the context the worker passes to Job.Run. v9: +// every running job has a cancellable context backing rj.cancel, so +// the preemption path can interrupt it. +func (p *pool) newJobContext(rj *runningJob) context.Context { + if rj.runCtx == nil { + return context.Background() + } + return rj.runCtx +} + +// complete is called when a job's Run returns. It removes the job +// from the running map, records throughput, and pulls the next queued +// job (if any) to fill the freed slot. +func (p *pool) complete(jobID string, startedAt, finishedAt time.Time) { + p.mu.Lock() + delete(p.running, jobID) + p.completions = append(p.completions, finishedAt) + p.pruneCompletionsLocked(finishedAt) + // V9: track runtime for ETA estimator. + if !startedAt.IsZero() { + p.recordRuntimeLocked(finishedAt.Sub(startedAt)) + } + + // Pull next queued job under the same lock. + if !p.closed && len(p.running) < p.maxConcurrent { + next := p.queue.Dequeue() + if next != nil { + rj := p.newRunningJobLocked(next.job) + p.running[next.job.ID()] = rj + p.mu.Unlock() + go p.runQueued(rj, next) + return + } + } + p.mu.Unlock() +} + +// recordRuntimeLocked appends to the rolling runtime window used by +// SubmitWithMaxWait's ETA estimator. Caller MUST hold p.mu. v9. +func (p *pool) recordRuntimeLocked(d time.Duration) { + if d <= 0 { + return + } + cap := p.etaWindowSize + if cap <= 0 { + cap = 16 + } + p.runtimes = append(p.runtimes, d) + if len(p.runtimes) > cap { + p.runtimes = p.runtimes[len(p.runtimes)-cap:] + } +} + +func (p *pool) Cancel(jobID string) error { + p.mu.Lock() + defer p.mu.Unlock() + if p.queue.Cancel(jobID) { + return nil + } + return ErrNotQueued +} + +func (p *pool) Stats() LaneStats { + p.mu.Lock() + defer p.mu.Unlock() + now := time.Now() + p.pruneCompletionsLocked(now) + return LaneStats{ + Name: p.name, + MaxConcurrent: p.maxConcurrent, + Running: len(p.running), + Queued: p.queue.Len(), + OldestQueuedAt: p.queue.OldestEnqueueTime(), + Throughput1m: len(p.completions), + } +} + +func (p *pool) SetMaxConcurrent(n int) { + if n <= 0 { + return + } + p.mu.Lock() + p.maxConcurrent = n + // If we just raised the cap, dispatch backlog. + for len(p.running) < p.maxConcurrent && !p.closed { + next := p.queue.Dequeue() + if next == nil { + break + } + rj := p.newRunningJobLocked(next.job) + p.running[next.job.ID()] = rj + // Spin up the goroutine while still holding the lock; the + // goroutine itself doesn't take p.mu until complete(). + go p.runQueued(rj, next) + } + p.mu.Unlock() +} + +// pruneCompletionsLocked drops completion timestamps older than 60s. +// Caller must hold p.mu. The slice is rebuilt rather than truncated +// in place because the throughput counts are typically small (hundreds +// at most); avoiding pointer churn here is not worth the complexity +// of an in-place compaction. +func (p *pool) pruneCompletionsLocked(now time.Time) { + cutoff := now.Add(-time.Minute) + if len(p.completions) == 0 { + return + } + // Find the first entry within the window — completions is + // append-only so it's already sorted ascending. + first := 0 + for first < len(p.completions) && p.completions[first].Before(cutoff) { + first++ + } + if first == 0 { + return + } + if first >= len(p.completions) { + p.completions = p.completions[:0] + return + } + // Copy tail down to head; reuse the backing array. + n := copy(p.completions, p.completions[first:]) + p.completions = p.completions[:n] +} + +// estimateWaitLocked returns the best-effort wait time before the +// given queue position is dispatched. Caller MUST hold p.mu. v9 — +// uses the recent-runtime window when available, falling back to the +// throughput-based estimate. The result reflects the time the +// position-`pos` job will sit in the queue: with `maxConcurrent` +// running jobs the wait is `(pos / maxConcurrent) * avgRuntime`. +func (p *pool) estimateWaitLocked(pos int) time.Duration { + if pos <= 0 { + return 0 + } + if len(p.runtimes) == 0 { + return p.estimateETALocked(pos) + } + var total time.Duration + for _, d := range p.runtimes { + total += d + } + avg := total / time.Duration(len(p.runtimes)) + if avg <= 0 { + return p.estimateETALocked(pos) + } + concurrency := p.maxConcurrent + if concurrency <= 0 { + concurrency = 1 + } + // Each "round" through the slots drains `concurrency` jobs in + // avg runtime. Position `pos` waits ceil(pos / concurrency) rounds. + rounds := (pos + concurrency - 1) / concurrency + return avg * time.Duration(rounds) +} + +// estimateETALocked returns a rough ETA for a job at the given +// 1-based queue position. Caller must hold p.mu. +// +// Why best-effort: production callers (Discord "queued (~30s)" reply) +// only need an order-of-magnitude estimate. Throughput is sampled over +// a 1-minute window; if the window is empty we fall back to a +// conservative default of 1s/slot * pos. +func (p *pool) estimateETALocked(pos int) time.Duration { + if pos <= 0 { + return 0 + } + // throughput per second over the window + thr := len(p.completions) + if thr == 0 { + // Fallback: assume each slot takes ~1s — better than zero. + return time.Duration(pos) * time.Second + } + // We have N completions in the last 60s; the lane's "effective + // throughput" is N jobs / 60s. ETA for position `pos` is the + // time needed to drain pos jobs at that rate. + perJob := 60.0 / float64(thr) + return time.Duration(perJob * float64(pos) * float64(time.Second)) +} + +// fifoPolicy is a simple slice-backed FIFO queue. Used by the v1 +// constructor (New). +type fifoPolicy struct { + queue []*queuedJob +} + +func newFIFOPolicy() queuePolicy { return &fifoPolicy{} } + +func (f *fifoPolicy) Enqueue(j *queuedJob) { + f.queue = append(f.queue, j) +} + +func (f *fifoPolicy) Dequeue() *queuedJob { + if len(f.queue) == 0 { + return nil + } + j := f.queue[0] + // Avoid retaining the old reference. + f.queue[0] = nil + f.queue = f.queue[1:] + return j +} + +func (f *fifoPolicy) Cancel(jobID string) bool { + for i, j := range f.queue { + if j.job.ID() == jobID { + // Remove and signal cancelled. + f.queue = append(f.queue[:i], f.queue[i+1:]...) + j.done <- jobResult{err: ErrCancelled} + return true + } + } + return false +} + +func (f *fifoPolicy) Len() int { return len(f.queue) } + +func (f *fifoPolicy) OldestEnqueueTime() *time.Time { + if len(f.queue) == 0 { + return nil + } + t := f.queue[0].enqueuedAt + return &t +} diff --git a/lane/pool_test.go b/lane/pool_test.go new file mode 100644 index 0000000..3f3585e --- /dev/null +++ b/lane/pool_test.go @@ -0,0 +1,485 @@ +package lane + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +// testJob is a Job impl that signals when Run starts and blocks until +// release is closed. Used by tests to control dispatch ordering +// deterministically. +type testJob struct { + id string + caller string + priority int + started chan struct{} + release chan struct{} + err error + + // runCount is incremented inside Run; tests assert "exactly once". + runCount int32 +} + +func newTestJob(id string) *testJob { + return &testJob{ + id: id, + started: make(chan struct{}, 1), + release: make(chan struct{}), + } +} + +func (t *testJob) ID() string { return t.id } +func (t *testJob) CallerID() string { + if t.caller == "" { + return "anon" + } + return t.caller +} +func (t *testJob) Priority() int { return t.priority } +func (t *testJob) Run(ctx context.Context) error { + atomic.AddInt32(&t.runCount, 1) + // Non-blocking send so a test that doesn't drain `started` does + // not deadlock. + select { + case t.started <- struct{}{}: + default: + } + <-t.release + return t.err +} + +// TestPool_Submit_SlotAvailable verifies that Submit dispatches +// immediately when a slot is free. +func TestPool_Submit_SlotAvailable(t *testing.T) { + p := New("test", 1).(*pool) + job := newTestJob("j1") + pos, eta, err := p.Submit(context.Background(), job) + if err != nil { + t.Fatalf("submit err: %v", err) + } + if pos != 0 { + t.Fatalf("expected pos=0 (dispatched), got %d", pos) + } + if eta != 0 { + t.Fatalf("expected eta=0, got %v", eta) + } + // Wait for Run to start. + select { + case <-job.started: + case <-time.After(time.Second): + t.Fatalf("job did not start within 1s") + } + close(job.release) + // Drain completion. + waitForRunning(t, p, 0) +} + +// TestPool_Submit_QueuedWhenFull verifies queue position reporting. +func TestPool_Submit_QueuedWhenFull(t *testing.T) { + p := New("test", 1).(*pool) + j1 := newTestJob("j1") + if _, _, err := p.Submit(context.Background(), j1); err != nil { + t.Fatal(err) + } + <-j1.started + + j2 := newTestJob("j2") + pos, _, err := p.Submit(context.Background(), j2) + if err != nil { + t.Fatal(err) + } + if pos != 1 { + t.Fatalf("expected pos=1 for first queued, got %d", pos) + } + + j3 := newTestJob("j3") + pos, _, err = p.Submit(context.Background(), j3) + if err != nil { + t.Fatal(err) + } + if pos != 2 { + t.Fatalf("expected pos=2 for second queued, got %d", pos) + } + + stats := p.Stats() + if stats.Running != 1 || stats.Queued != 2 { + t.Fatalf("expected running=1 queued=2, got %+v", stats) + } + + // Drain. + close(j1.release) + close(j2.release) + close(j3.release) +} + +// TestPool_SubmitWait_Blocks verifies SubmitWait blocks until Run +// completes and returns Run's error. +func TestPool_SubmitWait_Blocks(t *testing.T) { + p := New("test", 2) + expected := errors.New("boom") + j := newTestJob("j1") + j.err = expected + + var got error + done := make(chan struct{}) + go func() { + got = p.SubmitWait(context.Background(), j) + close(done) + }() + + <-j.started + close(j.release) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("SubmitWait did not return within 1s") + } + if !errors.Is(got, expected) { + t.Fatalf("expected %v, got %v", expected, got) + } +} + +// TestPool_SubmitWait_CtxCancelledWhileQueued verifies that cancelling +// the ctx while queued returns ctx.Err and removes the job. +func TestPool_SubmitWait_CtxCancelledWhileQueued(t *testing.T) { + p := New("test", 1).(*pool) + blocker := newTestJob("blocker") + if _, _, err := p.Submit(context.Background(), blocker); err != nil { + t.Fatal(err) + } + <-blocker.started + + ctx, cancel := context.WithCancel(context.Background()) + target := newTestJob("target") + done := make(chan error, 1) + go func() { + done <- p.SubmitWait(ctx, target) + }() + + // Wait until target is enqueued. + waitFor(t, func() bool { return p.Stats().Queued == 1 }) + cancel() + + select { + case err := <-done: + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + case <-time.After(time.Second): + t.Fatal("SubmitWait did not return after cancel") + } + + // target.Run must never have been called. + if atomic.LoadInt32(&target.runCount) != 0 { + t.Fatalf("target.Run was called %d times, want 0", + target.runCount) + } + + close(blocker.release) +} + +// TestPool_Cancel_RemovesQueued verifies Cancel removes a queued job +// and that a subsequent SubmitWait observer would see ErrCancelled. +// Here we use Submit (fire-and-forget) so we just check that Cancel +// returns nil and the queue shrinks. +func TestPool_Cancel_RemovesQueued(t *testing.T) { + p := New("test", 1).(*pool) + blocker := newTestJob("blocker") + if _, _, err := p.Submit(context.Background(), blocker); err != nil { + t.Fatal(err) + } + <-blocker.started + + target := newTestJob("target") + if _, _, err := p.Submit(context.Background(), target); err != nil { + t.Fatal(err) + } + + if err := p.Cancel("target"); err != nil { + t.Fatalf("cancel: %v", err) + } + if p.Stats().Queued != 0 { + t.Fatalf("expected queued=0 after cancel, got %d", + p.Stats().Queued) + } + + // Cancelling again or cancelling a missing job returns ErrNotQueued. + if err := p.Cancel("target"); !errors.Is(err, ErrNotQueued) { + t.Fatalf("expected ErrNotQueued, got %v", err) + } + + close(blocker.release) +} + +// TestPool_Cancel_PropagatesToSubmitWait verifies that cancelling a +// job whose caller is in SubmitWait returns ErrCancelled. +func TestPool_Cancel_PropagatesToSubmitWait(t *testing.T) { + p := New("test", 1).(*pool) + blocker := newTestJob("blocker") + if _, _, err := p.Submit(context.Background(), blocker); err != nil { + t.Fatal(err) + } + <-blocker.started + + target := newTestJob("target") + done := make(chan error, 1) + go func() { + done <- p.SubmitWait(context.Background(), target) + }() + waitFor(t, func() bool { return p.Stats().Queued == 1 }) + + if err := p.Cancel("target"); err != nil { + t.Fatalf("cancel: %v", err) + } + + select { + case err := <-done: + if !errors.Is(err, ErrCancelled) { + t.Fatalf("expected ErrCancelled, got %v", err) + } + case <-time.After(time.Second): + t.Fatal("SubmitWait did not return after cancel") + } + + close(blocker.release) +} + +// TestPool_Stats_Accurate covers Running + Queued + OldestQueuedAt. +func TestPool_Stats_Accurate(t *testing.T) { + p := New("test", 1).(*pool) + j1 := newTestJob("j1") + if _, _, err := p.Submit(context.Background(), j1); err != nil { + t.Fatal(err) + } + <-j1.started + + beforeQueue := time.Now() + j2 := newTestJob("j2") + if _, _, err := p.Submit(context.Background(), j2); err != nil { + t.Fatal(err) + } + j3 := newTestJob("j3") + if _, _, err := p.Submit(context.Background(), j3); err != nil { + t.Fatal(err) + } + + stats := p.Stats() + if stats.Running != 1 { + t.Errorf("running=%d, want 1", stats.Running) + } + if stats.Queued != 2 { + t.Errorf("queued=%d, want 2", stats.Queued) + } + if stats.OldestQueuedAt == nil { + t.Errorf("OldestQueuedAt is nil") + } else if stats.OldestQueuedAt.Before(beforeQueue.Add(-time.Second)) { + t.Errorf("OldestQueuedAt seems too old: %v vs %v", + *stats.OldestQueuedAt, beforeQueue) + } + + close(j1.release) + close(j2.release) + close(j3.release) +} + +// TestPool_Throughput1m: complete 5 jobs, throughput=5; sleep 1.1s +// would be slow — instead manipulate the completions slice directly. +// The test verifies the slice trimming logic. +func TestPool_Throughput1m(t *testing.T) { + p := New("test", 1).(*pool) + now := time.Now() + // Fill completions slice manually. + p.completions = []time.Time{ + now.Add(-90 * time.Second), + now.Add(-30 * time.Second), + now.Add(-10 * time.Second), + now.Add(-1 * time.Second), + now, + } + stats := p.Stats() + if stats.Throughput1m != 4 { + t.Fatalf("expected 4 (only the last 4 are within 60s), got %d", + stats.Throughput1m) + } +} + +// TestPool_SetMaxConcurrent verifies that raising the cap drains +// queued backlog onto the new slots. +func TestPool_SetMaxConcurrent(t *testing.T) { + p := New("test", 1).(*pool) + j1 := newTestJob("j1") + j2 := newTestJob("j2") + j3 := newTestJob("j3") + + if _, _, err := p.Submit(context.Background(), j1); err != nil { + t.Fatal(err) + } + <-j1.started + + if _, _, err := p.Submit(context.Background(), j2); err != nil { + t.Fatal(err) + } + if _, _, err := p.Submit(context.Background(), j3); err != nil { + t.Fatal(err) + } + + if got := p.Stats().Queued; got != 2 { + t.Fatalf("expected queued=2, got %d", got) + } + + // Raise cap to 3 — should drain both queued jobs immediately. + p.SetMaxConcurrent(3) + waitFor(t, func() bool { return p.Stats().Running == 3 }) + + if got := p.Stats().Queued; got != 0 { + t.Fatalf("expected queued=0 after raise, got %d", got) + } + + close(j1.release) + close(j2.release) + close(j3.release) + waitForRunning(t, p, 0) +} + +// TestPool_SetMaxConcurrent_NoOpZeroOrNegative verifies n<=0 is +// ignored. +func TestPool_SetMaxConcurrent_NoOpZeroOrNegative(t *testing.T) { + p := New("test", 2).(*pool) + p.SetMaxConcurrent(0) + if got := p.Stats().MaxConcurrent; got != 2 { + t.Fatalf("zero set should be no-op, got %d", got) + } + p.SetMaxConcurrent(-1) + if got := p.Stats().MaxConcurrent; got != 2 { + t.Fatalf("negative set should be no-op, got %d", got) + } +} + +// TestPool_DispatchOnComplete verifies that finishing a running job +// pulls the next queued job onto the freed slot. +func TestPool_DispatchOnComplete(t *testing.T) { + p := New("test", 1).(*pool) + + j1 := newTestJob("j1") + j2 := newTestJob("j2") + + if _, _, err := p.Submit(context.Background(), j1); err != nil { + t.Fatal(err) + } + <-j1.started + + if _, _, err := p.Submit(context.Background(), j2); err != nil { + t.Fatal(err) + } + + if got := p.Stats().Queued; got != 1 { + t.Fatalf("expected queued=1, got %d", got) + } + + // Release j1; j2 should auto-dispatch. + close(j1.release) + select { + case <-j2.started: + case <-time.After(time.Second): + t.Fatal("j2 did not dispatch after j1 finished") + } + if got := p.Stats().Queued; got != 0 { + t.Errorf("expected queued=0 after dispatch, got %d", got) + } + close(j2.release) + waitForRunning(t, p, 0) +} + +// TestPool_ConcurrencyLimitRespected fires N jobs at a lane with +// maxConcurrent=2 and verifies at most 2 ever run simultaneously. +func TestPool_ConcurrencyLimitRespected(t *testing.T) { + p := New("test", 2) + const N = 8 + + var inflight int32 + var maxObserved int32 + done := make(chan struct{}, N) + + for i := 0; i < N; i++ { + i := i + j := &funcJob{ + id: fmt.Sprintf("j%d", i), + caller: "u1", + run: func(ctx context.Context) error { + cur := atomic.AddInt32(&inflight, 1) + for { + m := atomic.LoadInt32(&maxObserved) + if cur <= m || atomic.CompareAndSwapInt32(&maxObserved, m, cur) { + break + } + } + time.Sleep(20 * time.Millisecond) + atomic.AddInt32(&inflight, -1) + done <- struct{}{} + return nil + }, + } + if _, _, err := p.Submit(context.Background(), j); err != nil { + t.Fatal(err) + } + } + + for i := 0; i < N; i++ { + <-done + } + if max := atomic.LoadInt32(&maxObserved); max > 2 { + t.Fatalf("expected max in-flight <= 2, observed %d", max) + } +} + +// funcJob is a Job impl driven by a closure. Used by tests that don't +// need the started/release plumbing. +type funcJob struct { + id string + caller string + priority int + run func(ctx context.Context) error +} + +func (f *funcJob) ID() string { return f.id } +func (f *funcJob) CallerID() string { return f.caller } +func (f *funcJob) Priority() int { return f.priority } +func (f *funcJob) Run(ctx context.Context) error { return f.run(ctx) } + +// waitForRunning waits up to 1s for stats.Running == n. +func waitForRunning(t *testing.T, p *pool, n int) { + t.Helper() + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if p.Stats().Running == n { + return + } + time.Sleep(2 * time.Millisecond) + } + t.Fatalf("running != %d after 1s; have %d", n, p.Stats().Running) +} + +// waitFor polls cond up to 1s. +func waitFor(t *testing.T, cond func() bool) { + t.Helper() + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(2 * time.Millisecond) + } + t.Fatalf("condition did not become true within 1s") +} + +// Verify funcJob compiles under the Job interface. +var _ Job = (*funcJob)(nil) + +// silence unused import warning if reached during refactoring +var _ = sync.Mutex{} diff --git a/lane/precreate_test.go b/lane/precreate_test.go new file mode 100644 index 0000000..90dd11f --- /dev/null +++ b/lane/precreate_test.go @@ -0,0 +1,55 @@ +package lane + +import ( + "context" + "sort" + "testing" +) + +// TestRegistry_PreCreateMakesLanesVisible is the lane-level anchor for +// hotfix-5 Bug 4. The mort.go boot path now pre-creates the well-known +// lanes (skill-default, webhook-default, etc.) so they appear on +// `/skills/admin/queues` and pass set-lane validation BEFORE any run +// has ever hit them. +// +// Why this test (vs the existing GetOrCreate idempotency test): the +// production bug was specifically about the registry's lazy-creation +// behaviour combined with the queues page only listing materialised +// lanes. This test asserts the missing piece: after pre-creation, both +// Get and List return the lane immediately, regardless of whether a +// job ever touched it. +func TestRegistry_PreCreateMakesLanesVisible(t *testing.T) { + r := NewRegistry(nil) + wellKnown := []string{ + "ollama", "anthropic-thinking", "anthropic-default", "llm-default", + "skill-default", "skill-heavy", "webhook-default", + } + ctx := context.Background() + for _, name := range wellKnown { + _ = r.GetOrCreate(ctx, name) + } + // Get must return non-nil for every lane WITHOUT going through + // GetOrCreate again — that's the pre-creation guarantee. + for _, name := range wellKnown { + if l := r.Get(name); l == nil { + t.Errorf("lane %q not registered after pre-create; admin "+ + "queues page would be missing it (Bug 4 regression)", name) + } + } + // List must enumerate every pre-created lane. + got := make([]string, 0) + for _, l := range r.List() { + got = append(got, l.Name()) + } + sort.Strings(got) + want := append([]string{}, wellKnown...) + sort.Strings(want) + if len(got) != len(want) { + t.Fatalf("List length: got %d, want %d (got=%v, want=%v)", len(got), len(want), got, want) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("List[%d]: got %q, want %q", i, got[i], want[i]) + } + } +} diff --git a/lane/preemption_test.go b/lane/preemption_test.go new file mode 100644 index 0000000..62998bd --- /dev/null +++ b/lane/preemption_test.go @@ -0,0 +1,283 @@ +package lane + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" +) + +// preemptibleJob is a testJob extension that opts into preemption and +// honours ctx cancellation by returning ctx.Err on the cancel path. +type preemptibleJob struct { + *testJob + preemptible bool + // ranWith is set inside Run with the actual error returned by the + // honest ctx.Done observer so tests can distinguish "preempted" + // from "ran to completion". + ranWith atomic.Value // error +} + +func newPreemptibleJob(id string, priority int, preemptible bool) *preemptibleJob { + pj := &preemptibleJob{testJob: newTestJob(id), preemptible: preemptible} + pj.priority = priority + return pj +} + +func (p *preemptibleJob) IsPreemptible() bool { return p.preemptible } + +// finishedSentinel is a non-nil error stored when Run finishes via +// p.release (no preemption). atomic.Value cannot store nil, so we use +// this sentinel to disambiguate "Run completed normally" from "not +// yet finished". +var finishedSentinel = errors.New("test: finished normally") + +// Run blocks until either ctx is cancelled (preemption) or release is +// closed (normal finish). Records which path won so the test asserts. +func (p *preemptibleJob) Run(ctx context.Context) error { + atomic.AddInt32(&p.runCount, 1) + select { + case p.started <- struct{}{}: + default: + } + select { + case <-ctx.Done(): + err := context.Cause(ctx) + p.ranWith.Store(err) + return err + case <-p.release: + p.ranWith.Store(finishedSentinel) + return p.err + } +} + +// fixedPreemptionPolicy is a test PreemptionPolicy with knobs for +// MinRuntime + Enabled. +type fixedPreemptionPolicy struct { + min time.Duration + enabled bool +} + +func (f *fixedPreemptionPolicy) MinRuntime() time.Duration { return f.min } +func (f *fixedPreemptionPolicy) Enabled() bool { return f.enabled } + +// TestPool_Preemption_FiresOnHigherPriority verifies that a high- +// priority Submit at a full lane preempts a preemptible low-priority +// running job that has been running for at least min-runtime. +func TestPool_Preemption_FiresOnHigherPriority(t *testing.T) { + p := NewWithPolicy("test", 1, NewFairSharePolicy()).(*pool) + p.SetPreemptionPolicy(&fixedPreemptionPolicy{min: 0, enabled: true}) + + low := newPreemptibleJob("low", 0, true) + low.caller = "u1" + + if err := submitNoBlock(p, low); err != nil { + t.Fatalf("submit low: %v", err) + } + <-low.started + + // Slot is full. Submit a higher-priority job — should preempt. + high := newPreemptibleJob("high", 5, false) + high.caller = "u2" + pos, _, err := p.Submit(context.Background(), high) + if err != nil { + t.Fatalf("submit high: %v", err) + } + if pos != 0 { + t.Errorf("high pos = %d, want 0 (dispatched after preempt)", pos) + } + + // Wait for the low's Run to return with ctx.Cause = ErrPreempted. + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if v := low.ranWith.Load(); v != nil { + if err, ok := v.(error); ok && errors.Is(err, ErrPreempted) { + goto done + } + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("low Run never returned with ErrPreempted; ranWith=%v", low.ranWith.Load()) + +done: + // Let high finish. + close(high.release) + // Drain low's release channel to release the goroutine cleanly. + close(low.release) +} + +// TestPool_Preemption_RespectsMinRuntime verifies that a high-priority +// Submit does NOT preempt a job younger than the min-runtime guard. +func TestPool_Preemption_RespectsMinRuntime(t *testing.T) { + p := NewWithPolicy("test", 1, NewFairSharePolicy()).(*pool) + // Min-runtime in the future so no running job qualifies. + p.SetPreemptionPolicy(&fixedPreemptionPolicy{min: time.Hour, enabled: true}) + + low := newPreemptibleJob("low", 0, true) + low.caller = "u1" + if err := submitNoBlock(p, low); err != nil { + t.Fatalf("submit low: %v", err) + } + <-low.started + + high := newPreemptibleJob("high", 5, false) + high.caller = "u2" + pos, _, err := p.Submit(context.Background(), high) + if err != nil { + t.Fatalf("submit high: %v", err) + } + if pos == 0 { + t.Errorf("high pos = 0; expected to be queued (preemption blocked by min-runtime)") + } + + // Confirm low was NOT preempted: ranWith stays nil. + time.Sleep(20 * time.Millisecond) + if v := low.ranWith.Load(); v != nil { + if err, ok := v.(error); ok && err != nil { + t.Errorf("low was preempted unexpectedly: %v", err) + } + } + + close(low.release) + close(high.release) +} + +// TestPool_Preemption_NonPreemptibleProtected verifies that a +// non-preemptible running job is not chosen as a victim even when a +// higher-priority job arrives. +func TestPool_Preemption_NonPreemptibleProtected(t *testing.T) { + p := NewWithPolicy("test", 1, NewFairSharePolicy()).(*pool) + p.SetPreemptionPolicy(&fixedPreemptionPolicy{min: 0, enabled: true}) + + low := newPreemptibleJob("low", 0, false /* not preemptible */) + low.caller = "u1" + if err := submitNoBlock(p, low); err != nil { + t.Fatalf("submit low: %v", err) + } + <-low.started + + high := newPreemptibleJob("high", 5, false) + high.caller = "u2" + pos, _, err := p.Submit(context.Background(), high) + if err != nil { + t.Fatalf("submit high: %v", err) + } + if pos == 0 { + t.Errorf("high pos = 0; expected queued (non-preemptible victim)") + } + + close(low.release) + close(high.release) +} + +// TestPool_SubmitWithMaxWait_ZeroBlocks verifies that maxWait=0 falls +// back to the default Submit path (no early-return). +func TestPool_SubmitWithMaxWait_ZeroBlocks(t *testing.T) { + p := New("test", 1).(*pool) + + first := newTestJob("j1") + if err := submitNoBlock(p, first); err != nil { + t.Fatalf("submit first: %v", err) + } + <-first.started + + second := newTestJob("j2") + pos, _, err := p.SubmitWithMaxWait(context.Background(), second, 0) + if err != nil { + t.Fatalf("submit second: %v", err) + } + if pos == 0 { + t.Errorf("expected second to be queued, got pos=0") + } + + close(first.release) + close(second.release) +} + +// TestPool_SubmitWithMaxWait_RejectsWhenETAExceedsCap verifies that +// SubmitWithMaxWait returns ErrLaneBusy without enqueueing when the +// estimated wait exceeds maxWait. +func TestPool_SubmitWithMaxWait_RejectsWhenETAExceedsCap(t *testing.T) { + p := New("test", 1).(*pool) + p.SetETAWindowSize(4) + + // Run a job that takes ~30ms so the estimator has runtime data. + timed := newTestJob("timed") + go func() { + time.Sleep(30 * time.Millisecond) + close(timed.release) + }() + if err := p.SubmitWait(context.Background(), timed); err != nil { + t.Fatalf("timed: %v", err) + } + + // Block the lane. + blocker := newTestJob("blocker") + go func() { + _ = p.SubmitWait(context.Background(), blocker) + }() + <-blocker.started + + // Try to submit with maxWait=1ns — definitely shorter than the + // average runtime. + hopeless := newTestJob("hopeless") + pos, eta, err := p.SubmitWithMaxWait(context.Background(), hopeless, time.Nanosecond) + if !errors.Is(err, ErrLaneBusy) { + t.Fatalf("err = %v, want ErrLaneBusy; pos=%d eta=%s", err, pos, eta) + } + if eta == 0 { + t.Errorf("expected non-zero eta on busy reject, got 0") + } + + // Was hopeless enqueued? Stats should show 0 queued (only blocker + // running). + stats := p.Stats() + if stats.Queued != 0 { + t.Errorf("hopeless was enqueued despite ErrLaneBusy: queued=%d", stats.Queued) + } + + close(blocker.release) +} + +// TestPool_SubmitWithMaxWait_AllowsWhenETAUnderCap verifies that +// SubmitWithMaxWait does enqueue when the estimated wait is under the +// max. +func TestPool_SubmitWithMaxWait_AllowsWhenETAUnderCap(t *testing.T) { + p := New("test", 1).(*pool) + + first := newTestJob("first") + if err := submitNoBlock(p, first); err != nil { + t.Fatalf("submit first: %v", err) + } + <-first.started + + second := newTestJob("second") + pos, _, err := p.SubmitWithMaxWait(context.Background(), second, time.Hour) + if err != nil { + t.Fatalf("submit second: %v", err) + } + if pos != 1 { + t.Errorf("second pos = %d, want 1", pos) + } + + close(first.release) + close(second.release) +} + +// submitNoBlock is a helper that asynchronously calls SubmitWait so the +// caller can inspect the running job's state without blocking on +// completion. +func submitNoBlock(p Lane, job Job) error { + errCh := make(chan error, 1) + go func() { + errCh <- p.SubmitWait(context.Background(), job) + }() + // Give the dispatch goroutine a chance to start. + select { + case err := <-errCh: + return err + case <-time.After(50 * time.Millisecond): + return nil + } +} diff --git a/lane/registry.go b/lane/registry.go new file mode 100644 index 0000000..0b51791 --- /dev/null +++ b/lane/registry.go @@ -0,0 +1,196 @@ +package lane + +import ( + "context" + "sync" +) + +// ConvarReader is the narrow surface the registry uses to read +// per-lane concurrency caps from convars at startup and on Reload. +// +// Why an interface (not pkg/convar directly): registry is a generic +// primitive and shouldn't import the application convar package. +// Production wires a thin adapter; tests pass a fake. +type ConvarReader interface { + Int(ctx context.Context, name string, def int) int +} + +// ConvarReaderFunc adapts a closure into a ConvarReader. +type ConvarReaderFunc func(ctx context.Context, name string, def int) int + +// Int satisfies ConvarReader. +func (f ConvarReaderFunc) Int(ctx context.Context, name string, def int) int { + if f == nil { + return def + } + return f(ctx, name, def) +} + +// Registry is a manager of named lanes. The default policy is +// fair-share; lanes are created lazily on first GetOrCreate, with +// concurrency read from convar `lanes..max_concurrent` (default +// 1). Reload re-reads convars and updates each lane's MaxConcurrent +// in place — useful for runtime tuning without losing in-flight work. +// +// Why a singleton-ish manager vs constructing lanes ad-hoc: the +// registry is the integration point where mort.go wires lanes once +// and every subsystem (LLM transport, skill runner) looks them up by +// name. Lazy creation lets the registry stay schema-free — adding a +// new lane is just "ask for it by name". +// +// Test: registry_test.go covers GetOrCreate identity, convar read, +// and Reload. +type Registry struct { + mu sync.RWMutex + lanes map[string]Lane + convars ConvarReader + // policyFactory is the queue policy constructor used for new + // lanes. Defaults to NewFairSharePolicy. Tests substitute FIFO + // when they want deterministic ordering. + policyFactory func() queuePolicy +} + +// NewRegistry constructs a registry. convars may be nil — lanes +// fall back to the registry's default concurrency (1). +func NewRegistry(convars ConvarReader) *Registry { + return &Registry{ + lanes: make(map[string]Lane), + convars: convars, + policyFactory: NewFairSharePolicy, + } +} + +// SetPolicyFactory overrides the default policy used for new lanes. +// Existing lanes are unchanged. Used by tests; production keeps the +// fair-share default. +func (r *Registry) SetPolicyFactory(f func() queuePolicy) { + if f == nil { + f = NewFairSharePolicy + } + r.mu.Lock() + r.policyFactory = f + r.mu.Unlock() +} + +// Get returns the named lane or nil if it has not been created. +// Useful in admin/UI code that wants to show only existing lanes +// without creating new ones as a side effect. +func (r *Registry) Get(name string) Lane { + r.mu.RLock() + defer r.mu.RUnlock() + return r.lanes[name] +} + +// StatsReader is the read-only stats surface exposed to admin / user +// dashboards (Discord queue commands, /skills/admin/queues web view). +// *Registry satisfies it; tests substitute a fake. +// +// Why a narrow interface (vs passing *Registry around): the consumers +// only need stats and lane lookup — no creation or mutation surface. +// Keeping the dep narrow makes mocks trivial in webui + skills tests. +type StatsReader interface { + // List returns a snapshot of every registered lane. + List() []Lane + + // Lookup returns the lane by name, or nil. Mirrors Registry.Get + // (named differently to avoid the "Get" verb confusion in + // dashboards that primarily call Stats). + Lookup(name string) Lane +} + +// Lookup satisfies the StatsReader surface alongside Registry.Get. We +// expose both verbs so the dashboard code reads naturally without +// forcing existing call sites that use Get() to migrate. +func (r *Registry) Lookup(name string) Lane { return r.Get(name) } + +// GetOrCreate returns the named lane, creating it lazily on first +// call. Concurrency is read from convar `lanes..max_concurrent` +// (default 1). The policy is the registry's policy factory (default +// fair-share). +// +// Why convar name `lanes..max_concurrent` (not +// `skills.lane..max_concurrent`): pkg/lane is generic — the +// skills system happens to be the first caller, but the LLM transport +// wrapper (Phase 3) and other future runners will use the same +// registry. The convar namespace `lanes.*` keeps lane configuration +// in one place. The skills system can adopt different convar names +// if it prefers; in that case, mort.go reads them and calls +// SetMaxConcurrent on the resulting lanes after creation. +func (r *Registry) GetOrCreate(ctx context.Context, name string) Lane { + r.mu.RLock() + if l, ok := r.lanes[name]; ok { + r.mu.RUnlock() + return l + } + r.mu.RUnlock() + + r.mu.Lock() + defer r.mu.Unlock() + // Double-check after upgrading the lock. + if l, ok := r.lanes[name]; ok { + return l + } + maxConcurrent := r.readConcurrency(ctx, name) + policy := r.policyFactory() + if policy == nil { + policy = NewFairSharePolicy() + } + l := NewWithPolicy(name, maxConcurrent, policy) + r.lanes[name] = l + return l +} + +// List returns a snapshot of all registered lanes. Iteration order is +// not guaranteed (Go map randomization). +func (r *Registry) List() []Lane { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]Lane, 0, len(r.lanes)) + for _, l := range r.lanes { + out = append(out, l) + } + return out +} + +// Names returns the registered lane names. Used for the admin +// "list all lanes" surface. +func (r *Registry) Names() []string { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]string, 0, len(r.lanes)) + for n := range r.lanes { + out = append(out, n) + } + return out +} + +// Reload re-reads convars for every registered lane and calls +// SetMaxConcurrent on each. Existing running jobs continue to run; +// new dispatches respect the updated cap. +// +// Why a manual Reload instead of reading convars at every dispatch: +// dispatch is on the hot path; reading a convar there for every +// queued job is wasteful. A periodic Reload (every minute, say) is +// cheap and good enough for human-driven config changes. +func (r *Registry) Reload(ctx context.Context) { + r.mu.RLock() + defer r.mu.RUnlock() + for name, l := range r.lanes { + n := r.readConcurrency(ctx, name) + l.SetMaxConcurrent(n) + } +} + +// readConcurrency reads `lanes..max_concurrent` with a default +// of 1. Defensive against a nil ConvarReader and against negative +// values (clamped to 1). +func (r *Registry) readConcurrency(ctx context.Context, name string) int { + if r.convars == nil { + return 1 + } + n := r.convars.Int(ctx, "lanes."+name+".max_concurrent", 1) + if n <= 0 { + return 1 + } + return n +} diff --git a/lane/registry_test.go b/lane/registry_test.go new file mode 100644 index 0000000..26ec90d --- /dev/null +++ b/lane/registry_test.go @@ -0,0 +1,202 @@ +package lane + +import ( + "context" + "sync" + "testing" + "time" +) + +// fakeConvars is a ConvarReader fake backed by a map. +type fakeConvars struct { + mu sync.Mutex + vals map[string]int +} + +func newFakeConvars() *fakeConvars { return &fakeConvars{vals: map[string]int{}} } + +func (f *fakeConvars) set(name string, v int) { + f.mu.Lock() + defer f.mu.Unlock() + f.vals[name] = v +} + +func (f *fakeConvars) Int(_ context.Context, name string, def int) int { + f.mu.Lock() + defer f.mu.Unlock() + if v, ok := f.vals[name]; ok { + return v + } + return def +} + +// TestRegistry_GetOrCreate verifies GetOrCreate creates the lane on +// first call and returns the same instance on subsequent calls. +func TestRegistry_GetOrCreate(t *testing.T) { + r := NewRegistry(nil) + l1 := r.GetOrCreate(context.Background(), "ollama") + l2 := r.GetOrCreate(context.Background(), "ollama") + if l1 != l2 { + t.Fatalf("expected same lane instance on second GetOrCreate") + } + if got := l1.Name(); got != "ollama" { + t.Fatalf("expected name=ollama, got %s", got) + } +} + +// TestRegistry_ConvarConcurrency verifies the convar value drives the +// lane's MaxConcurrent at creation time. +func TestRegistry_ConvarConcurrency(t *testing.T) { + c := newFakeConvars() + c.set("lanes.ollama.max_concurrent", 3) + r := NewRegistry(c) + l := r.GetOrCreate(context.Background(), "ollama") + if got := l.Stats().MaxConcurrent; got != 3 { + t.Fatalf("expected MaxConcurrent=3, got %d", got) + } +} + +// TestRegistry_DefaultConcurrencyOne verifies that a missing convar +// falls back to 1. +func TestRegistry_DefaultConcurrencyOne(t *testing.T) { + r := NewRegistry(nil) + l := r.GetOrCreate(context.Background(), "default") + if got := l.Stats().MaxConcurrent; got != 1 { + t.Fatalf("expected default MaxConcurrent=1, got %d", got) + } +} + +// TestRegistry_NegativeConvarClamped verifies that a negative or zero +// convar value is clamped to 1. +func TestRegistry_NegativeConvarClamped(t *testing.T) { + c := newFakeConvars() + c.set("lanes.bad.max_concurrent", -5) + r := NewRegistry(c) + l := r.GetOrCreate(context.Background(), "bad") + if got := l.Stats().MaxConcurrent; got != 1 { + t.Fatalf("expected clamped to 1, got %d", got) + } +} + +// TestRegistry_Reload picks up convar changes for existing lanes. +func TestRegistry_Reload(t *testing.T) { + c := newFakeConvars() + c.set("lanes.x.max_concurrent", 2) + r := NewRegistry(c) + l := r.GetOrCreate(context.Background(), "x") + if got := l.Stats().MaxConcurrent; got != 2 { + t.Fatalf("expected 2 at create, got %d", got) + } + + c.set("lanes.x.max_concurrent", 5) + r.Reload(context.Background()) + + if got := l.Stats().MaxConcurrent; got != 5 { + t.Fatalf("expected 5 after Reload, got %d", got) + } +} + +// TestRegistry_List returns all created lanes. +func TestRegistry_List(t *testing.T) { + r := NewRegistry(nil) + r.GetOrCreate(context.Background(), "a") + r.GetOrCreate(context.Background(), "b") + r.GetOrCreate(context.Background(), "c") + if got := len(r.List()); got != 3 { + t.Fatalf("expected 3 lanes, got %d", got) + } + names := r.Names() + if len(names) != 3 { + t.Fatalf("expected 3 names, got %v", names) + } +} + +// TestRegistry_Get returns nil for missing lane (no implicit create). +func TestRegistry_Get(t *testing.T) { + r := NewRegistry(nil) + if got := r.Get("nope"); got != nil { + t.Fatalf("expected nil for missing lane, got %v", got) + } + r.GetOrCreate(context.Background(), "yes") + if got := r.Get("yes"); got == nil { + t.Fatalf("expected non-nil for existing lane") + } +} + +// TestRegistry_PolicyFactoryDefault verifies the default factory +// produces fair-share lanes (round-robins across users). +func TestRegistry_PolicyFactoryDefault(t *testing.T) { + c := newFakeConvars() + c.set("lanes.fair.max_concurrent", 1) + r := NewRegistry(c) + lane := r.GetOrCreate(context.Background(), "fair") + + // Block lane with one job so subsequent submits queue. + blocker := newTestJob("blocker") + blocker.caller = "blocker-user" + if _, _, err := lane.Submit(context.Background(), blocker); err != nil { + t.Fatal(err) + } + <-blocker.started + + startOrder := make(chan string, 3) + mkJob := func(id, caller string) *funcJob { + return &funcJob{ + id: id, caller: caller, + run: func(ctx context.Context) error { + startOrder <- id + return nil + }, + } + } + if _, _, err := lane.Submit(context.Background(), mkJob("a1", "userA")); err != nil { + t.Fatal(err) + } + if _, _, err := lane.Submit(context.Background(), mkJob("a2", "userA")); err != nil { + t.Fatal(err) + } + if _, _, err := lane.Submit(context.Background(), mkJob("b1", "userB")); err != nil { + t.Fatal(err) + } + if got := lane.Stats().Queued; got != 3 { + t.Fatalf("expected queued=3, got %d", got) + } + close(blocker.release) + + var order []string + for i := 0; i < 3; i++ { + select { + case id := <-startOrder: + order = append(order, id) + case <-time.After(time.Second): + t.Fatalf("did not observe all dispatches; got %v", order) + } + } + // b1 must run at position 0 or 1 (after at most one A). + pos := -1 + for i, id := range order { + if id == "b1" { + pos = i + break + } + } + if pos > 1 { + t.Fatalf("b1 ran at position %d among %v; expected 0 or 1", + pos, order) + } +} + +// TestRegistry_SetPolicyFactory verifies tests can override the +// default factory. +func TestRegistry_SetPolicyFactory(t *testing.T) { + r := NewRegistry(nil) + called := false + r.SetPolicyFactory(func() queuePolicy { + called = true + return newFIFOPolicy() + }) + r.GetOrCreate(context.Background(), "x") + if !called { + t.Fatal("custom policy factory was not called") + } +} diff --git a/lane/sampler.go b/lane/sampler.go new file mode 100644 index 0000000..f63acd9 --- /dev/null +++ b/lane/sampler.go @@ -0,0 +1,203 @@ +// Package lane — sampler.go: periodic occupancy sampler (v7). +// +// Why a dedicated sampler goroutine: /skills/admin/queues shows current +// state but operators need a timeline ("ollama lane was saturated for 4 +// hours yesterday afternoon"). Sampling at fixed intervals is the +// simplest way to capture that without instrumenting every Submit/ +// complete path. Sampling is best-effort observability — if the +// goroutine dies, charts show a gap; nothing else breaks. +// +// Why in pkg/lane (vs pkg/logic/skills/lane_sampler.go): the sampler +// reads from the lane registry which lives here. The persistence layer +// (skill_lane_samples table) lives in skills, so the sampler takes a +// narrow LaneSampleSink interface — production wires +// `skills.Storage.RecordLaneSample`; tests substitute a fake. +package lane + +import ( + "context" + "log/slog" + "sync" + "time" +) + +// LaneSampleSink is the persistence surface the sampler writes to. +// Production wires skills.Storage; tests substitute a recording fake. +// +// Why a narrow interface (vs importing skills.Storage): pkg/lane is a +// generic primitive that must NOT import the application's skills +// package — that would create an import cycle. Using a small typed +// interface keeps lane decoupled. +type LaneSampleSink interface { + RecordLaneSample(ctx context.Context, lane string, running, queued int, sampledAt time.Time) error +} + +// LaneSampleSinkFunc adapts a closure to LaneSampleSink. Useful in +// production wiring (mort.go) where the underlying storage method has +// a different shape. +type LaneSampleSinkFunc func(ctx context.Context, lane string, running, queued int, sampledAt time.Time) error + +// RecordLaneSample satisfies LaneSampleSink. +func (f LaneSampleSinkFunc) RecordLaneSample(ctx context.Context, lane string, running, queued int, sampledAt time.Time) error { + if f == nil { + return nil + } + return f(ctx, lane, running, queued, sampledAt) +} + +// LaneSamplePurger is the periodic-sweeper surface. Production wires +// skills.Storage.PurgeLaneSamples. +type LaneSamplePurger interface { + PurgeLaneSamples(ctx context.Context, olderThan time.Time) (int64, error) +} + +// LaneSamplePurgerFunc adapts a closure. +type LaneSamplePurgerFunc func(ctx context.Context, olderThan time.Time) (int64, error) + +// PurgeLaneSamples satisfies LaneSamplePurger. +func (f LaneSamplePurgerFunc) PurgeLaneSamples(ctx context.Context, olderThan time.Time) (int64, error) { + if f == nil { + return 0, nil + } + return f(ctx, olderThan) +} + +// Sampler periodically reads stats from every lane in the registry and +// writes one sample row per lane via the configured Sink. Optionally +// runs a daily retention sweep that purges samples older than +// RetentionDays via Purger. +// +// Test: sampler_test.go drives Sample() synchronously with a fake +// clock + recording sink. +type Sampler struct { + registry *Registry + sink LaneSampleSink + purger LaneSamplePurger + + interval time.Duration + retention time.Duration + purgeInterval time.Duration + clock func() time.Time + + // run-time state + mu sync.Mutex + running bool + stopCh chan struct{} + doneCh chan struct{} +} + +// NewSampler constructs the sampler. +// +// interval — sample cadence (typically 30s in production). +// retention — purge cutoff (typically 7d). +// clock=nil → time.Now. +func NewSampler(registry *Registry, sink LaneSampleSink, purger LaneSamplePurger, + interval, retention time.Duration, clock func() time.Time) *Sampler { + if interval <= 0 { + interval = 30 * time.Second + } + if retention <= 0 { + retention = 7 * 24 * time.Hour + } + if clock == nil { + clock = time.Now + } + return &Sampler{ + registry: registry, + sink: sink, + purger: purger, + interval: interval, + retention: retention, + purgeInterval: 24 * time.Hour, + clock: clock, + } +} + +// Start launches the sampler goroutine. Cancelling ctx stops it. +// Idempotent — calling Start twice without an intervening Stop is a +// no-op for the second call. +func (s *Sampler) Start(ctx context.Context) { + s.mu.Lock() + if s.running { + s.mu.Unlock() + return + } + s.running = true + s.stopCh = make(chan struct{}) + s.doneCh = make(chan struct{}) + s.mu.Unlock() + + go s.loop(ctx) +} + +// Stop signals the sampler to exit and waits for the goroutine to +// finish. Idempotent. +func (s *Sampler) Stop() { + s.mu.Lock() + if !s.running { + s.mu.Unlock() + return + } + close(s.stopCh) + doneCh := s.doneCh + s.running = false + s.mu.Unlock() + <-doneCh +} + +// Sample runs one sampling pass synchronously. Test entry point — +// production callers use Start. +func (s *Sampler) Sample(ctx context.Context) { + if s.registry == nil || s.sink == nil { + return + } + now := s.clock() + for _, l := range s.registry.List() { + st := l.Stats() + if err := s.sink.RecordLaneSample(ctx, st.Name, st.Running, st.Queued, now); err != nil { + // Best-effort observability — log and continue, never block. + slog.Warn("lane sampler: record failed", "lane", st.Name, "error", err) + } + } +} + +// PurgeOnce runs one retention sweep synchronously. Test entry point. +func (s *Sampler) PurgeOnce(ctx context.Context) { + if s.purger == nil { + return + } + cutoff := s.clock().Add(-s.retention) + if _, err := s.purger.PurgeLaneSamples(ctx, cutoff); err != nil { + slog.Warn("lane sampler: purge failed", "error", err) + } +} + +// loop is the sampler's main goroutine. Calls Sample at the interval +// cadence and PurgeOnce daily. Exits on ctx.Done OR Stop. +func (s *Sampler) loop(ctx context.Context) { + defer func() { + s.mu.Lock() + if s.doneCh != nil { + close(s.doneCh) + s.doneCh = nil + } + s.mu.Unlock() + }() + sampleTicker := time.NewTicker(s.interval) + defer sampleTicker.Stop() + purgeTicker := time.NewTicker(s.purgeInterval) + defer purgeTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-s.stopCh: + return + case <-sampleTicker.C: + s.Sample(ctx) + case <-purgeTicker.C: + s.PurgeOnce(ctx) + } + } +} diff --git a/lane/sampler_test.go b/lane/sampler_test.go new file mode 100644 index 0000000..ecc5ed1 --- /dev/null +++ b/lane/sampler_test.go @@ -0,0 +1,113 @@ +package lane + +import ( + "context" + "sync" + "testing" + "time" +) + +// recordingSink captures every RecordLaneSample call. +type recordingSink struct { + mu sync.Mutex + samples []sampleRow +} + +type sampleRow struct { + lane string + running int + queued int + sampledAt time.Time +} + +func (r *recordingSink) RecordLaneSample(_ context.Context, lane string, running, queued int, sampledAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + r.samples = append(r.samples, sampleRow{lane, running, queued, sampledAt}) + return nil +} + +// recordingPurger captures PurgeLaneSamples cutoffs. +type recordingPurger struct { + mu sync.Mutex + cutoffs []time.Time +} + +func (r *recordingPurger) PurgeLaneSamples(_ context.Context, olderThan time.Time) (int64, error) { + r.mu.Lock() + defer r.mu.Unlock() + r.cutoffs = append(r.cutoffs, olderThan) + return 0, nil +} + +func TestSampler_SamplesAllLanes(t *testing.T) { + reg := NewRegistry(nil) + reg.GetOrCreate(context.Background(), "ollama") + reg.GetOrCreate(context.Background(), "anthropic-default") + + sink := &recordingSink{} + now := time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC) + clock := func() time.Time { return now } + s := NewSampler(reg, sink, nil, 30*time.Second, 7*24*time.Hour, clock) + + s.Sample(context.Background()) + + sink.mu.Lock() + defer sink.mu.Unlock() + if len(sink.samples) != 2 { + t.Fatalf("expected 2 samples (one per lane), got %d", len(sink.samples)) + } + seen := map[string]bool{} + for _, sm := range sink.samples { + seen[sm.lane] = true + if !sm.sampledAt.Equal(now) { + t.Errorf("expected sampledAt=%v, got %v", now, sm.sampledAt) + } + } + if !seen["ollama"] || !seen["anthropic-default"] { + t.Fatalf("missing lane: %+v", seen) + } +} + +func TestSampler_PurgeOnceUsesRetentionWindow(t *testing.T) { + reg := NewRegistry(nil) + purger := &recordingPurger{} + now := time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC) + clock := func() time.Time { return now } + s := NewSampler(reg, nil, purger, 30*time.Second, 7*24*time.Hour, clock) + + s.PurgeOnce(context.Background()) + + purger.mu.Lock() + defer purger.mu.Unlock() + if len(purger.cutoffs) != 1 { + t.Fatalf("expected 1 purge call, got %d", len(purger.cutoffs)) + } + want := now.Add(-7 * 24 * time.Hour) + if !purger.cutoffs[0].Equal(want) { + t.Fatalf("cutoff: want %v, got %v", want, purger.cutoffs[0]) + } +} + +func TestSampler_NilSinkOrRegistryIsSafe(t *testing.T) { + // nil registry — no-op, no panic. + s := NewSampler(nil, &recordingSink{}, nil, 30*time.Second, 7*24*time.Hour, nil) + s.Sample(context.Background()) + + // nil sink — no-op. + reg := NewRegistry(nil) + reg.GetOrCreate(context.Background(), "ollama") + s2 := NewSampler(reg, nil, nil, 30*time.Second, 7*24*time.Hour, nil) + s2.Sample(context.Background()) +} + +func TestSampler_StartStopIdempotent(t *testing.T) { + reg := NewRegistry(nil) + sink := &recordingSink{} + s := NewSampler(reg, sink, nil, 30*time.Second, 7*24*time.Hour, nil) + ctx := context.Background() + s.Start(ctx) + s.Start(ctx) // second Start is a no-op + s.Stop() + s.Stop() // second Stop is a no-op +} diff --git a/pendingattach/pendingattach.go b/pendingattach/pendingattach.go new file mode 100644 index 0000000..5727446 --- /dev/null +++ b/pendingattach/pendingattach.go @@ -0,0 +1,54 @@ +// Package pendingattach holds the canonical pending-attachment row type +// and its dedupe helper. It imports nothing from the rest of mort so +// every consumer (the skills storage layer, the send_attachments tool, +// and all delivery drainers) can depend on it without import cycles. +// +// CONTRACT: enqueue (Storage.AddPendingAttachment) is NOT idempotent — a +// model that calls send_attachments more than once for the same file +// leaves multiple rows. Every consumer that drains pending attachments +// MUST call Dedupe before delivery, or the artifact double-posts. +package pendingattach + +import "time" + +// Attachment is one deferred-attachment row. Field-for-field the shape +// the skills storage layer persists. +type Attachment struct { + ID string + RunID string + SkillID string + FileID string + Filename string + Mime string + SizeBytes int64 + MessageText string + HostedURL string + Ord int + + // CreatedAt is the enqueue time. Used only by the retention sweeper + // (PurgePendingAttachments) and storage round-trip tests; delivery + // drainers ignore it. Zero on the enqueue path — the storage layer + // defaults it to time.Now(). + CreatedAt time.Time +} + +// Dedupe removes rows whose FileID has already been seen (first +// occurrence wins, preserving input order). Rows with an empty FileID +// are never collapsed. +func Dedupe(rows []Attachment) []Attachment { + if len(rows) < 2 { + return rows + } + seen := make(map[string]struct{}, len(rows)) + out := make([]Attachment, 0, len(rows)) + for _, row := range rows { + if row.FileID != "" { + if _, dup := seen[row.FileID]; dup { + continue + } + seen[row.FileID] = struct{}{} + } + out = append(out, row) + } + return out +} diff --git a/pendingattach/pendingattach_test.go b/pendingattach/pendingattach_test.go new file mode 100644 index 0000000..bfa86ba --- /dev/null +++ b/pendingattach/pendingattach_test.go @@ -0,0 +1,35 @@ +package pendingattach + +import "testing" + +func TestDedupe_CollapsesByFileID(t *testing.T) { + rows := []Attachment{ + {ID: "1", FileID: "a", Ord: 0}, + {ID: "2", FileID: "b", Ord: 1}, + {ID: "3", FileID: "a", Ord: 2}, // dup of file a + } + out := Dedupe(rows) + if len(out) != 2 { + t.Fatalf("want 2 rows, got %d: %+v", len(out), out) + } + if out[0].ID != "1" || out[1].ID != "2" { + t.Fatalf("first-wins order not preserved: %+v", out) + } +} + +func TestDedupe_EmptyFileIDNeverCollapsed(t *testing.T) { + rows := []Attachment{ + {ID: "1", FileID: ""}, + {ID: "2", FileID: ""}, + } + if got := Dedupe(rows); len(got) != 2 { + t.Fatalf("empty FileID rows must not collapse, got %d", len(got)) + } +} + +func TestDedupe_ShortInputPassthrough(t *testing.T) { + rows := []Attachment{{ID: "1", FileID: "a"}} + if got := Dedupe(rows); len(got) != 1 { + t.Fatalf("single row should pass through, got %d", len(got)) + } +} diff --git a/run/progress.go b/run/progress.go new file mode 100644 index 0000000..49b62d0 --- /dev/null +++ b/run/progress.go @@ -0,0 +1,86 @@ +package run + +import "context" + +// ProgressSink reports a one-line progress note for the current run upward to +// any ancestor run that is being watched by a run-critic. It exists to solve a +// specific false-positive: when an agent calls a long-running skill/agent as a +// single tool, the parent's agent loop is BLOCKED on that one tool call for the +// whole child run, so the parent's progress recorder sees "zero iterations, +// zero new tokens, no activity" and its critic concludes the tool "hung +// indefinitely" — even though the child is iterating happily. Forwarding the +// child's per-step activity up the chain keeps every blocked ancestor's +// last-activity fresh, so a healthy-but-slow child is no longer mistaken for a +// hang. A nil ProgressSink is safe to ignore (there is no ancestor to notify). +type ProgressSink func(note string) + +type progressSinkKey struct{} + +// WithProgressSink returns a context carrying sink for descendant runs to find +// via ProgressSinkFrom. A nil sink is stored as-is (ProgressSinkFrom returns +// nil), which callers treat as "no ancestor watching". +func WithProgressSink(ctx context.Context, sink ProgressSink) context.Context { + return context.WithValue(ctx, progressSinkKey{}, sink) +} + +// ProgressSinkFrom returns the ancestor progress sink carried on ctx, or nil +// if none is wired. The returned sink, when non-nil, forwards a note to the +// immediate parent run's recorder AND (transitively) to every further +// ancestor, because each level installs a sink that forwards upward. +func ProgressSinkFrom(ctx context.Context) ProgressSink { + if v := ctx.Value(progressSinkKey{}); v != nil { + if s, ok := v.(ProgressSink); ok { + return s + } + } + return nil +} + +// InstallProgressBridge wires the current run into the ancestor progress chain. +// +// report — this run's own recorder hook (e.g. recorder.OnStatus). nil when +// the run has no critic recorder of its own (the common skill case); +// the bridge then purely forwards descendants' progress upward. +// +// It returns: +// +// childCtx — pass this to the agent loop / toolbox so descendant runs +// (invoked as tools) forward their progress into this chain. +// notifyAncestors — call this on each of THIS run's own loop steps to keep +// every ancestor critic's last-activity fresh. nil when this +// run has no ancestors (it is a top-level run); nil-safe to +// call only via the returned value being checked, so callers +// should guard `if notifyAncestors != nil`. +// +// The chain is built so that a note from any descendant bumps the recorders of +// ALL of its blocked ancestors, not just its immediate parent. +func InstallProgressBridge(ctx context.Context, report ProgressSink) (childCtx context.Context, notifyAncestors ProgressSink) { + parent := ProgressSinkFrom(ctx) + + // The sink descendants will call. It must bump this run's own recorder + // (report) AND forward to all ancestors (parent). Collapse to the minimal + // closure so we don't stack a no-op wrapper for recorder-less runs. + var child ProgressSink + switch { + case report == nil: + // No recorder of our own: descendants forward straight to ancestors. + child = parent + case parent == nil: + // Top-level run with a recorder: descendants feed only our recorder. + child = report + default: + child = func(note string) { + report(note) + parent(note) + } + } + + childCtx = ctx + if child != nil { + childCtx = WithProgressSink(ctx, child) + } + // This run's own steps notify ancestors directly (its own recorder is fed + // separately by its step observer, so we deliberately do not call report + // here — only the ancestors need waking). + return childCtx, parent +} diff --git a/run/progress_test.go b/run/progress_test.go new file mode 100644 index 0000000..de37a3d --- /dev/null +++ b/run/progress_test.go @@ -0,0 +1,123 @@ +package run + +import ( + "context" + "testing" +) + +// ProgressSinkFrom on a bare context returns nil (nothing wired). +func TestProgressSinkFrom_Empty(t *testing.T) { + if got := ProgressSinkFrom(context.Background()); got != nil { + t.Fatalf("expected nil sink on bare context, got non-nil") + } +} + +// WithProgressSink round-trips a sink through the context. +func TestWithProgressSink_RoundTrip(t *testing.T) { + var got string + ctx := WithProgressSink(context.Background(), func(n string) { got = n }) + sink := ProgressSinkFrom(ctx) + if sink == nil { + t.Fatal("expected non-nil sink") + } + sink("hello") + if got != "hello" { + t.Fatalf("sink did not deliver note; got %q", got) + } +} + +// InstallProgressBridge with no parent and a report only feeds the report, +// and notifyAncestors is nil (there are no ancestors to notify). +func TestInstallProgressBridge_NoParent(t *testing.T) { + var reported []string + childCtx, notify := InstallProgressBridge(context.Background(), func(n string) { + reported = append(reported, n) + }) + if notify != nil { + t.Fatal("expected nil notifyAncestors when there is no parent") + } + // A child installed under childCtx should reach our report. + child := ProgressSinkFrom(childCtx) + if child == nil { + t.Fatal("expected a child sink installed") + } + child("from-child") + if len(reported) != 1 || reported[0] != "from-child" { + t.Fatalf("report did not receive child note; got %v", reported) + } +} + +// InstallProgressBridge with a nil report and an existing parent must pass the +// parent through unchanged (no needless wrapper layer) — this is the skill +// case: a skill run has no recorder of its own but must forward its progress +// to the ancestor agent's critic. +func TestInstallProgressBridge_NilReportPassesParentThrough(t *testing.T) { + var ancestor []string + base := WithProgressSink(context.Background(), func(n string) { ancestor = append(ancestor, n) }) + + childCtx, notify := InstallProgressBridge(base, nil) + // This run's own steps must notify the ancestor. + if notify == nil { + t.Fatal("expected non-nil notifyAncestors when a parent exists") + } + notify("my-step") + + // And a descendant under childCtx must also reach the ancestor. + ProgressSinkFrom(childCtx)("grandchild-step") + + if len(ancestor) != 2 || ancestor[0] != "my-step" || ancestor[1] != "grandchild-step" { + t.Fatalf("ancestor did not receive both notes; got %v", ancestor) + } +} + +// The full three-level chain: grandchild progress must bump BOTH the child's +// own report and the root ancestor — this is the depth>=2 case (agent -> +// sub-agent -> sub-sub-agent) where every blocked ancestor must stay alive. +func TestInstallProgressBridge_ThreeLevelChain(t *testing.T) { + var root, mid []string + + // Level 0 (root agent): has a recorder (report), no parent. + rootCtx, rootNotify := InstallProgressBridge(context.Background(), + func(n string) { root = append(root, n) }) + if rootNotify != nil { + t.Fatal("root should have no ancestors") + } + + // Level 1 (child agent): has its own recorder, parent = root. + midCtx, midNotify := InstallProgressBridge(rootCtx, + func(n string) { mid = append(mid, n) }) + if midNotify == nil { + t.Fatal("mid should notify root") + } + + // Level 1's own step notifies root only (its own recorder is fed by its + // own step observer, not via notifyAncestors). + midNotify("mid-step") + if len(root) != 1 || root[0] != "mid-step" { + t.Fatalf("root missed mid-step; root=%v", root) + } + + // Level 2 (grandchild): no recorder, parent = mid. + gcCtx, gcNotify := InstallProgressBridge(midCtx, nil) + if gcNotify == nil { + t.Fatal("grandchild should notify its ancestors") + } + // Grandchild's own step must bump BOTH mid (its parent's recorder) and + // root (mid forwards upward). + gcNotify("gc-step") + if len(mid) != 1 || mid[0] != "gc-step" { + t.Fatalf("mid missed gc-step; mid=%v", mid) + } + if len(root) != 2 || root[1] != "gc-step" { + t.Fatalf("root missed forwarded gc-step; root=%v", root) + } + + // A descendant installed under gcCtx still reaches mid + root. + ProgressSinkFrom(gcCtx)("ggc-step") + if len(mid) != 2 || mid[1] != "ggc-step" { + t.Fatalf("mid missed ggc-step; mid=%v", mid) + } + if len(root) != 3 || root[2] != "ggc-step" { + t.Fatalf("root missed ggc-step; root=%v", root) + } +}