feat: conversion-driven extensions — resolvers, DefineTool, hooks, ops controls
CI / Tidy (push) Successful in 9m31s
CI / Build & Test (push) Successful in 10m13s

Phase 9a (ADR-0014): Registry.RegisterResolver for dynamic tiers;
DefineTool[Args] typed tools; Usage cache/reasoning detail fields wired
through anthropic/openai/google; WithPromptCaching (Anthropic
cache_control); agent supervision hooks (WithMaxStepsFunc, WithSteer,
WithCompactor, WithToolErrorLimits + ErrToolLoop); health
Bench/Unbench/Snapshot; ChainConfig.Observer failover events.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-10 13:30:06 +02:00
parent 04b21fdad2
commit 0147a79d18
21 changed files with 767 additions and 29 deletions
+114 -10
View File
@@ -27,6 +27,11 @@ const DefaultMaxSteps = 10
// carrying the transcript so far. // carrying the transcript so far.
var ErrMaxSteps = errors.New("agent: max steps reached without a final answer") var ErrMaxSteps = errors.New("agent: max steps reached without a final answer")
// ErrToolLoop reports that the loop tripped a tool-error guard
// (consecutive all-error steps or identical repeated calls; see
// WithToolErrorLimits). Run returns it alongside the partial *Result.
var ErrToolLoop = errors.New("agent: tool-error guard tripped")
// Skill is the contract skills satisfy (defined here so agent does not // Skill is the contract skills satisfy (defined here so agent does not
// depend on the skill package; package skill provides implementations). // depend on the skill package; package skill provides implementations).
// Instructions are appended to the agent's system prompt; Tools (optional, // Instructions are appended to the agent's system prompt; Tools (optional,
@@ -67,13 +72,17 @@ type Result struct {
// it later. Agents are safe to share across goroutines only after // it later. Agents are safe to share across goroutines only after
// configuration is complete. // configuration is complete.
type Agent struct { type Agent struct {
model llm.Model model llm.Model
system string system string
toolboxes []*llm.Toolbox toolboxes []*llm.Toolbox
skills []Skill skills []Skill
maxSteps int maxSteps int
reqOpts []llm.Option maxStepsFunc func() int
observers []func(Step) compactor func(ctx context.Context, msgs []llm.Message) ([]llm.Message, error)
maxConsecutiveToolErrors int
maxSameCallRepeats int
reqOpts []llm.Option
observers []func(Step)
} }
// Option configures an Agent at construction. // Option configures an Agent at construction.
@@ -99,6 +108,34 @@ func WithMaxSteps(n int) Option {
return func(a *Agent) { a.maxSteps = n } return func(a *Agent) { a.maxSteps = n }
} }
// WithMaxStepsFunc makes the step ceiling dynamic: the function is
// consulted before every step, so a supervisor can extend (or shrink) a
// running agent's budget. It overrides WithMaxSteps while non-nil; a
// non-positive return falls back to the static value.
func WithMaxStepsFunc(fn func() int) Option {
return func(a *Agent) { a.maxStepsFunc = fn }
}
// WithCompactor installs a context-compaction hook, called with the full
// message slice before every model call; whatever it returns is sent
// instead (e.g. summarize the middle of a long transcript). A compactor
// error is non-fatal: the original messages are used.
func WithCompactor(fn func(ctx context.Context, msgs []llm.Message) ([]llm.Message, error)) Option {
return func(a *Agent) { a.compactor = fn }
}
// WithToolErrorLimits installs loop guards: maxConsecutiveErrors bounds
// successive steps whose tool results were ALL errors, and
// maxSameCallRepeats bounds identical (name + arguments) tool calls within
// one run. Either guard tripping ends the run with ErrToolLoop and the
// partial result. Zero disables a guard.
func WithToolErrorLimits(maxConsecutiveErrors, maxSameCallRepeats int) Option {
return func(a *Agent) {
a.maxConsecutiveToolErrors = maxConsecutiveErrors
a.maxSameCallRepeats = maxSameCallRepeats
}
}
// WithRequestOptions sets default request options (temperature, max // WithRequestOptions sets default request options (temperature, max
// tokens, ...) applied to every step of every run. // tokens, ...) applied to every step of every run.
func WithRequestOptions(opts ...llm.Option) Option { func WithRequestOptions(opts ...llm.Option) Option {
@@ -134,6 +171,7 @@ type runConfig struct {
history []llm.Message history []llm.Message
reqOpts []llm.Option reqOpts []llm.Option
onStep []func(Step) onStep []func(Step)
steer func() []llm.Message
} }
// WithHistory seeds the run with prior conversation messages (e.g. a // WithHistory seeds the run with prior conversation messages (e.g. a
@@ -153,6 +191,15 @@ func OnStep(fn func(Step)) RunOption {
return func(rc *runConfig) { rc.onStep = append(rc.onStep, fn) } return func(rc *runConfig) { rc.onStep = append(rc.onStep, fn) }
} }
// WithSteer installs a steering source for this run: the function is
// drained before every step and any returned messages are appended to the
// conversation — the mechanism for a supervisor nudging a running agent
// ("wrap up", "focus on X"). It is called from Run's goroutine; the
// function owns its own synchronization.
func WithSteer(fn func() []llm.Message) RunOption {
return func(rc *runConfig) { rc.steer = fn }
}
// systemPrompt composes the agent's system prompt with each skill's // systemPrompt composes the agent's system prompt with each skill's
// instructions, in attachment order. // instructions, in attachment order.
func (a *Agent) systemPrompt() string { func (a *Agent) systemPrompt() string {
@@ -227,8 +274,34 @@ func (a *Agent) Run(ctx context.Context, input string, opts ...RunOption) (*Resu
reqOpts := append(append([]llm.Option(nil), a.reqOpts...), rc.reqOpts...) reqOpts := append(append([]llm.Option(nil), a.reqOpts...), rc.reqOpts...)
system := a.systemPrompt() system := a.systemPrompt()
for stepIdx := range a.maxSteps { // Loop-guard state (WithToolErrorLimits).
req := llm.Request{System: system, Messages: msgs, Tools: ordered} consecutiveErrorSteps := 0
callCounts := make(map[string]int)
maxSteps := func() int {
if a.maxStepsFunc != nil {
if n := a.maxStepsFunc(); n > 0 {
return n
}
}
return a.maxSteps
}
for stepIdx := 0; stepIdx < maxSteps(); stepIdx++ {
// Steering: drain supervisor nudges into the conversation.
if rc.steer != nil {
msgs = append(msgs, rc.steer()...)
}
sendMsgs := msgs
if a.compactor != nil {
// Compaction failures are non-fatal: send the original.
if compacted, err := a.compactor(ctx, msgs); err == nil && compacted != nil {
sendMsgs = compacted
}
}
req := llm.Request{System: system, Messages: sendMsgs, Tools: ordered}
resp, err := a.model.Generate(ctx, req, reqOpts...) resp, err := a.model.Generate(ctx, req, reqOpts...)
if err != nil { if err != nil {
result.Messages = msgs result.Messages = msgs
@@ -249,11 +322,19 @@ func (a *Agent) Run(ctx context.Context, input string, opts ...RunOption) (*Resu
} }
results := make([]llm.ToolResult, 0, len(resp.ToolCalls)) results := make([]llm.ToolResult, 0, len(resp.ToolCalls))
repeatTripped := ""
for _, call := range resp.ToolCalls { for _, call := range resp.ToolCalls {
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
result.Messages = msgs result.Messages = msgs
return result, err return result, err
} }
if a.maxSameCallRepeats > 0 {
sig := call.Name + "\x00" + string(call.Arguments)
callCounts[sig]++
if callCounts[sig] > a.maxSameCallRepeats {
repeatTripped = call.Name
}
}
tool, ok := byName[call.Name] tool, ok := byName[call.Name]
if !ok { if !ok {
results = append(results, llm.ToolResult{ results = append(results, llm.ToolResult{
@@ -272,10 +353,33 @@ func (a *Agent) Run(ctx context.Context, input string, opts ...RunOption) (*Resu
result.Steps = append(result.Steps, step) result.Steps = append(result.Steps, step)
a.notify(rc, step) a.notify(rc, step)
msgs = append(msgs, llm.ToolResultsMessage(results...)) msgs = append(msgs, llm.ToolResultsMessage(results...))
if repeatTripped != "" {
result.Messages = msgs
return result, fmt.Errorf("%w: %q called identically more than %d times",
ErrToolLoop, repeatTripped, a.maxSameCallRepeats)
}
allErrors := len(results) > 0
for _, r := range results {
if !r.IsError {
allErrors = false
break
}
}
if allErrors {
consecutiveErrorSteps++
if a.maxConsecutiveToolErrors > 0 && consecutiveErrorSteps >= a.maxConsecutiveToolErrors {
result.Messages = msgs
return result, fmt.Errorf("%w: %d consecutive steps with only failing tool calls",
ErrToolLoop, consecutiveErrorSteps)
}
} else {
consecutiveErrorSteps = 0
}
} }
result.Messages = msgs result.Messages = msgs
return result, fmt.Errorf("%w (max %d)", ErrMaxSteps, a.maxSteps) return result, fmt.Errorf("%w (max %d)", ErrMaxSteps, maxSteps())
} }
// notify fans a step out to agent observers and run callbacks; observer // notify fans a step out to agent observers and run callbacks; observer
+175
View File
@@ -0,0 +1,175 @@
package agent
import (
"context"
"encoding/json"
"errors"
"strings"
"sync/atomic"
"testing"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake"
)
// TestMaxStepsFuncExtendsBudget: a supervisor raising the ceiling mid-run
// lets the loop continue past the static budget.
func TestMaxStepsFuncExtendsBudget(t *testing.T) {
fp := fake.New("fp")
fp.Enqueue("test-model",
toolCallReply("c1", "add", `{"a":1,"b":1}`),
toolCallReply("c2", "add", `{"a":2,"b":2}`),
toolCallReply("c3", "add", `{"a":3,"b":3}`),
fake.Reply("done"),
)
var ceiling atomic.Int64
ceiling.Store(2)
a := New(newModel(t, fp), "",
WithToolbox(adderToolbox(t)),
WithMaxSteps(2),
WithMaxStepsFunc(func() int { return int(ceiling.Load()) }),
WithStepObserver(func(s Step) {
if s.Index == 1 {
ceiling.Store(10) // the "critic" extends the budget
}
}),
)
res, err := a.Run(context.Background(), "go")
if err != nil {
t.Fatalf("Run: %v (budget should have been extended)", err)
}
if res.Output != "done" || len(res.Steps) != 4 {
t.Errorf("output=%q steps=%d", res.Output, len(res.Steps))
}
}
// TestSteerInjectsMessages: steering messages appear in the conversation
// before the next model call.
func TestSteerInjectsMessages(t *testing.T) {
fp := fake.New("fp")
fp.Enqueue("test-model",
toolCallReply("c1", "add", `{"a":1,"b":1}`),
fake.Reply("ok"),
)
var pending []llm.Message
pending = append(pending, llm.UserText("SUPERVISOR: wrap it up"))
a := New(newModel(t, fp), "", WithToolbox(adderToolbox(t)))
_, err := a.Run(context.Background(), "go", WithSteer(func() []llm.Message {
out := pending
pending = nil
return out
}))
if err != nil {
t.Fatalf("Run: %v", err)
}
first := fp.Calls()[0].Request.Messages
if len(first) != 2 || !strings.Contains(first[1].Text(), "SUPERVISOR") {
t.Errorf("first call messages = %+v, want steered message", first)
}
// Drained: second call must not duplicate it.
second := fp.Calls()[1].Request.Messages
count := 0
for _, m := range second {
if strings.Contains(m.Text(), "SUPERVISOR") {
count++
}
}
if count != 1 {
t.Errorf("steer message appears %d times in second call, want 1", count)
}
}
// TestCompactorShrinksOutboundContext: the model sees the compacted view;
// the canonical transcript keeps everything.
func TestCompactorShrinksOutboundContext(t *testing.T) {
fp := fake.New("fp")
fp.Enqueue("test-model", fake.Reply("answer"))
history := []llm.Message{
llm.UserText("old 1"), llm.AssistantText("old reply 1"),
llm.UserText("old 2"), llm.AssistantText("old reply 2"),
}
a := New(newModel(t, fp), "", WithCompactor(func(_ context.Context, msgs []llm.Message) ([]llm.Message, error) {
// Keep only the last message, prefixed by a synthetic summary.
return append([]llm.Message{llm.UserText("[summary of earlier conversation]")}, msgs[len(msgs)-1]), nil
}))
res, err := a.Run(context.Background(), "new question", WithHistory(history))
if err != nil {
t.Fatalf("Run: %v", err)
}
sent := fp.Calls()[0].Request.Messages
if len(sent) != 2 || !strings.Contains(sent[0].Text(), "summary") {
t.Errorf("sent = %+v, want compacted view", sent)
}
if len(res.Messages) != 6 {
t.Errorf("transcript = %d messages, want full uncompacted history", len(res.Messages))
}
}
// TestCompactorErrorIsNonFatal: a failing compactor falls back to the
// original messages.
func TestCompactorErrorIsNonFatal(t *testing.T) {
fp := fake.New("fp")
fp.Enqueue("test-model", fake.Reply("fine"))
a := New(newModel(t, fp), "", WithCompactor(func(context.Context, []llm.Message) ([]llm.Message, error) {
return nil, errors.New("summarizer down")
}))
res, err := a.Run(context.Background(), "go")
if err != nil || res.Output != "fine" {
t.Errorf("res=%v err=%v", res, err)
}
if len(fp.Calls()[0].Request.Messages) != 1 {
t.Error("original messages must be sent when compaction fails")
}
}
// TestConsecutiveToolErrorGuard: steps whose tools ALL fail trip the guard.
func TestConsecutiveToolErrorGuard(t *testing.T) {
fp := fake.New("fp", fake.WithDefault(func(string, llm.Request) fake.Step {
return toolCallReply("c", "bomb", `{}`)
}))
bomb := llm.NewToolbox("danger", llm.Tool{
Name: "bomb",
Handler: func(context.Context, json.RawMessage) (any, error) { return nil, errors.New("always fails") },
})
a := New(newModel(t, fp), "", WithToolbox(bomb), WithToolErrorLimits(2, 0), WithMaxSteps(10))
res, err := a.Run(context.Background(), "go")
if !errors.Is(err, ErrToolLoop) {
t.Fatalf("err = %v, want ErrToolLoop", err)
}
if len(res.Steps) != 2 {
t.Errorf("steps = %d, want guard to trip after 2", len(res.Steps))
}
}
// TestSameCallRepeatGuard: identical (name+args) calls beyond the limit
// trip the guard; varied calls do not.
func TestSameCallRepeatGuard(t *testing.T) {
fp := fake.New("fp", fake.WithDefault(func(string, llm.Request) fake.Step {
return toolCallReply("c", "add", `{"a":1,"b":1}`)
}))
a := New(newModel(t, fp), "", WithToolbox(adderToolbox(t)), WithToolErrorLimits(0, 3), WithMaxSteps(10))
_, err := a.Run(context.Background(), "go")
if !errors.Is(err, ErrToolLoop) || !strings.Contains(err.Error(), `"add"`) {
t.Fatalf("err = %v, want repeat-guard ErrToolLoop naming add", err)
}
// Varied arguments never trip it.
n := 0
fp2 := fake.New("fp", fake.WithDefault(func(string, llm.Request) fake.Step {
n++
if n > 4 {
return fake.Reply("done")
}
return toolCallReply("c", "add", `{"a":1,"b":`+string(rune('0'+n))+`}`)
}))
a2 := New(newModel(t, fp2), "", WithToolbox(adderToolbox(t)), WithToolErrorLimits(0, 3), WithMaxSteps(10))
if _, err := a2.Run(context.Background(), "go"); err != nil {
t.Errorf("varied calls must not trip the guard: %v", err)
}
}
+9
View File
@@ -91,10 +91,17 @@ func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func
var zero T var zero T
var failures []error var failures []error
observe := func(ev FailoverEvent) {
if c.cfg.Observer != nil {
c.cfg.Observer(ev)
}
}
for _, t := range c.targets { for _, t := range c.targets {
if !c.tracker.Available(t.key) { if !c.tracker.Available(t.key) {
until := c.tracker.BackedOffUntil(t.key) until := c.tracker.BackedOffUntil(t.key)
failures = append(failures, fmt.Errorf("%s: skipped (backed off until %s)", t.key, until.Format("15:04:05.000"))) failures = append(failures, fmt.Errorf("%s: skipped (backed off until %s)", t.key, until.Format("15:04:05.000")))
observe(FailoverEvent{Target: t.key, Skipped: true})
continue continue
} }
@@ -119,6 +126,7 @@ func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func
class := c.cfg.classify(err) class := c.cfg.classify(err)
if class == llm.ClassPermanent { if class == llm.ClassPermanent {
observe(FailoverEvent{Target: t.key, Err: err, Class: class, Attempt: attemptN})
if errors.Is(err, llm.ErrModelNotFound) || errors.Is(err, llm.ErrUnsupported) || c.cfg.AdvanceOnPermanent { if errors.Is(err, llm.ErrModelNotFound) || errors.Is(err, llm.ErrUnsupported) || c.cfg.AdvanceOnPermanent {
// Not a health problem (or policy says keep going): // Not a health problem (or policy says keep going):
// advance without penalizing the target. // advance without penalizing the target.
@@ -134,6 +142,7 @@ func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func
// attempts remain — but advance as soon as the tracker benches // attempts remain — but advance as soon as the tracker benches
// it (a freshly backed-off target is not worth more retries). // it (a freshly backed-off target is not worth more retries).
benched := c.tracker.ReportFailure(t.key) benched := c.tracker.ReportFailure(t.key)
observe(FailoverEvent{Target: t.key, Err: err, Class: class, Attempt: attemptN, Benched: benched})
if !benched && attemptN < retries { if !benched && attemptN < retries {
continue continue
} }
@@ -0,0 +1,50 @@
# ADR-0014: Conversion-driven extensions (resolvers, typed tools, hooks, ops controls)
**Status:** Accepted — 2026-06-10
## Context
Executing the mort conversion (docs/mort-migration.md) surfaced seven
capabilities mort's orchestration layer needs from its model substrate.
Each is general-purpose — none encodes anything mort-specific — and each
was promised in the migration blueprint.
## Decision
1. **`Registry.RegisterResolver`** — dynamic alias resolution for DB/
config-backed tiers. Checked after static aliases (static wins), in
registration order, without holding the registry lock (resolvers do
I/O); output expands recursively under the existing cycle guard.
2. **`DefineTool[Args]`** — typed tools: parameters from `SchemaFor[Args]`,
arguments unmarshaled before the handler. Schema failure panics
(startup-time programming error, mirroring NewToolbox).
3. **`Usage` detail fields** — CacheReadTokens / CacheWriteTokens /
ReasoningTokens, populated where providers report them (Anthropic cache
fields; OpenAI prompt/completion detail objects; Google cached-content
+ thoughts). Input/Output remain totals; details are portions.
4. **`WithPromptCaching`** — Anthropic top-level `cache_control`
auto-placement; a no-op for providers that cache implicitly or not at
all.
5. **Agent loop hooks**`WithMaxStepsFunc` (dynamic ceiling, consulted
every step, for supervisor budget adjustment), `WithSteer` (per-run
message injection drained before each step), `WithCompactor`
(pre-Generate transcript transformation; errors fall back to the
original — losing a request to a broken summarizer is worse than a long
prompt), `WithToolErrorLimits` (consecutive all-error steps and
identical-call repeats end the run with `ErrToolLoop` + partial result).
The canonical transcript in `Result.Messages` is always the
uncompacted truth; compaction affects only what is sent.
6. **Manual health controls**`Tracker.Bench/Unbench/Snapshot` for
ops surfaces (".failover bench" commands, dashboards).
7. **`ChainConfig.Observer`** — one synchronous callback per failover
decision (failed attempt with classification, bench, benched-skip).
This is a hook, not an observability stack; persistence/metrics remain
the consumer's business (anti-creep guardrail intact).
## Consequences
- mort's tier resolver, `.failover` admin surface, failover-event log,
run-critic budget control, steering, and compaction all rebase onto
public API.
- The agent loop gains supervision points without changing its
never-panic, partial-result-on-error contract.
+1
View File
@@ -17,3 +17,4 @@ One decision per file, append-only; supersede rather than rewrite.
| [0011](0011-google-provider.md) | Google provider on the official Gen AI SDK | Accepted | | [0011](0011-google-provider.md) | Google provider on the official Gen AI SDK | Accepted |
| [0012](0012-agent-loop.md) | Agent run loop | Accepted | | [0012](0012-agent-loop.md) | Agent run loop | Accepted |
| [0013](0013-skill-model.md) | Skill model — additive instruction+tool bundles | Accepted | | [0013](0013-skill-model.md) | Skill model — additive instruction+tool bundles | Accepted |
| [0014](0014-conversion-driven-extensions.md) | Conversion-driven extensions (resolvers, typed tools, hooks, ops controls) | Accepted |
+172
View File
@@ -0,0 +1,172 @@
package majordomo
import (
"context"
"encoding/json"
"errors"
"strings"
"testing"
"time"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake"
)
// TestRegisterResolver: dynamic tiers resolve after static aliases, expand
// recursively, and respect cycle detection.
func TestRegisterResolver(t *testing.T) {
r := newTestRegistry(t)
r.RegisterProvider(fake.New("fp"))
r.RegisterAlias("static", "fp/a")
tiers := map[string]string{
"db-tier": "fp/x,fp/y",
"db-nested": "db-tier,fp/z",
"db-cycle": "db-cycle",
}
r.RegisterResolver(ResolverFunc(func(name string) (string, bool) {
spec, ok := tiers[name]
return spec, ok
}))
m, err := r.Parse("static,db-nested")
if err != nil {
t.Fatalf("Parse: %v", err)
}
want := []string{"fp/a", "fp/x", "fp/y", "fp/z"}
if got := targetsOf(t, m); strings.Join(got, ",") != strings.Join(want, ",") {
t.Errorf("targets = %v, want %v", got, want)
}
if _, err := r.Parse("db-cycle"); !errors.Is(err, ErrAliasCycle) {
t.Errorf("cycle error = %v, want ErrAliasCycle", err)
}
// Static aliases shadow resolvers.
tiers["static"] = "fp/wrong"
m, _ = r.Parse("static")
if got := targetsOf(t, m); got[0] != "fp/a" {
t.Errorf("static alias must win over resolver, got %v", got)
}
}
// TestDefineTool: schema from Args, decoded handler arguments.
func TestDefineTool(t *testing.T) {
type addArgs struct {
A int `json:"a" description:"first addend"`
B int `json:"b"`
}
tool := DefineTool("add", "Add two ints", func(_ context.Context, args addArgs) (any, error) {
return args.A + args.B, nil
})
var schema map[string]any
if err := json.Unmarshal(tool.Parameters, &schema); err != nil {
t.Fatalf("schema: %v", err)
}
props := schema["properties"].(map[string]any)
if props["a"].(map[string]any)["description"] != "first addend" {
t.Errorf("schema = %v", schema)
}
res := llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "1", Name: "add", Arguments: json.RawMessage(`{"a":2,"b":40}`)})
if res.IsError || res.Content != "42" {
t.Errorf("result = %+v", res)
}
res = llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "2", Name: "add", Arguments: json.RawMessage(`{"a":"nope"}`)})
if !res.IsError || !strings.Contains(res.Content, "invalid arguments") {
t.Errorf("bad-args result = %+v", res)
}
}
// TestChainObserver: failover decisions emit events (attempt, bench, skip).
func TestChainObserver(t *testing.T) {
var events []FailoverEvent
r := newTestRegistry(t, WithChainConfig(ChainConfig{
Observer: func(ev FailoverEvent) { events = append(events, ev) },
}))
fp := fake.New("fp")
r.RegisterProvider(fp)
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")))
fp.Enqueue("b", fake.Reply("ok"), fake.Reply("ok"))
m, _ := r.Parse("fp/a,fp/b")
if _, err := generate(t, m); err != nil {
t.Fatalf("Generate: %v", err)
}
if len(events) != 2 {
t.Fatalf("events = %+v, want 2 failed attempts", events)
}
if events[0].Target != "fp/a" || events[0].Attempt != 0 || events[0].Benched {
t.Errorf("event 0 = %+v", events[0])
}
if !events[1].Benched {
t.Errorf("event 1 = %+v, want Benched", events[1])
}
// Second request: skipped-while-benched event.
events = nil
if _, err := generate(t, m); err != nil {
t.Fatalf("Generate #2: %v", err)
}
if len(events) != 1 || !events[0].Skipped || events[0].Target != "fp/a" {
t.Errorf("events = %+v, want one skip event", events)
}
}
// TestManualBenchControls: ops surfaces can bench/unbench/inspect.
func TestManualBenchControls(t *testing.T) {
clock := newFakeClock()
r := newTestRegistry(t, WithClock(clock.Now))
fp := fake.New("fp")
r.RegisterProvider(fp)
fp.Enqueue("b", fake.Reply("from-b"))
r.Health().Bench("fp/a", clock.Now().Add(time.Hour))
if r.Health().Available("fp/a") {
t.Fatal("manual bench must take effect")
}
snap := r.Health().Snapshot()
if len(snap) != 1 || snap[0].Key != "fp/a" {
t.Errorf("snapshot = %+v", snap)
}
m, _ := r.Parse("fp/a,fp/b")
resp, err := generate(t, m)
if err != nil || resp.Text() != "from-b" {
t.Fatalf("resp=%v err=%v (benched head must be skipped)", resp, err)
}
r.Health().Unbench("fp/a")
if !r.Health().Available("fp/a") {
t.Error("unbench must re-admit")
}
if len(r.Health().Snapshot()) != 0 {
t.Error("snapshot must be empty after unbench")
}
}
// TestPromptCachingOptionIsCarried: the request flag round-trips (the
// anthropic wire mapping is asserted in its own package).
func TestPromptCachingOptionIsCarried(t *testing.T) {
r := newTestRegistry(t)
fp := fake.New("fp")
r.RegisterProvider(fp)
m, _ := r.Parse("fp/x")
if _, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}, WithPromptCaching()); err != nil {
t.Fatalf("Generate: %v", err)
}
if !fp.Calls()[0].Request.PromptCache {
t.Error("PromptCache flag must reach the provider")
}
}
// TestUsageDetailAccumulation: Usage.Add sums the detail fields.
func TestUsageDetailAccumulation(t *testing.T) {
u := Usage{InputTokens: 10, OutputTokens: 5, CacheReadTokens: 4, CacheWriteTokens: 2, ReasoningTokens: 3}
u.Add(Usage{InputTokens: 1, OutputTokens: 1, CacheReadTokens: 1, CacheWriteTokens: 1, ReasoningTokens: 1})
if u.CacheReadTokens != 5 || u.CacheWriteTokens != 3 || u.ReasoningTokens != 4 {
t.Errorf("usage = %+v", u)
}
}
+6
View File
@@ -14,6 +14,12 @@ import (
// suitable for WithSchema and tool parameters. // suitable for WithSchema and tool parameters.
func SchemaFor[T any]() (json.RawMessage, error) { return llm.SchemaFor[T]() } func SchemaFor[T any]() (json.RawMessage, error) { return llm.SchemaFor[T]() }
// DefineTool re-exports llm.DefineTool: a typed tool whose parameter schema
// is derived from Args and whose handler receives decoded arguments.
func DefineTool[Args any](name, description string, fn func(ctx context.Context, args Args) (any, error)) Tool {
return llm.DefineTool(name, description, fn)
}
// Generate performs a structured-output request and unmarshals the result // Generate performs a structured-output request and unmarshals the result
// into T: the schema is derived from T (llm.SchemaFor), injected via the // into T: the schema is derived from T (llm.SchemaFor), injected via the
// provider's native structured-output mechanism, and the response text is // provider's native structured-output mechanism, and the response text is
+53
View File
@@ -15,6 +15,8 @@
package health package health
import ( import (
"slices"
"strings"
"sync" "sync"
"time" "time"
) )
@@ -133,6 +135,57 @@ func (t *Tracker) ReportFailure(key string) (backedOff bool) {
return true return true
} }
// Bench manually benches a key until the given time, regardless of its
// failure history (ops surfaces: ".failover bench"). A zero time unbenches.
func (t *Tracker) Bench(key string, until time.Time) {
t.mu.Lock()
defer t.mu.Unlock()
e, ok := t.entries[key]
if !ok {
e = &entry{}
t.entries[key] = e
}
e.until = until
}
// Unbench clears a key entirely: bench window, failure count, and backoff
// history.
func (t *Tracker) Unbench(key string) {
t.mu.Lock()
defer t.mu.Unlock()
delete(t.entries, key)
}
// Status describes one tracked key (diagnostics).
type Status struct {
Key string
// Until is the end of the current bench window (zero = not benched).
Until time.Time
// ConsecutiveFailures since the last success or bench trigger.
ConsecutiveFailures int
// Benches is the consecutive backoff count driving the exponent.
Benches int
}
// Snapshot returns the status of every currently-benched key, sorted by
// key, evaluated at the tracker's clock.
func (t *Tracker) Snapshot() []Status {
t.mu.Lock()
defer t.mu.Unlock()
now := t.cfg.Clock()
out := make([]Status, 0, len(t.entries))
for key, e := range t.entries {
if now.Before(e.until) {
out = append(out, Status{
Key: key, Until: e.until,
ConsecutiveFailures: e.consecutiveFailures, Benches: e.backoffs,
})
}
}
slices.SortFunc(out, func(a, b Status) int { return strings.Compare(a.Key, b.Key) })
return out
}
// BackedOffUntil returns the end of the key's current backoff window, or the // BackedOffUntil returns the end of the key's current backoff window, or the
// zero time when the key is not backed off. Useful for diagnostics and error // zero time when the key is not backed off. Useful for diagnostics and error
// messages. // messages.
+11
View File
@@ -43,6 +43,11 @@ type Request struct {
// Providers map it to their native knob (OpenAI reasoning_effort, // Providers map it to their native knob (OpenAI reasoning_effort,
// Ollama think levels) and ignore it where no mapping exists. // Ollama think levels) and ignore it where no mapping exists.
ReasoningEffort string ReasoningEffort string
// PromptCache opts the request into the provider's prompt caching
// (Anthropic cache_control; ignored by providers that cache
// automatically or not at all).
PromptCache bool
} }
// Option mutates a Request before it is sent. Options passed to Generate or // Option mutates a Request before it is sent. Options passed to Generate or
@@ -100,6 +105,12 @@ func WithReasoningEffort(level string) Option {
return func(r *Request) { r.ReasoningEffort = level } return func(r *Request) { r.ReasoningEffort = level }
} }
// WithPromptCaching opts into provider prompt caching where it is an
// explicit feature (Anthropic); a no-op elsewhere.
func WithPromptCaching() Option {
return func(r *Request) { r.PromptCache = true }
}
// Apply returns a copy of the request with all options applied. Providers // Apply returns a copy of the request with all options applied. Providers
// and wrappers call this once at the top of Generate/Stream. // and wrappers call this once at the top of Generate/Stream.
func (r Request) Apply(opts ...Option) Request { func (r Request) Apply(opts ...Option) Request {
+16 -1
View File
@@ -18,10 +18,22 @@ const (
FinishOther FinishReason = "other" FinishOther FinishReason = "other"
) )
// Usage reports token accounting for one request. // Usage reports token accounting for one request. InputTokens and
// OutputTokens are always totals; the detail fields break out portions of
// those totals where the provider reports them (0 = not reported).
type Usage struct { type Usage struct {
InputTokens int InputTokens int
OutputTokens int OutputTokens int
// CacheReadTokens is the portion of InputTokens served from the
// provider's prompt cache.
CacheReadTokens int
// CacheWriteTokens is the portion of InputTokens written to the
// provider's prompt cache.
CacheWriteTokens int
// ReasoningTokens is the portion of OutputTokens spent on
// thinking/reasoning.
ReasoningTokens int
} }
// Total returns input plus output tokens. // Total returns input plus output tokens.
@@ -31,6 +43,9 @@ func (u Usage) Total() int { return u.InputTokens + u.OutputTokens }
func (u *Usage) Add(o Usage) { func (u *Usage) Add(o Usage) {
u.InputTokens += o.InputTokens u.InputTokens += o.InputTokens
u.OutputTokens += o.OutputTokens u.OutputTokens += o.OutputTokens
u.CacheReadTokens += o.CacheReadTokens
u.CacheWriteTokens += o.CacheWriteTokens
u.ReasoningTokens += o.ReasoningTokens
} }
// Response is the canonical generation result. // Response is the canonical generation result.
+34
View File
@@ -22,6 +22,40 @@ type Tool struct {
Handler func(ctx context.Context, args json.RawMessage) (any, error) Handler func(ctx context.Context, args json.RawMessage) (any, error)
} }
// DefineTool builds a typed tool: the parameter schema is derived from
// Args (see SchemaFor) and the raw JSON arguments are unmarshaled into an
// Args value before the handler runs.
//
// weather := llm.DefineTool("get_weather", "Current weather for a city",
// func(ctx context.Context, args struct {
// City string `json:"city" description:"city name"`
// }) (any, error) {
// return lookup(args.City)
// })
//
// Schema derivation failures panic: tools are defined at startup and an
// unschematizable Args type is a programming error worth failing loudly on.
func DefineTool[Args any](name, description string, fn func(ctx context.Context, args Args) (any, error)) Tool {
schema, err := SchemaFor[Args]()
if err != nil {
panic(fmt.Sprintf("llm: DefineTool(%q): %v", name, err))
}
return Tool{
Name: name,
Description: description,
Parameters: schema,
Handler: func(ctx context.Context, raw json.RawMessage) (any, error) {
var args Args
if len(raw) > 0 {
if err := json.Unmarshal(raw, &args); err != nil {
return nil, fmt.Errorf("invalid arguments for %s: %w", name, err)
}
}
return fn(ctx, args)
},
}
}
// ToolCall is a model's request to invoke a tool. // ToolCall is a model's request to invoke a tool.
type ToolCall struct { type ToolCall struct {
// ID is the provider-assigned call id; majordomo synthesizes one for // ID is the provider-assigned call id; majordomo synthesizes one for
+1
View File
@@ -100,6 +100,7 @@ func WithTopP(p float64) Option { return llm.WithTop
func WithMaxTokens(n int) Option { return llm.WithMaxTokens(n) } func WithMaxTokens(n int) Option { return llm.WithMaxTokens(n) }
func WithStopSequences(stops ...string) Option { return llm.WithStopSequences(stops...) } func WithStopSequences(stops ...string) Option { return llm.WithStopSequences(stops...) }
func WithReasoningEffort(level string) Option { return llm.WithReasoningEffort(level) } func WithReasoningEffort(level string) Option { return llm.WithReasoningEffort(level) }
func WithPromptCaching() Option { return llm.WithPromptCaching() }
// WithModelCapabilities re-exports llm.WithCapabilities for Provider.Model // WithModelCapabilities re-exports llm.WithCapabilities for Provider.Model
// calls made through this package. // calls made through this package.
+12 -1
View File
@@ -97,12 +97,23 @@ func (r *Registry) expand(spec string, visiting []string) ([]element, error) {
continue continue
} }
// Bare token: must be a registered alias. // Bare token: a registered alias, or a dynamic resolver hit.
r.mu.RLock() r.mu.RLock()
target, isAlias := r.aliases[raw] target, isAlias := r.aliases[raw]
_, isProvider := r.providers[raw] _, isProvider := r.providers[raw]
resolvers := slices.Clone(r.resolvers)
r.mu.RUnlock() r.mu.RUnlock()
if !isAlias {
// Resolvers run without the lock — they may call back into the
// registry (and DB-backed ones block on I/O).
for _, res := range resolvers {
if spec, ok := res.Resolve(raw); ok && spec != "" {
target, isAlias = spec, true
break
}
}
}
if !isAlias { if !isAlias {
if isProvider { if isProvider {
return nil, fmt.Errorf("%q is a provider, not an alias — use %q", raw, raw+"/<model-id>") return nil, fmt.Errorf("%q is a provider, not an alias — use %q", raw, raw+"/<model-id>")
+14
View File
@@ -1,5 +1,19 @@
# progress # progress
## 2026-06-10 — Phase 9a: conversion-driven library extensions
**Landed (ADR-0014):** RegisterResolver (dynamic DB-backed tiers, static
aliases win, recursive + cycle-guarded), DefineTool[Args] (typed tools
over SchemaFor), Usage cache/reasoning detail fields populated by
anthropic/openai/google, WithPromptCaching (Anthropic top-level
cache_control), agent hooks (WithMaxStepsFunc, WithSteer, WithCompactor —
non-fatal on error, canonical transcript stays uncompacted —
WithToolErrorLimits with ErrToolLoop), health Bench/Unbench/Snapshot,
ChainConfig.Observer failover events (attempt/bench/skip). Full hermetic
coverage for each.
**Next:** Phase 9b — the mort conversion branch.
## 2026-06-10 — Phase 8: live validation against real Ollama Cloud ## 2026-06-10 — Phase 8: live validation against real Ollama Cloud
**All six checks PASS** (examples/live harness, OLLAMA_API_KEY from .env): **All six checks PASS** (examples/live harness, OLLAMA_API_KEY from .env):
+16 -2
View File
@@ -24,6 +24,13 @@ type wireRequest struct {
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
OutputConfig *wireOutputConfig `json:"output_config,omitempty"` OutputConfig *wireOutputConfig `json:"output_config,omitempty"`
// CacheControl is the top-level auto-placement form of prompt caching:
// the API puts the breakpoint on the last cacheable block.
CacheControl *wireCacheControl `json:"cache_control,omitempty"`
}
type wireCacheControl struct {
Type string `json:"type"`
} }
type wireMessage struct { type wireMessage struct {
@@ -109,8 +116,10 @@ type wireUsage struct {
// real total input is input + cache_creation + cache_read. // real total input is input + cache_creation + cache_read.
func (u wireUsage) toUsage() llm.Usage { func (u wireUsage) toUsage() llm.Usage {
return llm.Usage{ return llm.Usage{
InputTokens: u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens, InputTokens: u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens,
OutputTokens: u.OutputTokens, OutputTokens: u.OutputTokens,
CacheReadTokens: u.CacheReadInputTokens,
CacheWriteTokens: u.CacheCreationInputTokens,
} }
} }
@@ -157,6 +166,11 @@ func buildWireRequest(modelID string, req llm.Request, defaultMax int, stream bo
Schema: req.Schema, Schema: req.Schema,
}} }}
} }
if req.PromptCache {
// Top-level auto-placement: the API puts the cache breakpoint on
// the last cacheable block.
wr.CacheControl = &wireCacheControl{Type: "ephemeral"}
}
return wr return wr
} }
+4 -2
View File
@@ -364,8 +364,10 @@ func (m *model) toResponse(resp *genai.GenerateContentResponse) *llm.Response {
out := &llm.Response{Model: m.qualified(), Raw: resp} out := &llm.Response{Model: m.qualified(), Raw: resp}
if resp.UsageMetadata != nil { if resp.UsageMetadata != nil {
out.Usage = llm.Usage{ out.Usage = llm.Usage{
InputTokens: int(resp.UsageMetadata.PromptTokenCount), InputTokens: int(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount + resp.UsageMetadata.ThoughtsTokenCount), OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount + resp.UsageMetadata.ThoughtsTokenCount),
CacheReadTokens: int(resp.UsageMetadata.CachedContentTokenCount),
ReasoningTokens: int(resp.UsageMetadata.ThoughtsTokenCount),
} }
} }
if len(resp.Candidates) == 0 { if len(resp.Candidates) == 0 {
+4 -2
View File
@@ -78,8 +78,10 @@ func (s *stream) Next() (llm.StreamEvent, error) {
if chunk.UsageMetadata != nil { if chunk.UsageMetadata != nil {
s.usage = llm.Usage{ s.usage = llm.Usage{
InputTokens: int(chunk.UsageMetadata.PromptTokenCount), InputTokens: int(chunk.UsageMetadata.PromptTokenCount),
OutputTokens: int(chunk.UsageMetadata.CandidatesTokenCount + chunk.UsageMetadata.ThoughtsTokenCount), OutputTokens: int(chunk.UsageMetadata.CandidatesTokenCount + chunk.UsageMetadata.ThoughtsTokenCount),
CacheReadTokens: int(chunk.UsageMetadata.CachedContentTokenCount),
ReasoningTokens: int(chunk.UsageMetadata.ThoughtsTokenCount),
} }
} }
if len(chunk.Candidates) == 0 { if len(chunk.Candidates) == 0 {
+1 -4
View File
@@ -130,10 +130,7 @@ func (m *model) apiError(httpResp *http.Response) error {
func (m *model) toResponse(wire *chatResponse) *llm.Response { func (m *model) toResponse(wire *chatResponse) *llm.Response {
resp := &llm.Response{Model: m.p.name + "/" + m.id, Raw: wire} resp := &llm.Response{Model: m.p.name + "/" + m.id, Raw: wire}
if wire.Usage != nil { if wire.Usage != nil {
resp.Usage = llm.Usage{ resp.Usage = wire.Usage.toUsage()
InputTokens: wire.Usage.PromptTokens,
OutputTokens: wire.Usage.CompletionTokens,
}
} }
if len(wire.Choices) == 0 { if len(wire.Choices) == 0 {
resp.FinishReason = llm.FinishOther resp.FinishReason = llm.FinishOther
+1 -4
View File
@@ -104,10 +104,7 @@ func (s *stream) handleChunk(data []byte) error {
return apiErr return apiErr
} }
if chunk.Usage != nil { if chunk.Usage != nil {
s.usage = llm.Usage{ s.usage = chunk.Usage.toUsage()
InputTokens: chunk.Usage.PromptTokens,
OutputTokens: chunk.Usage.CompletionTokens,
}
} }
// Why the guard: the include_usage chunk arrives with an EMPTY choices // Why the guard: the include_usage chunk arrives with an EMPTY choices
// array; indexing choices[0] unconditionally would panic on it. // array; indexing choices[0] unconditionally would panic on it.
+26 -3
View File
@@ -125,9 +125,32 @@ type wireRespMessage struct {
} }
type wireUsage struct { type wireUsage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
PromptTokensDetails *wirePromptDetail `json:"prompt_tokens_details"`
CompletionTokensDetails *wireOutputDetail `json:"completion_tokens_details"`
}
type wirePromptDetail struct {
CachedTokens int `json:"cached_tokens"`
}
type wireOutputDetail struct {
ReasoningTokens int `json:"reasoning_tokens"`
}
// toUsage maps wire usage (with optional detail objects — absent on many
// compat servers) onto the canonical Usage.
func (u *wireUsage) toUsage() llm.Usage {
out := llm.Usage{InputTokens: u.PromptTokens, OutputTokens: u.CompletionTokens}
if u.PromptTokensDetails != nil {
out.CacheReadTokens = u.PromptTokensDetails.CachedTokens
}
if u.CompletionTokensDetails != nil {
out.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens
}
return out
} }
type errorEnvelope struct { type errorEnvelope struct {
+47
View File
@@ -18,6 +18,7 @@ type Registry struct {
mu sync.RWMutex mu sync.RWMutex
providers map[string]llm.Provider providers map[string]llm.Provider
aliases map[string]string aliases map[string]string
resolvers []Resolver
schemes map[string]SchemeFactory schemes map[string]SchemeFactory
// envErrs records LLM_* entries that failed to load so the failure // envErrs records LLM_* entries that failed to load so the failure
// surfaces when (and only when) that provider name is actually used. // surfaces when (and only when) that provider name is actually used.
@@ -28,6 +29,22 @@ type Registry struct {
envLookup func(string) string envLookup func(string) string
} }
// Resolver dynamically resolves a bare alias token to a spec string —
// the hook for DB- or config-backed tiers that change at runtime. Checked
// after static aliases, in registration order; the returned spec is
// expanded recursively (chains, nested aliases) with cycle detection.
type Resolver interface {
// Resolve returns the spec for name, or ok=false when this resolver
// does not handle it.
Resolve(name string) (spec string, ok bool)
}
// ResolverFunc adapts a function to the Resolver interface.
type ResolverFunc func(name string) (string, bool)
// Resolve implements Resolver.
func (f ResolverFunc) Resolve(name string) (string, bool) { return f(name) }
// SchemeFactory builds a provider instance from an env DSN. name is the // SchemeFactory builds a provider instance from an env DSN. name is the
// registry name the provider will be registered under (e.g. "m1" for // registry name the provider will be registered under (e.g. "m1" for
// LLM_M1); dsn carries the scheme, credential, and host. // LLM_M1); dsn carries the scheme, credential, and host.
@@ -49,6 +66,27 @@ type ChainConfig struct {
// Classify overrides the default error classifier (llm.Classify). // Classify overrides the default error classifier (llm.Classify).
Classify func(error) llm.ErrorClass Classify func(error) llm.ErrorClass
// Observer, when non-nil, receives one event per failover decision
// (failed attempt, bench, benched-skip) — the hook for persisting
// failover logs. Called synchronously; keep it fast or hand off.
Observer func(FailoverEvent)
}
// FailoverEvent describes one failover decision in a chain.
type FailoverEvent struct {
// Target is the "provider/model" key the event concerns.
Target string
// Err is the failure (nil for benched-skip events).
Err error
// Class is the error classification (meaningful when Err != nil).
Class llm.ErrorClass
// Attempt is the 0-based attempt number on this target in this request.
Attempt int
// Benched reports that this failure benched the target.
Benched bool
// Skipped reports the target was skipped because it is benched.
Skipped bool
} }
// DefaultTransientRetries is the default number of same-target retries // DefaultTransientRetries is the default number of same-target retries
@@ -180,6 +218,15 @@ func (r *Registry) RegisterAlias(name, spec string) {
r.aliases[name] = spec r.aliases[name] = spec
} }
// RegisterResolver appends a dynamic alias resolver (e.g. database-backed
// tiers). Resolvers are consulted in registration order, after static
// aliases.
func (r *Registry) RegisterResolver(res Resolver) {
r.mu.Lock()
defer r.mu.Unlock()
r.resolvers = append(r.resolvers, res)
}
// RegisterScheme adds or replaces an env-DSN scheme factory, letting // RegisterScheme adds or replaces an env-DSN scheme factory, letting
// consumers wire custom provider kinds into LLM_* definitions. // consumers wire custom provider kinds into LLM_* definitions.
func (r *Registry) RegisterScheme(scheme string, factory SchemeFactory) { func (r *Registry) RegisterScheme(scheme string, factory SchemeFactory) {