0147a79d18
Phase 9a (ADR-0014): Registry.RegisterResolver for dynamic tiers; DefineTool[Args] typed tools; Usage cache/reasoning detail fields wired through anthropic/openai/google; WithPromptCaching (Anthropic cache_control); agent supervision hooks (WithMaxStepsFunc, WithSteer, WithCompactor, WithToolErrorLimits + ErrToolLoop); health Bench/Unbench/Snapshot; ChainConfig.Observer failover events. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
173 lines
5.3 KiB
Go
173 lines
5.3 KiB
Go
package majordomo
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
|
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake"
|
|
)
|
|
|
|
// TestRegisterResolver: dynamic tiers resolve after static aliases, expand
|
|
// recursively, and respect cycle detection.
|
|
func TestRegisterResolver(t *testing.T) {
|
|
r := newTestRegistry(t)
|
|
r.RegisterProvider(fake.New("fp"))
|
|
r.RegisterAlias("static", "fp/a")
|
|
|
|
tiers := map[string]string{
|
|
"db-tier": "fp/x,fp/y",
|
|
"db-nested": "db-tier,fp/z",
|
|
"db-cycle": "db-cycle",
|
|
}
|
|
r.RegisterResolver(ResolverFunc(func(name string) (string, bool) {
|
|
spec, ok := tiers[name]
|
|
return spec, ok
|
|
}))
|
|
|
|
m, err := r.Parse("static,db-nested")
|
|
if err != nil {
|
|
t.Fatalf("Parse: %v", err)
|
|
}
|
|
want := []string{"fp/a", "fp/x", "fp/y", "fp/z"}
|
|
if got := targetsOf(t, m); strings.Join(got, ",") != strings.Join(want, ",") {
|
|
t.Errorf("targets = %v, want %v", got, want)
|
|
}
|
|
|
|
if _, err := r.Parse("db-cycle"); !errors.Is(err, ErrAliasCycle) {
|
|
t.Errorf("cycle error = %v, want ErrAliasCycle", err)
|
|
}
|
|
|
|
// Static aliases shadow resolvers.
|
|
tiers["static"] = "fp/wrong"
|
|
m, _ = r.Parse("static")
|
|
if got := targetsOf(t, m); got[0] != "fp/a" {
|
|
t.Errorf("static alias must win over resolver, got %v", got)
|
|
}
|
|
}
|
|
|
|
// TestDefineTool: schema from Args, decoded handler arguments.
|
|
func TestDefineTool(t *testing.T) {
|
|
type addArgs struct {
|
|
A int `json:"a" description:"first addend"`
|
|
B int `json:"b"`
|
|
}
|
|
tool := DefineTool("add", "Add two ints", func(_ context.Context, args addArgs) (any, error) {
|
|
return args.A + args.B, nil
|
|
})
|
|
|
|
var schema map[string]any
|
|
if err := json.Unmarshal(tool.Parameters, &schema); err != nil {
|
|
t.Fatalf("schema: %v", err)
|
|
}
|
|
props := schema["properties"].(map[string]any)
|
|
if props["a"].(map[string]any)["description"] != "first addend" {
|
|
t.Errorf("schema = %v", schema)
|
|
}
|
|
|
|
res := llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "1", Name: "add", Arguments: json.RawMessage(`{"a":2,"b":40}`)})
|
|
if res.IsError || res.Content != "42" {
|
|
t.Errorf("result = %+v", res)
|
|
}
|
|
res = llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "2", Name: "add", Arguments: json.RawMessage(`{"a":"nope"}`)})
|
|
if !res.IsError || !strings.Contains(res.Content, "invalid arguments") {
|
|
t.Errorf("bad-args result = %+v", res)
|
|
}
|
|
}
|
|
|
|
// TestChainObserver: failover decisions emit events (attempt, bench, skip).
|
|
func TestChainObserver(t *testing.T) {
|
|
var events []FailoverEvent
|
|
r := newTestRegistry(t, WithChainConfig(ChainConfig{
|
|
Observer: func(ev FailoverEvent) { events = append(events, ev) },
|
|
}))
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")))
|
|
fp.Enqueue("b", fake.Reply("ok"), fake.Reply("ok"))
|
|
|
|
m, _ := r.Parse("fp/a,fp/b")
|
|
if _, err := generate(t, m); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if len(events) != 2 {
|
|
t.Fatalf("events = %+v, want 2 failed attempts", events)
|
|
}
|
|
if events[0].Target != "fp/a" || events[0].Attempt != 0 || events[0].Benched {
|
|
t.Errorf("event 0 = %+v", events[0])
|
|
}
|
|
if !events[1].Benched {
|
|
t.Errorf("event 1 = %+v, want Benched", events[1])
|
|
}
|
|
|
|
// Second request: skipped-while-benched event.
|
|
events = nil
|
|
if _, err := generate(t, m); err != nil {
|
|
t.Fatalf("Generate #2: %v", err)
|
|
}
|
|
if len(events) != 1 || !events[0].Skipped || events[0].Target != "fp/a" {
|
|
t.Errorf("events = %+v, want one skip event", events)
|
|
}
|
|
}
|
|
|
|
// TestManualBenchControls: ops surfaces can bench/unbench/inspect.
|
|
func TestManualBenchControls(t *testing.T) {
|
|
clock := newFakeClock()
|
|
r := newTestRegistry(t, WithClock(clock.Now))
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
fp.Enqueue("b", fake.Reply("from-b"))
|
|
|
|
r.Health().Bench("fp/a", clock.Now().Add(time.Hour))
|
|
if r.Health().Available("fp/a") {
|
|
t.Fatal("manual bench must take effect")
|
|
}
|
|
snap := r.Health().Snapshot()
|
|
if len(snap) != 1 || snap[0].Key != "fp/a" {
|
|
t.Errorf("snapshot = %+v", snap)
|
|
}
|
|
|
|
m, _ := r.Parse("fp/a,fp/b")
|
|
resp, err := generate(t, m)
|
|
if err != nil || resp.Text() != "from-b" {
|
|
t.Fatalf("resp=%v err=%v (benched head must be skipped)", resp, err)
|
|
}
|
|
|
|
r.Health().Unbench("fp/a")
|
|
if !r.Health().Available("fp/a") {
|
|
t.Error("unbench must re-admit")
|
|
}
|
|
if len(r.Health().Snapshot()) != 0 {
|
|
t.Error("snapshot must be empty after unbench")
|
|
}
|
|
}
|
|
|
|
// TestPromptCachingOptionIsCarried: the request flag round-trips (the
|
|
// anthropic wire mapping is asserted in its own package).
|
|
func TestPromptCachingOptionIsCarried(t *testing.T) {
|
|
r := newTestRegistry(t)
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
|
|
m, _ := r.Parse("fp/x")
|
|
if _, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}, WithPromptCaching()); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if !fp.Calls()[0].Request.PromptCache {
|
|
t.Error("PromptCache flag must reach the provider")
|
|
}
|
|
}
|
|
|
|
// TestUsageDetailAccumulation: Usage.Add sums the detail fields.
|
|
func TestUsageDetailAccumulation(t *testing.T) {
|
|
u := Usage{InputTokens: 10, OutputTokens: 5, CacheReadTokens: 4, CacheWriteTokens: 2, ReasoningTokens: 3}
|
|
u.Add(Usage{InputTokens: 1, OutputTokens: 1, CacheReadTokens: 1, CacheWriteTokens: 1, ReasoningTokens: 1})
|
|
if u.CacheReadTokens != 5 || u.CacheWriteTokens != 3 || u.ReasoningTokens != 4 {
|
|
t.Errorf("usage = %+v", u)
|
|
}
|
|
}
|