feat: conversion-driven extensions — resolvers, DefineTool, hooks, ops controls
Phase 9a (ADR-0014): Registry.RegisterResolver for dynamic tiers; DefineTool[Args] typed tools; Usage cache/reasoning detail fields wired through anthropic/openai/google; WithPromptCaching (Anthropic cache_control); agent supervision hooks (WithMaxStepsFunc, WithSteer, WithCompactor, WithToolErrorLimits + ErrToolLoop); health Bench/Unbench/Snapshot; ChainConfig.Observer failover events. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
+114
-10
@@ -27,6 +27,11 @@ const DefaultMaxSteps = 10
|
||||
// carrying the transcript so far.
|
||||
var ErrMaxSteps = errors.New("agent: max steps reached without a final answer")
|
||||
|
||||
// ErrToolLoop reports that the loop tripped a tool-error guard
|
||||
// (consecutive all-error steps or identical repeated calls; see
|
||||
// WithToolErrorLimits). Run returns it alongside the partial *Result.
|
||||
var ErrToolLoop = errors.New("agent: tool-error guard tripped")
|
||||
|
||||
// Skill is the contract skills satisfy (defined here so agent does not
|
||||
// depend on the skill package; package skill provides implementations).
|
||||
// Instructions are appended to the agent's system prompt; Tools (optional,
|
||||
@@ -67,13 +72,17 @@ type Result struct {
|
||||
// it later. Agents are safe to share across goroutines only after
|
||||
// configuration is complete.
|
||||
type Agent struct {
|
||||
model llm.Model
|
||||
system string
|
||||
toolboxes []*llm.Toolbox
|
||||
skills []Skill
|
||||
maxSteps int
|
||||
reqOpts []llm.Option
|
||||
observers []func(Step)
|
||||
model llm.Model
|
||||
system string
|
||||
toolboxes []*llm.Toolbox
|
||||
skills []Skill
|
||||
maxSteps int
|
||||
maxStepsFunc func() int
|
||||
compactor func(ctx context.Context, msgs []llm.Message) ([]llm.Message, error)
|
||||
maxConsecutiveToolErrors int
|
||||
maxSameCallRepeats int
|
||||
reqOpts []llm.Option
|
||||
observers []func(Step)
|
||||
}
|
||||
|
||||
// Option configures an Agent at construction.
|
||||
@@ -99,6 +108,34 @@ func WithMaxSteps(n int) Option {
|
||||
return func(a *Agent) { a.maxSteps = n }
|
||||
}
|
||||
|
||||
// WithMaxStepsFunc makes the step ceiling dynamic: the function is
|
||||
// consulted before every step, so a supervisor can extend (or shrink) a
|
||||
// running agent's budget. It overrides WithMaxSteps while non-nil; a
|
||||
// non-positive return falls back to the static value.
|
||||
func WithMaxStepsFunc(fn func() int) Option {
|
||||
return func(a *Agent) { a.maxStepsFunc = fn }
|
||||
}
|
||||
|
||||
// WithCompactor installs a context-compaction hook, called with the full
|
||||
// message slice before every model call; whatever it returns is sent
|
||||
// instead (e.g. summarize the middle of a long transcript). A compactor
|
||||
// error is non-fatal: the original messages are used.
|
||||
func WithCompactor(fn func(ctx context.Context, msgs []llm.Message) ([]llm.Message, error)) Option {
|
||||
return func(a *Agent) { a.compactor = fn }
|
||||
}
|
||||
|
||||
// WithToolErrorLimits installs loop guards: maxConsecutiveErrors bounds
|
||||
// successive steps whose tool results were ALL errors, and
|
||||
// maxSameCallRepeats bounds identical (name + arguments) tool calls within
|
||||
// one run. Either guard tripping ends the run with ErrToolLoop and the
|
||||
// partial result. Zero disables a guard.
|
||||
func WithToolErrorLimits(maxConsecutiveErrors, maxSameCallRepeats int) Option {
|
||||
return func(a *Agent) {
|
||||
a.maxConsecutiveToolErrors = maxConsecutiveErrors
|
||||
a.maxSameCallRepeats = maxSameCallRepeats
|
||||
}
|
||||
}
|
||||
|
||||
// WithRequestOptions sets default request options (temperature, max
|
||||
// tokens, ...) applied to every step of every run.
|
||||
func WithRequestOptions(opts ...llm.Option) Option {
|
||||
@@ -134,6 +171,7 @@ type runConfig struct {
|
||||
history []llm.Message
|
||||
reqOpts []llm.Option
|
||||
onStep []func(Step)
|
||||
steer func() []llm.Message
|
||||
}
|
||||
|
||||
// WithHistory seeds the run with prior conversation messages (e.g. a
|
||||
@@ -153,6 +191,15 @@ func OnStep(fn func(Step)) RunOption {
|
||||
return func(rc *runConfig) { rc.onStep = append(rc.onStep, fn) }
|
||||
}
|
||||
|
||||
// WithSteer installs a steering source for this run: the function is
|
||||
// drained before every step and any returned messages are appended to the
|
||||
// conversation — the mechanism for a supervisor nudging a running agent
|
||||
// ("wrap up", "focus on X"). It is called from Run's goroutine; the
|
||||
// function owns its own synchronization.
|
||||
func WithSteer(fn func() []llm.Message) RunOption {
|
||||
return func(rc *runConfig) { rc.steer = fn }
|
||||
}
|
||||
|
||||
// systemPrompt composes the agent's system prompt with each skill's
|
||||
// instructions, in attachment order.
|
||||
func (a *Agent) systemPrompt() string {
|
||||
@@ -227,8 +274,34 @@ func (a *Agent) Run(ctx context.Context, input string, opts ...RunOption) (*Resu
|
||||
reqOpts := append(append([]llm.Option(nil), a.reqOpts...), rc.reqOpts...)
|
||||
system := a.systemPrompt()
|
||||
|
||||
for stepIdx := range a.maxSteps {
|
||||
req := llm.Request{System: system, Messages: msgs, Tools: ordered}
|
||||
// Loop-guard state (WithToolErrorLimits).
|
||||
consecutiveErrorSteps := 0
|
||||
callCounts := make(map[string]int)
|
||||
|
||||
maxSteps := func() int {
|
||||
if a.maxStepsFunc != nil {
|
||||
if n := a.maxStepsFunc(); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return a.maxSteps
|
||||
}
|
||||
|
||||
for stepIdx := 0; stepIdx < maxSteps(); stepIdx++ {
|
||||
// Steering: drain supervisor nudges into the conversation.
|
||||
if rc.steer != nil {
|
||||
msgs = append(msgs, rc.steer()...)
|
||||
}
|
||||
|
||||
sendMsgs := msgs
|
||||
if a.compactor != nil {
|
||||
// Compaction failures are non-fatal: send the original.
|
||||
if compacted, err := a.compactor(ctx, msgs); err == nil && compacted != nil {
|
||||
sendMsgs = compacted
|
||||
}
|
||||
}
|
||||
|
||||
req := llm.Request{System: system, Messages: sendMsgs, Tools: ordered}
|
||||
resp, err := a.model.Generate(ctx, req, reqOpts...)
|
||||
if err != nil {
|
||||
result.Messages = msgs
|
||||
@@ -249,11 +322,19 @@ func (a *Agent) Run(ctx context.Context, input string, opts ...RunOption) (*Resu
|
||||
}
|
||||
|
||||
results := make([]llm.ToolResult, 0, len(resp.ToolCalls))
|
||||
repeatTripped := ""
|
||||
for _, call := range resp.ToolCalls {
|
||||
if err := ctx.Err(); err != nil {
|
||||
result.Messages = msgs
|
||||
return result, err
|
||||
}
|
||||
if a.maxSameCallRepeats > 0 {
|
||||
sig := call.Name + "\x00" + string(call.Arguments)
|
||||
callCounts[sig]++
|
||||
if callCounts[sig] > a.maxSameCallRepeats {
|
||||
repeatTripped = call.Name
|
||||
}
|
||||
}
|
||||
tool, ok := byName[call.Name]
|
||||
if !ok {
|
||||
results = append(results, llm.ToolResult{
|
||||
@@ -272,10 +353,33 @@ func (a *Agent) Run(ctx context.Context, input string, opts ...RunOption) (*Resu
|
||||
result.Steps = append(result.Steps, step)
|
||||
a.notify(rc, step)
|
||||
msgs = append(msgs, llm.ToolResultsMessage(results...))
|
||||
|
||||
if repeatTripped != "" {
|
||||
result.Messages = msgs
|
||||
return result, fmt.Errorf("%w: %q called identically more than %d times",
|
||||
ErrToolLoop, repeatTripped, a.maxSameCallRepeats)
|
||||
}
|
||||
allErrors := len(results) > 0
|
||||
for _, r := range results {
|
||||
if !r.IsError {
|
||||
allErrors = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allErrors {
|
||||
consecutiveErrorSteps++
|
||||
if a.maxConsecutiveToolErrors > 0 && consecutiveErrorSteps >= a.maxConsecutiveToolErrors {
|
||||
result.Messages = msgs
|
||||
return result, fmt.Errorf("%w: %d consecutive steps with only failing tool calls",
|
||||
ErrToolLoop, consecutiveErrorSteps)
|
||||
}
|
||||
} else {
|
||||
consecutiveErrorSteps = 0
|
||||
}
|
||||
}
|
||||
|
||||
result.Messages = msgs
|
||||
return result, fmt.Errorf("%w (max %d)", ErrMaxSteps, a.maxSteps)
|
||||
return result, fmt.Errorf("%w (max %d)", ErrMaxSteps, maxSteps())
|
||||
}
|
||||
|
||||
// notify fans a step out to agent observers and run callbacks; observer
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake"
|
||||
)
|
||||
|
||||
// TestMaxStepsFuncExtendsBudget: a supervisor raising the ceiling mid-run
|
||||
// lets the loop continue past the static budget.
|
||||
func TestMaxStepsFuncExtendsBudget(t *testing.T) {
|
||||
fp := fake.New("fp")
|
||||
fp.Enqueue("test-model",
|
||||
toolCallReply("c1", "add", `{"a":1,"b":1}`),
|
||||
toolCallReply("c2", "add", `{"a":2,"b":2}`),
|
||||
toolCallReply("c3", "add", `{"a":3,"b":3}`),
|
||||
fake.Reply("done"),
|
||||
)
|
||||
|
||||
var ceiling atomic.Int64
|
||||
ceiling.Store(2)
|
||||
a := New(newModel(t, fp), "",
|
||||
WithToolbox(adderToolbox(t)),
|
||||
WithMaxSteps(2),
|
||||
WithMaxStepsFunc(func() int { return int(ceiling.Load()) }),
|
||||
WithStepObserver(func(s Step) {
|
||||
if s.Index == 1 {
|
||||
ceiling.Store(10) // the "critic" extends the budget
|
||||
}
|
||||
}),
|
||||
)
|
||||
res, err := a.Run(context.Background(), "go")
|
||||
if err != nil {
|
||||
t.Fatalf("Run: %v (budget should have been extended)", err)
|
||||
}
|
||||
if res.Output != "done" || len(res.Steps) != 4 {
|
||||
t.Errorf("output=%q steps=%d", res.Output, len(res.Steps))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSteerInjectsMessages: steering messages appear in the conversation
|
||||
// before the next model call.
|
||||
func TestSteerInjectsMessages(t *testing.T) {
|
||||
fp := fake.New("fp")
|
||||
fp.Enqueue("test-model",
|
||||
toolCallReply("c1", "add", `{"a":1,"b":1}`),
|
||||
fake.Reply("ok"),
|
||||
)
|
||||
|
||||
var pending []llm.Message
|
||||
pending = append(pending, llm.UserText("SUPERVISOR: wrap it up"))
|
||||
a := New(newModel(t, fp), "", WithToolbox(adderToolbox(t)))
|
||||
_, err := a.Run(context.Background(), "go", WithSteer(func() []llm.Message {
|
||||
out := pending
|
||||
pending = nil
|
||||
return out
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
first := fp.Calls()[0].Request.Messages
|
||||
if len(first) != 2 || !strings.Contains(first[1].Text(), "SUPERVISOR") {
|
||||
t.Errorf("first call messages = %+v, want steered message", first)
|
||||
}
|
||||
// Drained: second call must not duplicate it.
|
||||
second := fp.Calls()[1].Request.Messages
|
||||
count := 0
|
||||
for _, m := range second {
|
||||
if strings.Contains(m.Text(), "SUPERVISOR") {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("steer message appears %d times in second call, want 1", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactorShrinksOutboundContext: the model sees the compacted view;
|
||||
// the canonical transcript keeps everything.
|
||||
func TestCompactorShrinksOutboundContext(t *testing.T) {
|
||||
fp := fake.New("fp")
|
||||
fp.Enqueue("test-model", fake.Reply("answer"))
|
||||
|
||||
history := []llm.Message{
|
||||
llm.UserText("old 1"), llm.AssistantText("old reply 1"),
|
||||
llm.UserText("old 2"), llm.AssistantText("old reply 2"),
|
||||
}
|
||||
a := New(newModel(t, fp), "", WithCompactor(func(_ context.Context, msgs []llm.Message) ([]llm.Message, error) {
|
||||
// Keep only the last message, prefixed by a synthetic summary.
|
||||
return append([]llm.Message{llm.UserText("[summary of earlier conversation]")}, msgs[len(msgs)-1]), nil
|
||||
}))
|
||||
res, err := a.Run(context.Background(), "new question", WithHistory(history))
|
||||
if err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
sent := fp.Calls()[0].Request.Messages
|
||||
if len(sent) != 2 || !strings.Contains(sent[0].Text(), "summary") {
|
||||
t.Errorf("sent = %+v, want compacted view", sent)
|
||||
}
|
||||
if len(res.Messages) != 6 {
|
||||
t.Errorf("transcript = %d messages, want full uncompacted history", len(res.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactorErrorIsNonFatal: a failing compactor falls back to the
|
||||
// original messages.
|
||||
func TestCompactorErrorIsNonFatal(t *testing.T) {
|
||||
fp := fake.New("fp")
|
||||
fp.Enqueue("test-model", fake.Reply("fine"))
|
||||
|
||||
a := New(newModel(t, fp), "", WithCompactor(func(context.Context, []llm.Message) ([]llm.Message, error) {
|
||||
return nil, errors.New("summarizer down")
|
||||
}))
|
||||
res, err := a.Run(context.Background(), "go")
|
||||
if err != nil || res.Output != "fine" {
|
||||
t.Errorf("res=%v err=%v", res, err)
|
||||
}
|
||||
if len(fp.Calls()[0].Request.Messages) != 1 {
|
||||
t.Error("original messages must be sent when compaction fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsecutiveToolErrorGuard: steps whose tools ALL fail trip the guard.
|
||||
func TestConsecutiveToolErrorGuard(t *testing.T) {
|
||||
fp := fake.New("fp", fake.WithDefault(func(string, llm.Request) fake.Step {
|
||||
return toolCallReply("c", "bomb", `{}`)
|
||||
}))
|
||||
bomb := llm.NewToolbox("danger", llm.Tool{
|
||||
Name: "bomb",
|
||||
Handler: func(context.Context, json.RawMessage) (any, error) { return nil, errors.New("always fails") },
|
||||
})
|
||||
|
||||
a := New(newModel(t, fp), "", WithToolbox(bomb), WithToolErrorLimits(2, 0), WithMaxSteps(10))
|
||||
res, err := a.Run(context.Background(), "go")
|
||||
if !errors.Is(err, ErrToolLoop) {
|
||||
t.Fatalf("err = %v, want ErrToolLoop", err)
|
||||
}
|
||||
if len(res.Steps) != 2 {
|
||||
t.Errorf("steps = %d, want guard to trip after 2", len(res.Steps))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSameCallRepeatGuard: identical (name+args) calls beyond the limit
|
||||
// trip the guard; varied calls do not.
|
||||
func TestSameCallRepeatGuard(t *testing.T) {
|
||||
fp := fake.New("fp", fake.WithDefault(func(string, llm.Request) fake.Step {
|
||||
return toolCallReply("c", "add", `{"a":1,"b":1}`)
|
||||
}))
|
||||
|
||||
a := New(newModel(t, fp), "", WithToolbox(adderToolbox(t)), WithToolErrorLimits(0, 3), WithMaxSteps(10))
|
||||
_, err := a.Run(context.Background(), "go")
|
||||
if !errors.Is(err, ErrToolLoop) || !strings.Contains(err.Error(), `"add"`) {
|
||||
t.Fatalf("err = %v, want repeat-guard ErrToolLoop naming add", err)
|
||||
}
|
||||
|
||||
// Varied arguments never trip it.
|
||||
n := 0
|
||||
fp2 := fake.New("fp", fake.WithDefault(func(string, llm.Request) fake.Step {
|
||||
n++
|
||||
if n > 4 {
|
||||
return fake.Reply("done")
|
||||
}
|
||||
return toolCallReply("c", "add", `{"a":1,"b":`+string(rune('0'+n))+`}`)
|
||||
}))
|
||||
a2 := New(newModel(t, fp2), "", WithToolbox(adderToolbox(t)), WithToolErrorLimits(0, 3), WithMaxSteps(10))
|
||||
if _, err := a2.Run(context.Background(), "go"); err != nil {
|
||||
t.Errorf("varied calls must not trip the guard: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -91,10 +91,17 @@ func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func
|
||||
var zero T
|
||||
var failures []error
|
||||
|
||||
observe := func(ev FailoverEvent) {
|
||||
if c.cfg.Observer != nil {
|
||||
c.cfg.Observer(ev)
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range c.targets {
|
||||
if !c.tracker.Available(t.key) {
|
||||
until := c.tracker.BackedOffUntil(t.key)
|
||||
failures = append(failures, fmt.Errorf("%s: skipped (backed off until %s)", t.key, until.Format("15:04:05.000")))
|
||||
observe(FailoverEvent{Target: t.key, Skipped: true})
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -119,6 +126,7 @@ func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func
|
||||
|
||||
class := c.cfg.classify(err)
|
||||
if class == llm.ClassPermanent {
|
||||
observe(FailoverEvent{Target: t.key, Err: err, Class: class, Attempt: attemptN})
|
||||
if errors.Is(err, llm.ErrModelNotFound) || errors.Is(err, llm.ErrUnsupported) || c.cfg.AdvanceOnPermanent {
|
||||
// Not a health problem (or policy says keep going):
|
||||
// advance without penalizing the target.
|
||||
@@ -134,6 +142,7 @@ func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func
|
||||
// attempts remain — but advance as soon as the tracker benches
|
||||
// it (a freshly backed-off target is not worth more retries).
|
||||
benched := c.tracker.ReportFailure(t.key)
|
||||
observe(FailoverEvent{Target: t.key, Err: err, Class: class, Attempt: attemptN, Benched: benched})
|
||||
if !benched && attemptN < retries {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# ADR-0014: Conversion-driven extensions (resolvers, typed tools, hooks, ops controls)
|
||||
|
||||
**Status:** Accepted — 2026-06-10
|
||||
|
||||
## Context
|
||||
|
||||
Executing the mort conversion (docs/mort-migration.md) surfaced seven
|
||||
capabilities mort's orchestration layer needs from its model substrate.
|
||||
Each is general-purpose — none encodes anything mort-specific — and each
|
||||
was promised in the migration blueprint.
|
||||
|
||||
## Decision
|
||||
|
||||
1. **`Registry.RegisterResolver`** — dynamic alias resolution for DB/
|
||||
config-backed tiers. Checked after static aliases (static wins), in
|
||||
registration order, without holding the registry lock (resolvers do
|
||||
I/O); output expands recursively under the existing cycle guard.
|
||||
2. **`DefineTool[Args]`** — typed tools: parameters from `SchemaFor[Args]`,
|
||||
arguments unmarshaled before the handler. Schema failure panics
|
||||
(startup-time programming error, mirroring NewToolbox).
|
||||
3. **`Usage` detail fields** — CacheReadTokens / CacheWriteTokens /
|
||||
ReasoningTokens, populated where providers report them (Anthropic cache
|
||||
fields; OpenAI prompt/completion detail objects; Google cached-content
|
||||
+ thoughts). Input/Output remain totals; details are portions.
|
||||
4. **`WithPromptCaching`** — Anthropic top-level `cache_control`
|
||||
auto-placement; a no-op for providers that cache implicitly or not at
|
||||
all.
|
||||
5. **Agent loop hooks** — `WithMaxStepsFunc` (dynamic ceiling, consulted
|
||||
every step, for supervisor budget adjustment), `WithSteer` (per-run
|
||||
message injection drained before each step), `WithCompactor`
|
||||
(pre-Generate transcript transformation; errors fall back to the
|
||||
original — losing a request to a broken summarizer is worse than a long
|
||||
prompt), `WithToolErrorLimits` (consecutive all-error steps and
|
||||
identical-call repeats end the run with `ErrToolLoop` + partial result).
|
||||
The canonical transcript in `Result.Messages` is always the
|
||||
uncompacted truth; compaction affects only what is sent.
|
||||
6. **Manual health controls** — `Tracker.Bench/Unbench/Snapshot` for
|
||||
ops surfaces (".failover bench" commands, dashboards).
|
||||
7. **`ChainConfig.Observer`** — one synchronous callback per failover
|
||||
decision (failed attempt with classification, bench, benched-skip).
|
||||
This is a hook, not an observability stack; persistence/metrics remain
|
||||
the consumer's business (anti-creep guardrail intact).
|
||||
|
||||
## Consequences
|
||||
|
||||
- mort's tier resolver, `.failover` admin surface, failover-event log,
|
||||
run-critic budget control, steering, and compaction all rebase onto
|
||||
public API.
|
||||
- The agent loop gains supervision points without changing its
|
||||
never-panic, partial-result-on-error contract.
|
||||
@@ -17,3 +17,4 @@ One decision per file, append-only; supersede rather than rewrite.
|
||||
| [0011](0011-google-provider.md) | Google provider on the official Gen AI SDK | Accepted |
|
||||
| [0012](0012-agent-loop.md) | Agent run loop | Accepted |
|
||||
| [0013](0013-skill-model.md) | Skill model — additive instruction+tool bundles | Accepted |
|
||||
| [0014](0014-conversion-driven-extensions.md) | Conversion-driven extensions (resolvers, typed tools, hooks, ops controls) | Accepted |
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
package majordomo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake"
|
||||
)
|
||||
|
||||
// TestRegisterResolver: dynamic tiers resolve after static aliases, expand
|
||||
// recursively, and respect cycle detection.
|
||||
func TestRegisterResolver(t *testing.T) {
|
||||
r := newTestRegistry(t)
|
||||
r.RegisterProvider(fake.New("fp"))
|
||||
r.RegisterAlias("static", "fp/a")
|
||||
|
||||
tiers := map[string]string{
|
||||
"db-tier": "fp/x,fp/y",
|
||||
"db-nested": "db-tier,fp/z",
|
||||
"db-cycle": "db-cycle",
|
||||
}
|
||||
r.RegisterResolver(ResolverFunc(func(name string) (string, bool) {
|
||||
spec, ok := tiers[name]
|
||||
return spec, ok
|
||||
}))
|
||||
|
||||
m, err := r.Parse("static,db-nested")
|
||||
if err != nil {
|
||||
t.Fatalf("Parse: %v", err)
|
||||
}
|
||||
want := []string{"fp/a", "fp/x", "fp/y", "fp/z"}
|
||||
if got := targetsOf(t, m); strings.Join(got, ",") != strings.Join(want, ",") {
|
||||
t.Errorf("targets = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if _, err := r.Parse("db-cycle"); !errors.Is(err, ErrAliasCycle) {
|
||||
t.Errorf("cycle error = %v, want ErrAliasCycle", err)
|
||||
}
|
||||
|
||||
// Static aliases shadow resolvers.
|
||||
tiers["static"] = "fp/wrong"
|
||||
m, _ = r.Parse("static")
|
||||
if got := targetsOf(t, m); got[0] != "fp/a" {
|
||||
t.Errorf("static alias must win over resolver, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefineTool: schema from Args, decoded handler arguments.
|
||||
func TestDefineTool(t *testing.T) {
|
||||
type addArgs struct {
|
||||
A int `json:"a" description:"first addend"`
|
||||
B int `json:"b"`
|
||||
}
|
||||
tool := DefineTool("add", "Add two ints", func(_ context.Context, args addArgs) (any, error) {
|
||||
return args.A + args.B, nil
|
||||
})
|
||||
|
||||
var schema map[string]any
|
||||
if err := json.Unmarshal(tool.Parameters, &schema); err != nil {
|
||||
t.Fatalf("schema: %v", err)
|
||||
}
|
||||
props := schema["properties"].(map[string]any)
|
||||
if props["a"].(map[string]any)["description"] != "first addend" {
|
||||
t.Errorf("schema = %v", schema)
|
||||
}
|
||||
|
||||
res := llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "1", Name: "add", Arguments: json.RawMessage(`{"a":2,"b":40}`)})
|
||||
if res.IsError || res.Content != "42" {
|
||||
t.Errorf("result = %+v", res)
|
||||
}
|
||||
res = llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "2", Name: "add", Arguments: json.RawMessage(`{"a":"nope"}`)})
|
||||
if !res.IsError || !strings.Contains(res.Content, "invalid arguments") {
|
||||
t.Errorf("bad-args result = %+v", res)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChainObserver: failover decisions emit events (attempt, bench, skip).
|
||||
func TestChainObserver(t *testing.T) {
|
||||
var events []FailoverEvent
|
||||
r := newTestRegistry(t, WithChainConfig(ChainConfig{
|
||||
Observer: func(ev FailoverEvent) { events = append(events, ev) },
|
||||
}))
|
||||
fp := fake.New("fp")
|
||||
r.RegisterProvider(fp)
|
||||
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")))
|
||||
fp.Enqueue("b", fake.Reply("ok"), fake.Reply("ok"))
|
||||
|
||||
m, _ := r.Parse("fp/a,fp/b")
|
||||
if _, err := generate(t, m); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("events = %+v, want 2 failed attempts", events)
|
||||
}
|
||||
if events[0].Target != "fp/a" || events[0].Attempt != 0 || events[0].Benched {
|
||||
t.Errorf("event 0 = %+v", events[0])
|
||||
}
|
||||
if !events[1].Benched {
|
||||
t.Errorf("event 1 = %+v, want Benched", events[1])
|
||||
}
|
||||
|
||||
// Second request: skipped-while-benched event.
|
||||
events = nil
|
||||
if _, err := generate(t, m); err != nil {
|
||||
t.Fatalf("Generate #2: %v", err)
|
||||
}
|
||||
if len(events) != 1 || !events[0].Skipped || events[0].Target != "fp/a" {
|
||||
t.Errorf("events = %+v, want one skip event", events)
|
||||
}
|
||||
}
|
||||
|
||||
// TestManualBenchControls: ops surfaces can bench/unbench/inspect.
|
||||
func TestManualBenchControls(t *testing.T) {
|
||||
clock := newFakeClock()
|
||||
r := newTestRegistry(t, WithClock(clock.Now))
|
||||
fp := fake.New("fp")
|
||||
r.RegisterProvider(fp)
|
||||
fp.Enqueue("b", fake.Reply("from-b"))
|
||||
|
||||
r.Health().Bench("fp/a", clock.Now().Add(time.Hour))
|
||||
if r.Health().Available("fp/a") {
|
||||
t.Fatal("manual bench must take effect")
|
||||
}
|
||||
snap := r.Health().Snapshot()
|
||||
if len(snap) != 1 || snap[0].Key != "fp/a" {
|
||||
t.Errorf("snapshot = %+v", snap)
|
||||
}
|
||||
|
||||
m, _ := r.Parse("fp/a,fp/b")
|
||||
resp, err := generate(t, m)
|
||||
if err != nil || resp.Text() != "from-b" {
|
||||
t.Fatalf("resp=%v err=%v (benched head must be skipped)", resp, err)
|
||||
}
|
||||
|
||||
r.Health().Unbench("fp/a")
|
||||
if !r.Health().Available("fp/a") {
|
||||
t.Error("unbench must re-admit")
|
||||
}
|
||||
if len(r.Health().Snapshot()) != 0 {
|
||||
t.Error("snapshot must be empty after unbench")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPromptCachingOptionIsCarried: the request flag round-trips (the
|
||||
// anthropic wire mapping is asserted in its own package).
|
||||
func TestPromptCachingOptionIsCarried(t *testing.T) {
|
||||
r := newTestRegistry(t)
|
||||
fp := fake.New("fp")
|
||||
r.RegisterProvider(fp)
|
||||
|
||||
m, _ := r.Parse("fp/x")
|
||||
if _, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}, WithPromptCaching()); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if !fp.Calls()[0].Request.PromptCache {
|
||||
t.Error("PromptCache flag must reach the provider")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUsageDetailAccumulation: Usage.Add sums the detail fields.
|
||||
func TestUsageDetailAccumulation(t *testing.T) {
|
||||
u := Usage{InputTokens: 10, OutputTokens: 5, CacheReadTokens: 4, CacheWriteTokens: 2, ReasoningTokens: 3}
|
||||
u.Add(Usage{InputTokens: 1, OutputTokens: 1, CacheReadTokens: 1, CacheWriteTokens: 1, ReasoningTokens: 1})
|
||||
if u.CacheReadTokens != 5 || u.CacheWriteTokens != 3 || u.ReasoningTokens != 4 {
|
||||
t.Errorf("usage = %+v", u)
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,12 @@ import (
|
||||
// suitable for WithSchema and tool parameters.
|
||||
func SchemaFor[T any]() (json.RawMessage, error) { return llm.SchemaFor[T]() }
|
||||
|
||||
// DefineTool re-exports llm.DefineTool: a typed tool whose parameter schema
|
||||
// is derived from Args and whose handler receives decoded arguments.
|
||||
func DefineTool[Args any](name, description string, fn func(ctx context.Context, args Args) (any, error)) Tool {
|
||||
return llm.DefineTool(name, description, fn)
|
||||
}
|
||||
|
||||
// Generate performs a structured-output request and unmarshals the result
|
||||
// into T: the schema is derived from T (llm.SchemaFor), injected via the
|
||||
// provider's native structured-output mechanism, and the response text is
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -133,6 +135,57 @@ func (t *Tracker) ReportFailure(key string) (backedOff bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Bench manually benches a key until the given time, regardless of its
|
||||
// failure history (ops surfaces: ".failover bench"). A zero time unbenches.
|
||||
func (t *Tracker) Bench(key string, until time.Time) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
e, ok := t.entries[key]
|
||||
if !ok {
|
||||
e = &entry{}
|
||||
t.entries[key] = e
|
||||
}
|
||||
e.until = until
|
||||
}
|
||||
|
||||
// Unbench clears a key entirely: bench window, failure count, and backoff
|
||||
// history.
|
||||
func (t *Tracker) Unbench(key string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
delete(t.entries, key)
|
||||
}
|
||||
|
||||
// Status describes one tracked key (diagnostics).
|
||||
type Status struct {
|
||||
Key string
|
||||
// Until is the end of the current bench window (zero = not benched).
|
||||
Until time.Time
|
||||
// ConsecutiveFailures since the last success or bench trigger.
|
||||
ConsecutiveFailures int
|
||||
// Benches is the consecutive backoff count driving the exponent.
|
||||
Benches int
|
||||
}
|
||||
|
||||
// Snapshot returns the status of every currently-benched key, sorted by
|
||||
// key, evaluated at the tracker's clock.
|
||||
func (t *Tracker) Snapshot() []Status {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
now := t.cfg.Clock()
|
||||
out := make([]Status, 0, len(t.entries))
|
||||
for key, e := range t.entries {
|
||||
if now.Before(e.until) {
|
||||
out = append(out, Status{
|
||||
Key: key, Until: e.until,
|
||||
ConsecutiveFailures: e.consecutiveFailures, Benches: e.backoffs,
|
||||
})
|
||||
}
|
||||
}
|
||||
slices.SortFunc(out, func(a, b Status) int { return strings.Compare(a.Key, b.Key) })
|
||||
return out
|
||||
}
|
||||
|
||||
// BackedOffUntil returns the end of the key's current backoff window, or the
|
||||
// zero time when the key is not backed off. Useful for diagnostics and error
|
||||
// messages.
|
||||
|
||||
@@ -43,6 +43,11 @@ type Request struct {
|
||||
// Providers map it to their native knob (OpenAI reasoning_effort,
|
||||
// Ollama think levels) and ignore it where no mapping exists.
|
||||
ReasoningEffort string
|
||||
|
||||
// PromptCache opts the request into the provider's prompt caching
|
||||
// (Anthropic cache_control; ignored by providers that cache
|
||||
// automatically or not at all).
|
||||
PromptCache bool
|
||||
}
|
||||
|
||||
// Option mutates a Request before it is sent. Options passed to Generate or
|
||||
@@ -100,6 +105,12 @@ func WithReasoningEffort(level string) Option {
|
||||
return func(r *Request) { r.ReasoningEffort = level }
|
||||
}
|
||||
|
||||
// WithPromptCaching opts into provider prompt caching where it is an
|
||||
// explicit feature (Anthropic); a no-op elsewhere.
|
||||
func WithPromptCaching() Option {
|
||||
return func(r *Request) { r.PromptCache = true }
|
||||
}
|
||||
|
||||
// Apply returns a copy of the request with all options applied. Providers
|
||||
// and wrappers call this once at the top of Generate/Stream.
|
||||
func (r Request) Apply(opts ...Option) Request {
|
||||
|
||||
+16
-1
@@ -18,10 +18,22 @@ const (
|
||||
FinishOther FinishReason = "other"
|
||||
)
|
||||
|
||||
// Usage reports token accounting for one request.
|
||||
// Usage reports token accounting for one request. InputTokens and
|
||||
// OutputTokens are always totals; the detail fields break out portions of
|
||||
// those totals where the provider reports them (0 = not reported).
|
||||
type Usage struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
|
||||
// CacheReadTokens is the portion of InputTokens served from the
|
||||
// provider's prompt cache.
|
||||
CacheReadTokens int
|
||||
// CacheWriteTokens is the portion of InputTokens written to the
|
||||
// provider's prompt cache.
|
||||
CacheWriteTokens int
|
||||
// ReasoningTokens is the portion of OutputTokens spent on
|
||||
// thinking/reasoning.
|
||||
ReasoningTokens int
|
||||
}
|
||||
|
||||
// Total returns input plus output tokens.
|
||||
@@ -31,6 +43,9 @@ func (u Usage) Total() int { return u.InputTokens + u.OutputTokens }
|
||||
func (u *Usage) Add(o Usage) {
|
||||
u.InputTokens += o.InputTokens
|
||||
u.OutputTokens += o.OutputTokens
|
||||
u.CacheReadTokens += o.CacheReadTokens
|
||||
u.CacheWriteTokens += o.CacheWriteTokens
|
||||
u.ReasoningTokens += o.ReasoningTokens
|
||||
}
|
||||
|
||||
// Response is the canonical generation result.
|
||||
|
||||
+34
@@ -22,6 +22,40 @@ type Tool struct {
|
||||
Handler func(ctx context.Context, args json.RawMessage) (any, error)
|
||||
}
|
||||
|
||||
// DefineTool builds a typed tool: the parameter schema is derived from
|
||||
// Args (see SchemaFor) and the raw JSON arguments are unmarshaled into an
|
||||
// Args value before the handler runs.
|
||||
//
|
||||
// weather := llm.DefineTool("get_weather", "Current weather for a city",
|
||||
// func(ctx context.Context, args struct {
|
||||
// City string `json:"city" description:"city name"`
|
||||
// }) (any, error) {
|
||||
// return lookup(args.City)
|
||||
// })
|
||||
//
|
||||
// Schema derivation failures panic: tools are defined at startup and an
|
||||
// unschematizable Args type is a programming error worth failing loudly on.
|
||||
func DefineTool[Args any](name, description string, fn func(ctx context.Context, args Args) (any, error)) Tool {
|
||||
schema, err := SchemaFor[Args]()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("llm: DefineTool(%q): %v", name, err))
|
||||
}
|
||||
return Tool{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Parameters: schema,
|
||||
Handler: func(ctx context.Context, raw json.RawMessage) (any, error) {
|
||||
var args Args
|
||||
if len(raw) > 0 {
|
||||
if err := json.Unmarshal(raw, &args); err != nil {
|
||||
return nil, fmt.Errorf("invalid arguments for %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
return fn(ctx, args)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToolCall is a model's request to invoke a tool.
|
||||
type ToolCall struct {
|
||||
// ID is the provider-assigned call id; majordomo synthesizes one for
|
||||
|
||||
@@ -100,6 +100,7 @@ func WithTopP(p float64) Option { return llm.WithTop
|
||||
func WithMaxTokens(n int) Option { return llm.WithMaxTokens(n) }
|
||||
func WithStopSequences(stops ...string) Option { return llm.WithStopSequences(stops...) }
|
||||
func WithReasoningEffort(level string) Option { return llm.WithReasoningEffort(level) }
|
||||
func WithPromptCaching() Option { return llm.WithPromptCaching() }
|
||||
|
||||
// WithModelCapabilities re-exports llm.WithCapabilities for Provider.Model
|
||||
// calls made through this package.
|
||||
|
||||
@@ -97,12 +97,23 @@ func (r *Registry) expand(spec string, visiting []string) ([]element, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Bare token: must be a registered alias.
|
||||
// Bare token: a registered alias, or a dynamic resolver hit.
|
||||
r.mu.RLock()
|
||||
target, isAlias := r.aliases[raw]
|
||||
_, isProvider := r.providers[raw]
|
||||
resolvers := slices.Clone(r.resolvers)
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !isAlias {
|
||||
// Resolvers run without the lock — they may call back into the
|
||||
// registry (and DB-backed ones block on I/O).
|
||||
for _, res := range resolvers {
|
||||
if spec, ok := res.Resolve(raw); ok && spec != "" {
|
||||
target, isAlias = spec, true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isAlias {
|
||||
if isProvider {
|
||||
return nil, fmt.Errorf("%q is a provider, not an alias — use %q", raw, raw+"/<model-id>")
|
||||
|
||||
+14
@@ -1,5 +1,19 @@
|
||||
# progress
|
||||
|
||||
## 2026-06-10 — Phase 9a: conversion-driven library extensions
|
||||
|
||||
**Landed (ADR-0014):** RegisterResolver (dynamic DB-backed tiers, static
|
||||
aliases win, recursive + cycle-guarded), DefineTool[Args] (typed tools
|
||||
over SchemaFor), Usage cache/reasoning detail fields populated by
|
||||
anthropic/openai/google, WithPromptCaching (Anthropic top-level
|
||||
cache_control), agent hooks (WithMaxStepsFunc, WithSteer, WithCompactor —
|
||||
non-fatal on error, canonical transcript stays uncompacted —
|
||||
WithToolErrorLimits with ErrToolLoop), health Bench/Unbench/Snapshot,
|
||||
ChainConfig.Observer failover events (attempt/bench/skip). Full hermetic
|
||||
coverage for each.
|
||||
|
||||
**Next:** Phase 9b — the mort conversion branch.
|
||||
|
||||
## 2026-06-10 — Phase 8: live validation against real Ollama Cloud
|
||||
|
||||
**All six checks PASS** (examples/live harness, OLLAMA_API_KEY from .env):
|
||||
|
||||
@@ -24,6 +24,13 @@ type wireRequest struct {
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
OutputConfig *wireOutputConfig `json:"output_config,omitempty"`
|
||||
// CacheControl is the top-level auto-placement form of prompt caching:
|
||||
// the API puts the breakpoint on the last cacheable block.
|
||||
CacheControl *wireCacheControl `json:"cache_control,omitempty"`
|
||||
}
|
||||
|
||||
type wireCacheControl struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type wireMessage struct {
|
||||
@@ -109,8 +116,10 @@ type wireUsage struct {
|
||||
// real total input is input + cache_creation + cache_read.
|
||||
func (u wireUsage) toUsage() llm.Usage {
|
||||
return llm.Usage{
|
||||
InputTokens: u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens,
|
||||
OutputTokens: u.OutputTokens,
|
||||
InputTokens: u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens,
|
||||
OutputTokens: u.OutputTokens,
|
||||
CacheReadTokens: u.CacheReadInputTokens,
|
||||
CacheWriteTokens: u.CacheCreationInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,6 +166,11 @@ func buildWireRequest(modelID string, req llm.Request, defaultMax int, stream bo
|
||||
Schema: req.Schema,
|
||||
}}
|
||||
}
|
||||
if req.PromptCache {
|
||||
// Top-level auto-placement: the API puts the cache breakpoint on
|
||||
// the last cacheable block.
|
||||
wr.CacheControl = &wireCacheControl{Type: "ephemeral"}
|
||||
}
|
||||
return wr
|
||||
}
|
||||
|
||||
|
||||
@@ -364,8 +364,10 @@ func (m *model) toResponse(resp *genai.GenerateContentResponse) *llm.Response {
|
||||
out := &llm.Response{Model: m.qualified(), Raw: resp}
|
||||
if resp.UsageMetadata != nil {
|
||||
out.Usage = llm.Usage{
|
||||
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
|
||||
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount + resp.UsageMetadata.ThoughtsTokenCount),
|
||||
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
|
||||
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount + resp.UsageMetadata.ThoughtsTokenCount),
|
||||
CacheReadTokens: int(resp.UsageMetadata.CachedContentTokenCount),
|
||||
ReasoningTokens: int(resp.UsageMetadata.ThoughtsTokenCount),
|
||||
}
|
||||
}
|
||||
if len(resp.Candidates) == 0 {
|
||||
|
||||
@@ -78,8 +78,10 @@ func (s *stream) Next() (llm.StreamEvent, error) {
|
||||
|
||||
if chunk.UsageMetadata != nil {
|
||||
s.usage = llm.Usage{
|
||||
InputTokens: int(chunk.UsageMetadata.PromptTokenCount),
|
||||
OutputTokens: int(chunk.UsageMetadata.CandidatesTokenCount + chunk.UsageMetadata.ThoughtsTokenCount),
|
||||
InputTokens: int(chunk.UsageMetadata.PromptTokenCount),
|
||||
OutputTokens: int(chunk.UsageMetadata.CandidatesTokenCount + chunk.UsageMetadata.ThoughtsTokenCount),
|
||||
CacheReadTokens: int(chunk.UsageMetadata.CachedContentTokenCount),
|
||||
ReasoningTokens: int(chunk.UsageMetadata.ThoughtsTokenCount),
|
||||
}
|
||||
}
|
||||
if len(chunk.Candidates) == 0 {
|
||||
|
||||
@@ -130,10 +130,7 @@ func (m *model) apiError(httpResp *http.Response) error {
|
||||
func (m *model) toResponse(wire *chatResponse) *llm.Response {
|
||||
resp := &llm.Response{Model: m.p.name + "/" + m.id, Raw: wire}
|
||||
if wire.Usage != nil {
|
||||
resp.Usage = llm.Usage{
|
||||
InputTokens: wire.Usage.PromptTokens,
|
||||
OutputTokens: wire.Usage.CompletionTokens,
|
||||
}
|
||||
resp.Usage = wire.Usage.toUsage()
|
||||
}
|
||||
if len(wire.Choices) == 0 {
|
||||
resp.FinishReason = llm.FinishOther
|
||||
|
||||
@@ -104,10 +104,7 @@ func (s *stream) handleChunk(data []byte) error {
|
||||
return apiErr
|
||||
}
|
||||
if chunk.Usage != nil {
|
||||
s.usage = llm.Usage{
|
||||
InputTokens: chunk.Usage.PromptTokens,
|
||||
OutputTokens: chunk.Usage.CompletionTokens,
|
||||
}
|
||||
s.usage = chunk.Usage.toUsage()
|
||||
}
|
||||
// Why the guard: the include_usage chunk arrives with an EMPTY choices
|
||||
// array; indexing choices[0] unconditionally would panic on it.
|
||||
|
||||
+26
-3
@@ -125,9 +125,32 @@ type wireRespMessage struct {
|
||||
}
|
||||
|
||||
type wireUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *wirePromptDetail `json:"prompt_tokens_details"`
|
||||
CompletionTokensDetails *wireOutputDetail `json:"completion_tokens_details"`
|
||||
}
|
||||
|
||||
type wirePromptDetail struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type wireOutputDetail struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
// toUsage maps wire usage (with optional detail objects — absent on many
|
||||
// compat servers) onto the canonical Usage.
|
||||
func (u *wireUsage) toUsage() llm.Usage {
|
||||
out := llm.Usage{InputTokens: u.PromptTokens, OutputTokens: u.CompletionTokens}
|
||||
if u.PromptTokensDetails != nil {
|
||||
out.CacheReadTokens = u.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if u.CompletionTokensDetails != nil {
|
||||
out.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type errorEnvelope struct {
|
||||
|
||||
+47
@@ -18,6 +18,7 @@ type Registry struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]llm.Provider
|
||||
aliases map[string]string
|
||||
resolvers []Resolver
|
||||
schemes map[string]SchemeFactory
|
||||
// envErrs records LLM_* entries that failed to load so the failure
|
||||
// surfaces when (and only when) that provider name is actually used.
|
||||
@@ -28,6 +29,22 @@ type Registry struct {
|
||||
envLookup func(string) string
|
||||
}
|
||||
|
||||
// Resolver dynamically resolves a bare alias token to a spec string —
|
||||
// the hook for DB- or config-backed tiers that change at runtime. Checked
|
||||
// after static aliases, in registration order; the returned spec is
|
||||
// expanded recursively (chains, nested aliases) with cycle detection.
|
||||
type Resolver interface {
|
||||
// Resolve returns the spec for name, or ok=false when this resolver
|
||||
// does not handle it.
|
||||
Resolve(name string) (spec string, ok bool)
|
||||
}
|
||||
|
||||
// ResolverFunc adapts a function to the Resolver interface.
|
||||
type ResolverFunc func(name string) (string, bool)
|
||||
|
||||
// Resolve implements Resolver.
|
||||
func (f ResolverFunc) Resolve(name string) (string, bool) { return f(name) }
|
||||
|
||||
// SchemeFactory builds a provider instance from an env DSN. name is the
|
||||
// registry name the provider will be registered under (e.g. "m1" for
|
||||
// LLM_M1); dsn carries the scheme, credential, and host.
|
||||
@@ -49,6 +66,27 @@ type ChainConfig struct {
|
||||
|
||||
// Classify overrides the default error classifier (llm.Classify).
|
||||
Classify func(error) llm.ErrorClass
|
||||
|
||||
// Observer, when non-nil, receives one event per failover decision
|
||||
// (failed attempt, bench, benched-skip) — the hook for persisting
|
||||
// failover logs. Called synchronously; keep it fast or hand off.
|
||||
Observer func(FailoverEvent)
|
||||
}
|
||||
|
||||
// FailoverEvent describes one failover decision in a chain.
|
||||
type FailoverEvent struct {
|
||||
// Target is the "provider/model" key the event concerns.
|
||||
Target string
|
||||
// Err is the failure (nil for benched-skip events).
|
||||
Err error
|
||||
// Class is the error classification (meaningful when Err != nil).
|
||||
Class llm.ErrorClass
|
||||
// Attempt is the 0-based attempt number on this target in this request.
|
||||
Attempt int
|
||||
// Benched reports that this failure benched the target.
|
||||
Benched bool
|
||||
// Skipped reports the target was skipped because it is benched.
|
||||
Skipped bool
|
||||
}
|
||||
|
||||
// DefaultTransientRetries is the default number of same-target retries
|
||||
@@ -180,6 +218,15 @@ func (r *Registry) RegisterAlias(name, spec string) {
|
||||
r.aliases[name] = spec
|
||||
}
|
||||
|
||||
// RegisterResolver appends a dynamic alias resolver (e.g. database-backed
|
||||
// tiers). Resolvers are consulted in registration order, after static
|
||||
// aliases.
|
||||
func (r *Registry) RegisterResolver(res Resolver) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.resolvers = append(r.resolvers, res)
|
||||
}
|
||||
|
||||
// RegisterScheme adds or replaces an env-DSN scheme factory, letting
|
||||
// consumers wire custom provider kinds into LLM_* definitions.
|
||||
func (r *Registry) RegisterScheme(scheme string, factory SchemeFactory) {
|
||||
|
||||
Reference in New Issue
Block a user