From 0147a79d187b91461632aefcef0abf4684e19311 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Wed, 10 Jun 2026 13:30:06 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20conversion-driven=20extensions=20?= =?UTF-8?q?=E2=80=94=20resolvers,=20DefineTool,=20hooks,=20ops=20controls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 9a (ADR-0014): Registry.RegisterResolver for dynamic tiers; DefineTool[Args] typed tools; Usage cache/reasoning detail fields wired through anthropic/openai/google; WithPromptCaching (Anthropic cache_control); agent supervision hooks (WithMaxStepsFunc, WithSteer, WithCompactor, WithToolErrorLimits + ErrToolLoop); health Bench/Unbench/Snapshot; ChainConfig.Observer failover events. Co-Authored-By: Claude Fable 5 --- agent/agent.go | 124 ++++++++++++- agent/hooks_test.go | 175 ++++++++++++++++++ chain.go | 9 + docs/adr/0014-conversion-driven-extensions.md | 50 +++++ docs/adr/README.md | 1 + extensions_test.go | 172 +++++++++++++++++ generate.go | 6 + health/health.go | 53 ++++++ llm/request.go | 11 ++ llm/response.go | 17 +- llm/tool.go | 34 ++++ majordomo.go | 1 + parse.go | 13 +- progress.md | 14 ++ provider/anthropic/wire.go | 18 +- provider/google/google.go | 6 +- provider/google/stream.go | 6 +- provider/openai/model.go | 5 +- provider/openai/stream.go | 5 +- provider/openai/wire.go | 29 ++- registry.go | 47 +++++ 21 files changed, 767 insertions(+), 29 deletions(-) create mode 100644 agent/hooks_test.go create mode 100644 docs/adr/0014-conversion-driven-extensions.md create mode 100644 extensions_test.go diff --git a/agent/agent.go b/agent/agent.go index c2806ec..05519ee 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -27,6 +27,11 @@ const DefaultMaxSteps = 10 // carrying the transcript so far. var ErrMaxSteps = errors.New("agent: max steps reached without a final answer") +// ErrToolLoop reports that the loop tripped a tool-error guard +// (consecutive all-error steps or identical repeated calls; see +// WithToolErrorLimits). Run returns it alongside the partial *Result. +var ErrToolLoop = errors.New("agent: tool-error guard tripped") + // Skill is the contract skills satisfy (defined here so agent does not // depend on the skill package; package skill provides implementations). // Instructions are appended to the agent's system prompt; Tools (optional, @@ -67,13 +72,17 @@ type Result struct { // it later. Agents are safe to share across goroutines only after // configuration is complete. type Agent struct { - model llm.Model - system string - toolboxes []*llm.Toolbox - skills []Skill - maxSteps int - reqOpts []llm.Option - observers []func(Step) + model llm.Model + system string + toolboxes []*llm.Toolbox + skills []Skill + maxSteps int + maxStepsFunc func() int + compactor func(ctx context.Context, msgs []llm.Message) ([]llm.Message, error) + maxConsecutiveToolErrors int + maxSameCallRepeats int + reqOpts []llm.Option + observers []func(Step) } // Option configures an Agent at construction. @@ -99,6 +108,34 @@ func WithMaxSteps(n int) Option { return func(a *Agent) { a.maxSteps = n } } +// WithMaxStepsFunc makes the step ceiling dynamic: the function is +// consulted before every step, so a supervisor can extend (or shrink) a +// running agent's budget. It overrides WithMaxSteps while non-nil; a +// non-positive return falls back to the static value. +func WithMaxStepsFunc(fn func() int) Option { + return func(a *Agent) { a.maxStepsFunc = fn } +} + +// WithCompactor installs a context-compaction hook, called with the full +// message slice before every model call; whatever it returns is sent +// instead (e.g. summarize the middle of a long transcript). A compactor +// error is non-fatal: the original messages are used. +func WithCompactor(fn func(ctx context.Context, msgs []llm.Message) ([]llm.Message, error)) Option { + return func(a *Agent) { a.compactor = fn } +} + +// WithToolErrorLimits installs loop guards: maxConsecutiveErrors bounds +// successive steps whose tool results were ALL errors, and +// maxSameCallRepeats bounds identical (name + arguments) tool calls within +// one run. Either guard tripping ends the run with ErrToolLoop and the +// partial result. Zero disables a guard. +func WithToolErrorLimits(maxConsecutiveErrors, maxSameCallRepeats int) Option { + return func(a *Agent) { + a.maxConsecutiveToolErrors = maxConsecutiveErrors + a.maxSameCallRepeats = maxSameCallRepeats + } +} + // WithRequestOptions sets default request options (temperature, max // tokens, ...) applied to every step of every run. func WithRequestOptions(opts ...llm.Option) Option { @@ -134,6 +171,7 @@ type runConfig struct { history []llm.Message reqOpts []llm.Option onStep []func(Step) + steer func() []llm.Message } // WithHistory seeds the run with prior conversation messages (e.g. a @@ -153,6 +191,15 @@ func OnStep(fn func(Step)) RunOption { return func(rc *runConfig) { rc.onStep = append(rc.onStep, fn) } } +// WithSteer installs a steering source for this run: the function is +// drained before every step and any returned messages are appended to the +// conversation — the mechanism for a supervisor nudging a running agent +// ("wrap up", "focus on X"). It is called from Run's goroutine; the +// function owns its own synchronization. +func WithSteer(fn func() []llm.Message) RunOption { + return func(rc *runConfig) { rc.steer = fn } +} + // systemPrompt composes the agent's system prompt with each skill's // instructions, in attachment order. func (a *Agent) systemPrompt() string { @@ -227,8 +274,34 @@ func (a *Agent) Run(ctx context.Context, input string, opts ...RunOption) (*Resu reqOpts := append(append([]llm.Option(nil), a.reqOpts...), rc.reqOpts...) system := a.systemPrompt() - for stepIdx := range a.maxSteps { - req := llm.Request{System: system, Messages: msgs, Tools: ordered} + // Loop-guard state (WithToolErrorLimits). + consecutiveErrorSteps := 0 + callCounts := make(map[string]int) + + maxSteps := func() int { + if a.maxStepsFunc != nil { + if n := a.maxStepsFunc(); n > 0 { + return n + } + } + return a.maxSteps + } + + for stepIdx := 0; stepIdx < maxSteps(); stepIdx++ { + // Steering: drain supervisor nudges into the conversation. + if rc.steer != nil { + msgs = append(msgs, rc.steer()...) + } + + sendMsgs := msgs + if a.compactor != nil { + // Compaction failures are non-fatal: send the original. + if compacted, err := a.compactor(ctx, msgs); err == nil && compacted != nil { + sendMsgs = compacted + } + } + + req := llm.Request{System: system, Messages: sendMsgs, Tools: ordered} resp, err := a.model.Generate(ctx, req, reqOpts...) if err != nil { result.Messages = msgs @@ -249,11 +322,19 @@ func (a *Agent) Run(ctx context.Context, input string, opts ...RunOption) (*Resu } results := make([]llm.ToolResult, 0, len(resp.ToolCalls)) + repeatTripped := "" for _, call := range resp.ToolCalls { if err := ctx.Err(); err != nil { result.Messages = msgs return result, err } + if a.maxSameCallRepeats > 0 { + sig := call.Name + "\x00" + string(call.Arguments) + callCounts[sig]++ + if callCounts[sig] > a.maxSameCallRepeats { + repeatTripped = call.Name + } + } tool, ok := byName[call.Name] if !ok { results = append(results, llm.ToolResult{ @@ -272,10 +353,33 @@ func (a *Agent) Run(ctx context.Context, input string, opts ...RunOption) (*Resu result.Steps = append(result.Steps, step) a.notify(rc, step) msgs = append(msgs, llm.ToolResultsMessage(results...)) + + if repeatTripped != "" { + result.Messages = msgs + return result, fmt.Errorf("%w: %q called identically more than %d times", + ErrToolLoop, repeatTripped, a.maxSameCallRepeats) + } + allErrors := len(results) > 0 + for _, r := range results { + if !r.IsError { + allErrors = false + break + } + } + if allErrors { + consecutiveErrorSteps++ + if a.maxConsecutiveToolErrors > 0 && consecutiveErrorSteps >= a.maxConsecutiveToolErrors { + result.Messages = msgs + return result, fmt.Errorf("%w: %d consecutive steps with only failing tool calls", + ErrToolLoop, consecutiveErrorSteps) + } + } else { + consecutiveErrorSteps = 0 + } } result.Messages = msgs - return result, fmt.Errorf("%w (max %d)", ErrMaxSteps, a.maxSteps) + return result, fmt.Errorf("%w (max %d)", ErrMaxSteps, maxSteps()) } // notify fans a step out to agent observers and run callbacks; observer diff --git a/agent/hooks_test.go b/agent/hooks_test.go new file mode 100644 index 0000000..51e4a60 --- /dev/null +++ b/agent/hooks_test.go @@ -0,0 +1,175 @@ +package agent + +import ( + "context" + "encoding/json" + "errors" + "strings" + "sync/atomic" + "testing" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" + "gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake" +) + +// TestMaxStepsFuncExtendsBudget: a supervisor raising the ceiling mid-run +// lets the loop continue past the static budget. +func TestMaxStepsFuncExtendsBudget(t *testing.T) { + fp := fake.New("fp") + fp.Enqueue("test-model", + toolCallReply("c1", "add", `{"a":1,"b":1}`), + toolCallReply("c2", "add", `{"a":2,"b":2}`), + toolCallReply("c3", "add", `{"a":3,"b":3}`), + fake.Reply("done"), + ) + + var ceiling atomic.Int64 + ceiling.Store(2) + a := New(newModel(t, fp), "", + WithToolbox(adderToolbox(t)), + WithMaxSteps(2), + WithMaxStepsFunc(func() int { return int(ceiling.Load()) }), + WithStepObserver(func(s Step) { + if s.Index == 1 { + ceiling.Store(10) // the "critic" extends the budget + } + }), + ) + res, err := a.Run(context.Background(), "go") + if err != nil { + t.Fatalf("Run: %v (budget should have been extended)", err) + } + if res.Output != "done" || len(res.Steps) != 4 { + t.Errorf("output=%q steps=%d", res.Output, len(res.Steps)) + } +} + +// TestSteerInjectsMessages: steering messages appear in the conversation +// before the next model call. +func TestSteerInjectsMessages(t *testing.T) { + fp := fake.New("fp") + fp.Enqueue("test-model", + toolCallReply("c1", "add", `{"a":1,"b":1}`), + fake.Reply("ok"), + ) + + var pending []llm.Message + pending = append(pending, llm.UserText("SUPERVISOR: wrap it up")) + a := New(newModel(t, fp), "", WithToolbox(adderToolbox(t))) + _, err := a.Run(context.Background(), "go", WithSteer(func() []llm.Message { + out := pending + pending = nil + return out + })) + if err != nil { + t.Fatalf("Run: %v", err) + } + first := fp.Calls()[0].Request.Messages + if len(first) != 2 || !strings.Contains(first[1].Text(), "SUPERVISOR") { + t.Errorf("first call messages = %+v, want steered message", first) + } + // Drained: second call must not duplicate it. + second := fp.Calls()[1].Request.Messages + count := 0 + for _, m := range second { + if strings.Contains(m.Text(), "SUPERVISOR") { + count++ + } + } + if count != 1 { + t.Errorf("steer message appears %d times in second call, want 1", count) + } +} + +// TestCompactorShrinksOutboundContext: the model sees the compacted view; +// the canonical transcript keeps everything. +func TestCompactorShrinksOutboundContext(t *testing.T) { + fp := fake.New("fp") + fp.Enqueue("test-model", fake.Reply("answer")) + + history := []llm.Message{ + llm.UserText("old 1"), llm.AssistantText("old reply 1"), + llm.UserText("old 2"), llm.AssistantText("old reply 2"), + } + a := New(newModel(t, fp), "", WithCompactor(func(_ context.Context, msgs []llm.Message) ([]llm.Message, error) { + // Keep only the last message, prefixed by a synthetic summary. + return append([]llm.Message{llm.UserText("[summary of earlier conversation]")}, msgs[len(msgs)-1]), nil + })) + res, err := a.Run(context.Background(), "new question", WithHistory(history)) + if err != nil { + t.Fatalf("Run: %v", err) + } + sent := fp.Calls()[0].Request.Messages + if len(sent) != 2 || !strings.Contains(sent[0].Text(), "summary") { + t.Errorf("sent = %+v, want compacted view", sent) + } + if len(res.Messages) != 6 { + t.Errorf("transcript = %d messages, want full uncompacted history", len(res.Messages)) + } +} + +// TestCompactorErrorIsNonFatal: a failing compactor falls back to the +// original messages. +func TestCompactorErrorIsNonFatal(t *testing.T) { + fp := fake.New("fp") + fp.Enqueue("test-model", fake.Reply("fine")) + + a := New(newModel(t, fp), "", WithCompactor(func(context.Context, []llm.Message) ([]llm.Message, error) { + return nil, errors.New("summarizer down") + })) + res, err := a.Run(context.Background(), "go") + if err != nil || res.Output != "fine" { + t.Errorf("res=%v err=%v", res, err) + } + if len(fp.Calls()[0].Request.Messages) != 1 { + t.Error("original messages must be sent when compaction fails") + } +} + +// TestConsecutiveToolErrorGuard: steps whose tools ALL fail trip the guard. +func TestConsecutiveToolErrorGuard(t *testing.T) { + fp := fake.New("fp", fake.WithDefault(func(string, llm.Request) fake.Step { + return toolCallReply("c", "bomb", `{}`) + })) + bomb := llm.NewToolbox("danger", llm.Tool{ + Name: "bomb", + Handler: func(context.Context, json.RawMessage) (any, error) { return nil, errors.New("always fails") }, + }) + + a := New(newModel(t, fp), "", WithToolbox(bomb), WithToolErrorLimits(2, 0), WithMaxSteps(10)) + res, err := a.Run(context.Background(), "go") + if !errors.Is(err, ErrToolLoop) { + t.Fatalf("err = %v, want ErrToolLoop", err) + } + if len(res.Steps) != 2 { + t.Errorf("steps = %d, want guard to trip after 2", len(res.Steps)) + } +} + +// TestSameCallRepeatGuard: identical (name+args) calls beyond the limit +// trip the guard; varied calls do not. +func TestSameCallRepeatGuard(t *testing.T) { + fp := fake.New("fp", fake.WithDefault(func(string, llm.Request) fake.Step { + return toolCallReply("c", "add", `{"a":1,"b":1}`) + })) + + a := New(newModel(t, fp), "", WithToolbox(adderToolbox(t)), WithToolErrorLimits(0, 3), WithMaxSteps(10)) + _, err := a.Run(context.Background(), "go") + if !errors.Is(err, ErrToolLoop) || !strings.Contains(err.Error(), `"add"`) { + t.Fatalf("err = %v, want repeat-guard ErrToolLoop naming add", err) + } + + // Varied arguments never trip it. + n := 0 + fp2 := fake.New("fp", fake.WithDefault(func(string, llm.Request) fake.Step { + n++ + if n > 4 { + return fake.Reply("done") + } + return toolCallReply("c", "add", `{"a":1,"b":`+string(rune('0'+n))+`}`) + })) + a2 := New(newModel(t, fp2), "", WithToolbox(adderToolbox(t)), WithToolErrorLimits(0, 3), WithMaxSteps(10)) + if _, err := a2.Run(context.Background(), "go"); err != nil { + t.Errorf("varied calls must not trip the guard: %v", err) + } +} diff --git a/chain.go b/chain.go index 26928de..a86d87f 100644 --- a/chain.go +++ b/chain.go @@ -91,10 +91,17 @@ func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func var zero T var failures []error + observe := func(ev FailoverEvent) { + if c.cfg.Observer != nil { + c.cfg.Observer(ev) + } + } + for _, t := range c.targets { if !c.tracker.Available(t.key) { until := c.tracker.BackedOffUntil(t.key) failures = append(failures, fmt.Errorf("%s: skipped (backed off until %s)", t.key, until.Format("15:04:05.000"))) + observe(FailoverEvent{Target: t.key, Skipped: true}) continue } @@ -119,6 +126,7 @@ func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func class := c.cfg.classify(err) if class == llm.ClassPermanent { + observe(FailoverEvent{Target: t.key, Err: err, Class: class, Attempt: attemptN}) if errors.Is(err, llm.ErrModelNotFound) || errors.Is(err, llm.ErrUnsupported) || c.cfg.AdvanceOnPermanent { // Not a health problem (or policy says keep going): // advance without penalizing the target. @@ -134,6 +142,7 @@ func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func // attempts remain — but advance as soon as the tracker benches // it (a freshly backed-off target is not worth more retries). benched := c.tracker.ReportFailure(t.key) + observe(FailoverEvent{Target: t.key, Err: err, Class: class, Attempt: attemptN, Benched: benched}) if !benched && attemptN < retries { continue } diff --git a/docs/adr/0014-conversion-driven-extensions.md b/docs/adr/0014-conversion-driven-extensions.md new file mode 100644 index 0000000..46b8f14 --- /dev/null +++ b/docs/adr/0014-conversion-driven-extensions.md @@ -0,0 +1,50 @@ +# ADR-0014: Conversion-driven extensions (resolvers, typed tools, hooks, ops controls) + +**Status:** Accepted — 2026-06-10 + +## Context + +Executing the mort conversion (docs/mort-migration.md) surfaced seven +capabilities mort's orchestration layer needs from its model substrate. +Each is general-purpose — none encodes anything mort-specific — and each +was promised in the migration blueprint. + +## Decision + +1. **`Registry.RegisterResolver`** — dynamic alias resolution for DB/ + config-backed tiers. Checked after static aliases (static wins), in + registration order, without holding the registry lock (resolvers do + I/O); output expands recursively under the existing cycle guard. +2. **`DefineTool[Args]`** — typed tools: parameters from `SchemaFor[Args]`, + arguments unmarshaled before the handler. Schema failure panics + (startup-time programming error, mirroring NewToolbox). +3. **`Usage` detail fields** — CacheReadTokens / CacheWriteTokens / + ReasoningTokens, populated where providers report them (Anthropic cache + fields; OpenAI prompt/completion detail objects; Google cached-content + + thoughts). Input/Output remain totals; details are portions. +4. **`WithPromptCaching`** — Anthropic top-level `cache_control` + auto-placement; a no-op for providers that cache implicitly or not at + all. +5. **Agent loop hooks** — `WithMaxStepsFunc` (dynamic ceiling, consulted + every step, for supervisor budget adjustment), `WithSteer` (per-run + message injection drained before each step), `WithCompactor` + (pre-Generate transcript transformation; errors fall back to the + original — losing a request to a broken summarizer is worse than a long + prompt), `WithToolErrorLimits` (consecutive all-error steps and + identical-call repeats end the run with `ErrToolLoop` + partial result). + The canonical transcript in `Result.Messages` is always the + uncompacted truth; compaction affects only what is sent. +6. **Manual health controls** — `Tracker.Bench/Unbench/Snapshot` for + ops surfaces (".failover bench" commands, dashboards). +7. **`ChainConfig.Observer`** — one synchronous callback per failover + decision (failed attempt with classification, bench, benched-skip). + This is a hook, not an observability stack; persistence/metrics remain + the consumer's business (anti-creep guardrail intact). + +## Consequences + +- mort's tier resolver, `.failover` admin surface, failover-event log, + run-critic budget control, steering, and compaction all rebase onto + public API. +- The agent loop gains supervision points without changing its + never-panic, partial-result-on-error contract. diff --git a/docs/adr/README.md b/docs/adr/README.md index 9816e9c..5846dca 100644 --- a/docs/adr/README.md +++ b/docs/adr/README.md @@ -17,3 +17,4 @@ One decision per file, append-only; supersede rather than rewrite. | [0011](0011-google-provider.md) | Google provider on the official Gen AI SDK | Accepted | | [0012](0012-agent-loop.md) | Agent run loop | Accepted | | [0013](0013-skill-model.md) | Skill model — additive instruction+tool bundles | Accepted | +| [0014](0014-conversion-driven-extensions.md) | Conversion-driven extensions (resolvers, typed tools, hooks, ops controls) | Accepted | diff --git a/extensions_test.go b/extensions_test.go new file mode 100644 index 0000000..67ac411 --- /dev/null +++ b/extensions_test.go @@ -0,0 +1,172 @@ +package majordomo + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + "time" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" + "gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake" +) + +// TestRegisterResolver: dynamic tiers resolve after static aliases, expand +// recursively, and respect cycle detection. +func TestRegisterResolver(t *testing.T) { + r := newTestRegistry(t) + r.RegisterProvider(fake.New("fp")) + r.RegisterAlias("static", "fp/a") + + tiers := map[string]string{ + "db-tier": "fp/x,fp/y", + "db-nested": "db-tier,fp/z", + "db-cycle": "db-cycle", + } + r.RegisterResolver(ResolverFunc(func(name string) (string, bool) { + spec, ok := tiers[name] + return spec, ok + })) + + m, err := r.Parse("static,db-nested") + if err != nil { + t.Fatalf("Parse: %v", err) + } + want := []string{"fp/a", "fp/x", "fp/y", "fp/z"} + if got := targetsOf(t, m); strings.Join(got, ",") != strings.Join(want, ",") { + t.Errorf("targets = %v, want %v", got, want) + } + + if _, err := r.Parse("db-cycle"); !errors.Is(err, ErrAliasCycle) { + t.Errorf("cycle error = %v, want ErrAliasCycle", err) + } + + // Static aliases shadow resolvers. + tiers["static"] = "fp/wrong" + m, _ = r.Parse("static") + if got := targetsOf(t, m); got[0] != "fp/a" { + t.Errorf("static alias must win over resolver, got %v", got) + } +} + +// TestDefineTool: schema from Args, decoded handler arguments. +func TestDefineTool(t *testing.T) { + type addArgs struct { + A int `json:"a" description:"first addend"` + B int `json:"b"` + } + tool := DefineTool("add", "Add two ints", func(_ context.Context, args addArgs) (any, error) { + return args.A + args.B, nil + }) + + var schema map[string]any + if err := json.Unmarshal(tool.Parameters, &schema); err != nil { + t.Fatalf("schema: %v", err) + } + props := schema["properties"].(map[string]any) + if props["a"].(map[string]any)["description"] != "first addend" { + t.Errorf("schema = %v", schema) + } + + res := llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "1", Name: "add", Arguments: json.RawMessage(`{"a":2,"b":40}`)}) + if res.IsError || res.Content != "42" { + t.Errorf("result = %+v", res) + } + res = llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "2", Name: "add", Arguments: json.RawMessage(`{"a":"nope"}`)}) + if !res.IsError || !strings.Contains(res.Content, "invalid arguments") { + t.Errorf("bad-args result = %+v", res) + } +} + +// TestChainObserver: failover decisions emit events (attempt, bench, skip). +func TestChainObserver(t *testing.T) { + var events []FailoverEvent + r := newTestRegistry(t, WithChainConfig(ChainConfig{ + Observer: func(ev FailoverEvent) { events = append(events, ev) }, + })) + fp := fake.New("fp") + r.RegisterProvider(fp) + fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a"))) + fp.Enqueue("b", fake.Reply("ok"), fake.Reply("ok")) + + m, _ := r.Parse("fp/a,fp/b") + if _, err := generate(t, m); err != nil { + t.Fatalf("Generate: %v", err) + } + if len(events) != 2 { + t.Fatalf("events = %+v, want 2 failed attempts", events) + } + if events[0].Target != "fp/a" || events[0].Attempt != 0 || events[0].Benched { + t.Errorf("event 0 = %+v", events[0]) + } + if !events[1].Benched { + t.Errorf("event 1 = %+v, want Benched", events[1]) + } + + // Second request: skipped-while-benched event. + events = nil + if _, err := generate(t, m); err != nil { + t.Fatalf("Generate #2: %v", err) + } + if len(events) != 1 || !events[0].Skipped || events[0].Target != "fp/a" { + t.Errorf("events = %+v, want one skip event", events) + } +} + +// TestManualBenchControls: ops surfaces can bench/unbench/inspect. +func TestManualBenchControls(t *testing.T) { + clock := newFakeClock() + r := newTestRegistry(t, WithClock(clock.Now)) + fp := fake.New("fp") + r.RegisterProvider(fp) + fp.Enqueue("b", fake.Reply("from-b")) + + r.Health().Bench("fp/a", clock.Now().Add(time.Hour)) + if r.Health().Available("fp/a") { + t.Fatal("manual bench must take effect") + } + snap := r.Health().Snapshot() + if len(snap) != 1 || snap[0].Key != "fp/a" { + t.Errorf("snapshot = %+v", snap) + } + + m, _ := r.Parse("fp/a,fp/b") + resp, err := generate(t, m) + if err != nil || resp.Text() != "from-b" { + t.Fatalf("resp=%v err=%v (benched head must be skipped)", resp, err) + } + + r.Health().Unbench("fp/a") + if !r.Health().Available("fp/a") { + t.Error("unbench must re-admit") + } + if len(r.Health().Snapshot()) != 0 { + t.Error("snapshot must be empty after unbench") + } +} + +// TestPromptCachingOptionIsCarried: the request flag round-trips (the +// anthropic wire mapping is asserted in its own package). +func TestPromptCachingOptionIsCarried(t *testing.T) { + r := newTestRegistry(t) + fp := fake.New("fp") + r.RegisterProvider(fp) + + m, _ := r.Parse("fp/x") + if _, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}, WithPromptCaching()); err != nil { + t.Fatalf("Generate: %v", err) + } + if !fp.Calls()[0].Request.PromptCache { + t.Error("PromptCache flag must reach the provider") + } +} + +// TestUsageDetailAccumulation: Usage.Add sums the detail fields. +func TestUsageDetailAccumulation(t *testing.T) { + u := Usage{InputTokens: 10, OutputTokens: 5, CacheReadTokens: 4, CacheWriteTokens: 2, ReasoningTokens: 3} + u.Add(Usage{InputTokens: 1, OutputTokens: 1, CacheReadTokens: 1, CacheWriteTokens: 1, ReasoningTokens: 1}) + if u.CacheReadTokens != 5 || u.CacheWriteTokens != 3 || u.ReasoningTokens != 4 { + t.Errorf("usage = %+v", u) + } +} diff --git a/generate.go b/generate.go index b3f39fd..d8f3c38 100644 --- a/generate.go +++ b/generate.go @@ -14,6 +14,12 @@ import ( // suitable for WithSchema and tool parameters. func SchemaFor[T any]() (json.RawMessage, error) { return llm.SchemaFor[T]() } +// DefineTool re-exports llm.DefineTool: a typed tool whose parameter schema +// is derived from Args and whose handler receives decoded arguments. +func DefineTool[Args any](name, description string, fn func(ctx context.Context, args Args) (any, error)) Tool { + return llm.DefineTool(name, description, fn) +} + // Generate performs a structured-output request and unmarshals the result // into T: the schema is derived from T (llm.SchemaFor), injected via the // provider's native structured-output mechanism, and the response text is diff --git a/health/health.go b/health/health.go index e31146c..f33fb90 100644 --- a/health/health.go +++ b/health/health.go @@ -15,6 +15,8 @@ package health import ( + "slices" + "strings" "sync" "time" ) @@ -133,6 +135,57 @@ func (t *Tracker) ReportFailure(key string) (backedOff bool) { return true } +// Bench manually benches a key until the given time, regardless of its +// failure history (ops surfaces: ".failover bench"). A zero time unbenches. +func (t *Tracker) Bench(key string, until time.Time) { + t.mu.Lock() + defer t.mu.Unlock() + e, ok := t.entries[key] + if !ok { + e = &entry{} + t.entries[key] = e + } + e.until = until +} + +// Unbench clears a key entirely: bench window, failure count, and backoff +// history. +func (t *Tracker) Unbench(key string) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.entries, key) +} + +// Status describes one tracked key (diagnostics). +type Status struct { + Key string + // Until is the end of the current bench window (zero = not benched). + Until time.Time + // ConsecutiveFailures since the last success or bench trigger. + ConsecutiveFailures int + // Benches is the consecutive backoff count driving the exponent. + Benches int +} + +// Snapshot returns the status of every currently-benched key, sorted by +// key, evaluated at the tracker's clock. +func (t *Tracker) Snapshot() []Status { + t.mu.Lock() + defer t.mu.Unlock() + now := t.cfg.Clock() + out := make([]Status, 0, len(t.entries)) + for key, e := range t.entries { + if now.Before(e.until) { + out = append(out, Status{ + Key: key, Until: e.until, + ConsecutiveFailures: e.consecutiveFailures, Benches: e.backoffs, + }) + } + } + slices.SortFunc(out, func(a, b Status) int { return strings.Compare(a.Key, b.Key) }) + return out +} + // BackedOffUntil returns the end of the key's current backoff window, or the // zero time when the key is not backed off. Useful for diagnostics and error // messages. diff --git a/llm/request.go b/llm/request.go index ce07c72..a085596 100644 --- a/llm/request.go +++ b/llm/request.go @@ -43,6 +43,11 @@ type Request struct { // Providers map it to their native knob (OpenAI reasoning_effort, // Ollama think levels) and ignore it where no mapping exists. ReasoningEffort string + + // PromptCache opts the request into the provider's prompt caching + // (Anthropic cache_control; ignored by providers that cache + // automatically or not at all). + PromptCache bool } // Option mutates a Request before it is sent. Options passed to Generate or @@ -100,6 +105,12 @@ func WithReasoningEffort(level string) Option { return func(r *Request) { r.ReasoningEffort = level } } +// WithPromptCaching opts into provider prompt caching where it is an +// explicit feature (Anthropic); a no-op elsewhere. +func WithPromptCaching() Option { + return func(r *Request) { r.PromptCache = true } +} + // Apply returns a copy of the request with all options applied. Providers // and wrappers call this once at the top of Generate/Stream. func (r Request) Apply(opts ...Option) Request { diff --git a/llm/response.go b/llm/response.go index 79a871d..068b9e0 100644 --- a/llm/response.go +++ b/llm/response.go @@ -18,10 +18,22 @@ const ( FinishOther FinishReason = "other" ) -// Usage reports token accounting for one request. +// Usage reports token accounting for one request. InputTokens and +// OutputTokens are always totals; the detail fields break out portions of +// those totals where the provider reports them (0 = not reported). type Usage struct { InputTokens int OutputTokens int + + // CacheReadTokens is the portion of InputTokens served from the + // provider's prompt cache. + CacheReadTokens int + // CacheWriteTokens is the portion of InputTokens written to the + // provider's prompt cache. + CacheWriteTokens int + // ReasoningTokens is the portion of OutputTokens spent on + // thinking/reasoning. + ReasoningTokens int } // Total returns input plus output tokens. @@ -31,6 +43,9 @@ func (u Usage) Total() int { return u.InputTokens + u.OutputTokens } func (u *Usage) Add(o Usage) { u.InputTokens += o.InputTokens u.OutputTokens += o.OutputTokens + u.CacheReadTokens += o.CacheReadTokens + u.CacheWriteTokens += o.CacheWriteTokens + u.ReasoningTokens += o.ReasoningTokens } // Response is the canonical generation result. diff --git a/llm/tool.go b/llm/tool.go index ecb1e4c..68414d9 100644 --- a/llm/tool.go +++ b/llm/tool.go @@ -22,6 +22,40 @@ type Tool struct { Handler func(ctx context.Context, args json.RawMessage) (any, error) } +// DefineTool builds a typed tool: the parameter schema is derived from +// Args (see SchemaFor) and the raw JSON arguments are unmarshaled into an +// Args value before the handler runs. +// +// weather := llm.DefineTool("get_weather", "Current weather for a city", +// func(ctx context.Context, args struct { +// City string `json:"city" description:"city name"` +// }) (any, error) { +// return lookup(args.City) +// }) +// +// Schema derivation failures panic: tools are defined at startup and an +// unschematizable Args type is a programming error worth failing loudly on. +func DefineTool[Args any](name, description string, fn func(ctx context.Context, args Args) (any, error)) Tool { + schema, err := SchemaFor[Args]() + if err != nil { + panic(fmt.Sprintf("llm: DefineTool(%q): %v", name, err)) + } + return Tool{ + Name: name, + Description: description, + Parameters: schema, + Handler: func(ctx context.Context, raw json.RawMessage) (any, error) { + var args Args + if len(raw) > 0 { + if err := json.Unmarshal(raw, &args); err != nil { + return nil, fmt.Errorf("invalid arguments for %s: %w", name, err) + } + } + return fn(ctx, args) + }, + } +} + // ToolCall is a model's request to invoke a tool. type ToolCall struct { // ID is the provider-assigned call id; majordomo synthesizes one for diff --git a/majordomo.go b/majordomo.go index d8c4f4d..d5f6639 100644 --- a/majordomo.go +++ b/majordomo.go @@ -100,6 +100,7 @@ func WithTopP(p float64) Option { return llm.WithTop func WithMaxTokens(n int) Option { return llm.WithMaxTokens(n) } func WithStopSequences(stops ...string) Option { return llm.WithStopSequences(stops...) } func WithReasoningEffort(level string) Option { return llm.WithReasoningEffort(level) } +func WithPromptCaching() Option { return llm.WithPromptCaching() } // WithModelCapabilities re-exports llm.WithCapabilities for Provider.Model // calls made through this package. diff --git a/parse.go b/parse.go index 50467dc..f54c1b9 100644 --- a/parse.go +++ b/parse.go @@ -97,12 +97,23 @@ func (r *Registry) expand(spec string, visiting []string) ([]element, error) { continue } - // Bare token: must be a registered alias. + // Bare token: a registered alias, or a dynamic resolver hit. r.mu.RLock() target, isAlias := r.aliases[raw] _, isProvider := r.providers[raw] + resolvers := slices.Clone(r.resolvers) r.mu.RUnlock() + if !isAlias { + // Resolvers run without the lock — they may call back into the + // registry (and DB-backed ones block on I/O). + for _, res := range resolvers { + if spec, ok := res.Resolve(raw); ok && spec != "" { + target, isAlias = spec, true + break + } + } + } if !isAlias { if isProvider { return nil, fmt.Errorf("%q is a provider, not an alias — use %q", raw, raw+"/") diff --git a/progress.md b/progress.md index 78f327a..da070d3 100644 --- a/progress.md +++ b/progress.md @@ -1,5 +1,19 @@ # progress +## 2026-06-10 — Phase 9a: conversion-driven library extensions + +**Landed (ADR-0014):** RegisterResolver (dynamic DB-backed tiers, static +aliases win, recursive + cycle-guarded), DefineTool[Args] (typed tools +over SchemaFor), Usage cache/reasoning detail fields populated by +anthropic/openai/google, WithPromptCaching (Anthropic top-level +cache_control), agent hooks (WithMaxStepsFunc, WithSteer, WithCompactor — +non-fatal on error, canonical transcript stays uncompacted — +WithToolErrorLimits with ErrToolLoop), health Bench/Unbench/Snapshot, +ChainConfig.Observer failover events (attempt/bench/skip). Full hermetic +coverage for each. + +**Next:** Phase 9b — the mort conversion branch. + ## 2026-06-10 — Phase 8: live validation against real Ollama Cloud **All six checks PASS** (examples/live harness, OLLAMA_API_KEY from .env): diff --git a/provider/anthropic/wire.go b/provider/anthropic/wire.go index d80b5da..dba7e14 100644 --- a/provider/anthropic/wire.go +++ b/provider/anthropic/wire.go @@ -24,6 +24,13 @@ type wireRequest struct { TopP *float64 `json:"top_p,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` OutputConfig *wireOutputConfig `json:"output_config,omitempty"` + // CacheControl is the top-level auto-placement form of prompt caching: + // the API puts the breakpoint on the last cacheable block. + CacheControl *wireCacheControl `json:"cache_control,omitempty"` +} + +type wireCacheControl struct { + Type string `json:"type"` } type wireMessage struct { @@ -109,8 +116,10 @@ type wireUsage struct { // real total input is input + cache_creation + cache_read. func (u wireUsage) toUsage() llm.Usage { return llm.Usage{ - InputTokens: u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens, - OutputTokens: u.OutputTokens, + InputTokens: u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens, + OutputTokens: u.OutputTokens, + CacheReadTokens: u.CacheReadInputTokens, + CacheWriteTokens: u.CacheCreationInputTokens, } } @@ -157,6 +166,11 @@ func buildWireRequest(modelID string, req llm.Request, defaultMax int, stream bo Schema: req.Schema, }} } + if req.PromptCache { + // Top-level auto-placement: the API puts the cache breakpoint on + // the last cacheable block. + wr.CacheControl = &wireCacheControl{Type: "ephemeral"} + } return wr } diff --git a/provider/google/google.go b/provider/google/google.go index 2689fec..f19a848 100644 --- a/provider/google/google.go +++ b/provider/google/google.go @@ -364,8 +364,10 @@ func (m *model) toResponse(resp *genai.GenerateContentResponse) *llm.Response { out := &llm.Response{Model: m.qualified(), Raw: resp} if resp.UsageMetadata != nil { out.Usage = llm.Usage{ - InputTokens: int(resp.UsageMetadata.PromptTokenCount), - OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount + resp.UsageMetadata.ThoughtsTokenCount), + InputTokens: int(resp.UsageMetadata.PromptTokenCount), + OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount + resp.UsageMetadata.ThoughtsTokenCount), + CacheReadTokens: int(resp.UsageMetadata.CachedContentTokenCount), + ReasoningTokens: int(resp.UsageMetadata.ThoughtsTokenCount), } } if len(resp.Candidates) == 0 { diff --git a/provider/google/stream.go b/provider/google/stream.go index fe0c0a4..a1b6441 100644 --- a/provider/google/stream.go +++ b/provider/google/stream.go @@ -78,8 +78,10 @@ func (s *stream) Next() (llm.StreamEvent, error) { if chunk.UsageMetadata != nil { s.usage = llm.Usage{ - InputTokens: int(chunk.UsageMetadata.PromptTokenCount), - OutputTokens: int(chunk.UsageMetadata.CandidatesTokenCount + chunk.UsageMetadata.ThoughtsTokenCount), + InputTokens: int(chunk.UsageMetadata.PromptTokenCount), + OutputTokens: int(chunk.UsageMetadata.CandidatesTokenCount + chunk.UsageMetadata.ThoughtsTokenCount), + CacheReadTokens: int(chunk.UsageMetadata.CachedContentTokenCount), + ReasoningTokens: int(chunk.UsageMetadata.ThoughtsTokenCount), } } if len(chunk.Candidates) == 0 { diff --git a/provider/openai/model.go b/provider/openai/model.go index 2744a25..4aa9822 100644 --- a/provider/openai/model.go +++ b/provider/openai/model.go @@ -130,10 +130,7 @@ func (m *model) apiError(httpResp *http.Response) error { func (m *model) toResponse(wire *chatResponse) *llm.Response { resp := &llm.Response{Model: m.p.name + "/" + m.id, Raw: wire} if wire.Usage != nil { - resp.Usage = llm.Usage{ - InputTokens: wire.Usage.PromptTokens, - OutputTokens: wire.Usage.CompletionTokens, - } + resp.Usage = wire.Usage.toUsage() } if len(wire.Choices) == 0 { resp.FinishReason = llm.FinishOther diff --git a/provider/openai/stream.go b/provider/openai/stream.go index a6a7e76..0a88e1c 100644 --- a/provider/openai/stream.go +++ b/provider/openai/stream.go @@ -104,10 +104,7 @@ func (s *stream) handleChunk(data []byte) error { return apiErr } if chunk.Usage != nil { - s.usage = llm.Usage{ - InputTokens: chunk.Usage.PromptTokens, - OutputTokens: chunk.Usage.CompletionTokens, - } + s.usage = chunk.Usage.toUsage() } // Why the guard: the include_usage chunk arrives with an EMPTY choices // array; indexing choices[0] unconditionally would panic on it. diff --git a/provider/openai/wire.go b/provider/openai/wire.go index 5553ccc..3246bb7 100644 --- a/provider/openai/wire.go +++ b/provider/openai/wire.go @@ -125,9 +125,32 @@ type wireRespMessage struct { } type wireUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *wirePromptDetail `json:"prompt_tokens_details"` + CompletionTokensDetails *wireOutputDetail `json:"completion_tokens_details"` +} + +type wirePromptDetail struct { + CachedTokens int `json:"cached_tokens"` +} + +type wireOutputDetail struct { + ReasoningTokens int `json:"reasoning_tokens"` +} + +// toUsage maps wire usage (with optional detail objects — absent on many +// compat servers) onto the canonical Usage. +func (u *wireUsage) toUsage() llm.Usage { + out := llm.Usage{InputTokens: u.PromptTokens, OutputTokens: u.CompletionTokens} + if u.PromptTokensDetails != nil { + out.CacheReadTokens = u.PromptTokensDetails.CachedTokens + } + if u.CompletionTokensDetails != nil { + out.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens + } + return out } type errorEnvelope struct { diff --git a/registry.go b/registry.go index 2992b01..e657e2a 100644 --- a/registry.go +++ b/registry.go @@ -18,6 +18,7 @@ type Registry struct { mu sync.RWMutex providers map[string]llm.Provider aliases map[string]string + resolvers []Resolver schemes map[string]SchemeFactory // envErrs records LLM_* entries that failed to load so the failure // surfaces when (and only when) that provider name is actually used. @@ -28,6 +29,22 @@ type Registry struct { envLookup func(string) string } +// Resolver dynamically resolves a bare alias token to a spec string — +// the hook for DB- or config-backed tiers that change at runtime. Checked +// after static aliases, in registration order; the returned spec is +// expanded recursively (chains, nested aliases) with cycle detection. +type Resolver interface { + // Resolve returns the spec for name, or ok=false when this resolver + // does not handle it. + Resolve(name string) (spec string, ok bool) +} + +// ResolverFunc adapts a function to the Resolver interface. +type ResolverFunc func(name string) (string, bool) + +// Resolve implements Resolver. +func (f ResolverFunc) Resolve(name string) (string, bool) { return f(name) } + // SchemeFactory builds a provider instance from an env DSN. name is the // registry name the provider will be registered under (e.g. "m1" for // LLM_M1); dsn carries the scheme, credential, and host. @@ -49,6 +66,27 @@ type ChainConfig struct { // Classify overrides the default error classifier (llm.Classify). Classify func(error) llm.ErrorClass + + // Observer, when non-nil, receives one event per failover decision + // (failed attempt, bench, benched-skip) — the hook for persisting + // failover logs. Called synchronously; keep it fast or hand off. + Observer func(FailoverEvent) +} + +// FailoverEvent describes one failover decision in a chain. +type FailoverEvent struct { + // Target is the "provider/model" key the event concerns. + Target string + // Err is the failure (nil for benched-skip events). + Err error + // Class is the error classification (meaningful when Err != nil). + Class llm.ErrorClass + // Attempt is the 0-based attempt number on this target in this request. + Attempt int + // Benched reports that this failure benched the target. + Benched bool + // Skipped reports the target was skipped because it is benched. + Skipped bool } // DefaultTransientRetries is the default number of same-target retries @@ -180,6 +218,15 @@ func (r *Registry) RegisterAlias(name, spec string) { r.aliases[name] = spec } +// RegisterResolver appends a dynamic alias resolver (e.g. database-backed +// tiers). Resolvers are consulted in registration order, after static +// aliases. +func (r *Registry) RegisterResolver(res Resolver) { + r.mu.Lock() + defer r.mu.Unlock() + r.resolvers = append(r.resolvers, res) +} + // RegisterScheme adds or replaces an env-DSN scheme factory, letting // consumers wire custom provider kinds into LLM_* definitions. func (r *Registry) RegisterScheme(scheme string, factory SchemeFactory) {