package majordomo import ( "context" "encoding/json" "errors" "strings" "testing" "time" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" "gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake" ) // TestRegisterResolver: dynamic tiers resolve after static aliases, expand // recursively, and respect cycle detection. func TestRegisterResolver(t *testing.T) { r := newTestRegistry(t) r.RegisterProvider(fake.New("fp")) r.RegisterAlias("static", "fp/a") tiers := map[string]string{ "db-tier": "fp/x,fp/y", "db-nested": "db-tier,fp/z", "db-cycle": "db-cycle", } r.RegisterResolver(ResolverFunc(func(name string) (string, bool) { spec, ok := tiers[name] return spec, ok })) m, err := r.Parse("static,db-nested") if err != nil { t.Fatalf("Parse: %v", err) } want := []string{"fp/a", "fp/x", "fp/y", "fp/z"} if got := targetsOf(t, m); strings.Join(got, ",") != strings.Join(want, ",") { t.Errorf("targets = %v, want %v", got, want) } if _, err := r.Parse("db-cycle"); !errors.Is(err, ErrAliasCycle) { t.Errorf("cycle error = %v, want ErrAliasCycle", err) } // Static aliases shadow resolvers. tiers["static"] = "fp/wrong" m, _ = r.Parse("static") if got := targetsOf(t, m); got[0] != "fp/a" { t.Errorf("static alias must win over resolver, got %v", got) } } // TestDefineTool: schema from Args, decoded handler arguments. func TestDefineTool(t *testing.T) { type addArgs struct { A int `json:"a" description:"first addend"` B int `json:"b"` } tool := DefineTool("add", "Add two ints", func(_ context.Context, args addArgs) (any, error) { return args.A + args.B, nil }) var schema map[string]any if err := json.Unmarshal(tool.Parameters, &schema); err != nil { t.Fatalf("schema: %v", err) } props := schema["properties"].(map[string]any) if props["a"].(map[string]any)["description"] != "first addend" { t.Errorf("schema = %v", schema) } res := llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "1", Name: "add", Arguments: json.RawMessage(`{"a":2,"b":40}`)}) if res.IsError || res.Content != "42" { t.Errorf("result = %+v", res) } res = llm.ExecuteTool(context.Background(), tool, ToolCall{ID: "2", Name: "add", Arguments: json.RawMessage(`{"a":"nope"}`)}) if !res.IsError || !strings.Contains(res.Content, "invalid arguments") { t.Errorf("bad-args result = %+v", res) } } // TestChainObserver: failover decisions emit events (attempt, bench, skip). func TestChainObserver(t *testing.T) { var events []FailoverEvent r := newTestRegistry(t, WithChainConfig(ChainConfig{ Observer: func(ev FailoverEvent) { events = append(events, ev) }, })) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a"))) fp.Enqueue("b", fake.Reply("ok"), fake.Reply("ok")) m, _ := r.Parse("fp/a,fp/b") if _, err := generate(t, m); err != nil { t.Fatalf("Generate: %v", err) } if len(events) != 2 { t.Fatalf("events = %+v, want 2 failed attempts", events) } if events[0].Target != "fp/a" || events[0].Attempt != 0 || events[0].Benched { t.Errorf("event 0 = %+v", events[0]) } if !events[1].Benched { t.Errorf("event 1 = %+v, want Benched", events[1]) } // Second request: skipped-while-benched event. events = nil if _, err := generate(t, m); err != nil { t.Fatalf("Generate #2: %v", err) } if len(events) != 1 || !events[0].Skipped || events[0].Target != "fp/a" { t.Errorf("events = %+v, want one skip event", events) } } // TestManualBenchControls: ops surfaces can bench/unbench/inspect. func TestManualBenchControls(t *testing.T) { clock := newFakeClock() r := newTestRegistry(t, WithClock(clock.Now)) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("b", fake.Reply("from-b")) r.Health().Bench("fp/a", clock.Now().Add(time.Hour)) if r.Health().Available("fp/a") { t.Fatal("manual bench must take effect") } snap := r.Health().Snapshot() if len(snap) != 1 || snap[0].Key != "fp/a" { t.Errorf("snapshot = %+v", snap) } m, _ := r.Parse("fp/a,fp/b") resp, err := generate(t, m) if err != nil || resp.Text() != "from-b" { t.Fatalf("resp=%v err=%v (benched head must be skipped)", resp, err) } r.Health().Unbench("fp/a") if !r.Health().Available("fp/a") { t.Error("unbench must re-admit") } if len(r.Health().Snapshot()) != 0 { t.Error("snapshot must be empty after unbench") } } // TestPromptCachingOptionIsCarried: the request flag round-trips (the // anthropic wire mapping is asserted in its own package). func TestPromptCachingOptionIsCarried(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") r.RegisterProvider(fp) m, _ := r.Parse("fp/x") if _, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}, WithPromptCaching()); err != nil { t.Fatalf("Generate: %v", err) } if !fp.Calls()[0].Request.PromptCache { t.Error("PromptCache flag must reach the provider") } } // TestUsageDetailAccumulation: Usage.Add sums the detail fields. func TestUsageDetailAccumulation(t *testing.T) { u := Usage{InputTokens: 10, OutputTokens: 5, CacheReadTokens: 4, CacheWriteTokens: 2, ReasoningTokens: 3} u.Add(Usage{InputTokens: 1, OutputTokens: 1, CacheReadTokens: 1, CacheWriteTokens: 1, ReasoningTokens: 1}) if u.CacheReadTokens != 5 || u.CacheWriteTokens != 3 || u.ReasoningTokens != 4 { t.Errorf("usage = %+v", u) } }