Files
majordomo/chain_test.go
T
steve dcd004289f feat: foundations — canonical types, Parse grammar, env DSNs, health, chains
Phase 1 of the majordomo build:
- llm/ canonical contract (messages, parts, tools, capabilities, streaming,
  Model/Provider, error classification)
- health/ clock-injected tracker (threshold bench, exponential capped
  cooldown, reset-on-success)
- root Registry + Parse (verbatim model ids, inline recursive alias
  expansion with cycle detection, chain dedup), LLM_* env-DSN providers
  (go-llm parity: lazy fallback + eager LoadEnv), health-aware chain
  executor behind the Model interface
- provider/fake scriptable test provider; hermetic test suite incl. the
  trailing-thinking chain and foreman:// env loading
- ADRs 0001-0008, CLAUDE.md, README (honest matrix), CI workflow,
  docs/phase-1-design.md

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-10 12:35:34 +02:00

208 lines
6.1 KiB
Go

package majordomo
import (
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake"
)
func transientErr(model string) error {
return &llm.APIError{Provider: "fp", Model: model, Status: http.StatusServiceUnavailable, Message: "overloaded"}
}
func authErr(model string) error {
return &llm.APIError{Provider: "fp", Model: model, Status: http.StatusUnauthorized, Message: "bad key"}
}
func notFoundErr(model string) error {
return &llm.APIError{Provider: "fp", Model: model, Status: http.StatusNotFound, Message: "no such model"}
}
// TestChainSingleTransientRecoversViaRetry: one blip, same target succeeds
// on the retry — the request never fails over.
func TestChainSingleTransientRecoversViaRetry(t *testing.T) {
r := newTestRegistry(t)
fp := fake.New("fp")
r.RegisterProvider(fp)
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Reply("recovered"))
m, err := r.Parse("fp/a,fp/b")
if err != nil {
t.Fatalf("Parse: %v", err)
}
resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
if err != nil {
t.Fatalf("Generate: %v", err)
}
if resp.Text() != "recovered" {
t.Errorf("text = %q, want recovered (same-target retry)", resp.Text())
}
if got := fp.CallCount("a"); got != 2 {
t.Errorf("target a saw %d calls, want 2 (initial + retry)", got)
}
if got := fp.CallCount("b"); got != 0 {
t.Errorf("target b saw %d calls, want 0", got)
}
}
// TestChainRepeatedTransientFailsOver: the head exhausts its retry, gets
// benched, and the chain advances to the next element.
func TestChainRepeatedTransientFailsOver(t *testing.T) {
r := newTestRegistry(t)
fp := fake.New("fp")
r.RegisterProvider(fp)
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")))
fp.Enqueue("b", fake.Reply("from-b"), fake.Reply("from-b"))
m, err := r.Parse("fp/a,fp/b")
if err != nil {
t.Fatalf("Parse: %v", err)
}
resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
if err != nil {
t.Fatalf("Generate: %v", err)
}
if resp.Text() != "from-b" {
t.Errorf("text = %q, want from-b", resp.Text())
}
// Two consecutive transient failures hit the default threshold: the
// head is now backed off and skipped on the next request.
if r.Health().Available("fp/a") {
t.Error("fp/a should be backed off after two consecutive transient failures")
}
resp2, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("again")}})
if err != nil {
t.Fatalf("Generate #2: %v", err)
}
if resp2.Text() != "from-b" {
t.Errorf("second response = %q, want from-b (head skipped)", resp2.Text())
}
if got := fp.CallCount("a"); got != 2 {
t.Errorf("backed-off target a saw %d calls, want 2", got)
}
}
// TestChainPermanentAuthFailsFast: failing over cannot fix bad credentials.
func TestChainPermanentAuthFailsFast(t *testing.T) {
r := newTestRegistry(t)
fp := fake.New("fp")
r.RegisterProvider(fp)
fp.Enqueue("a", fake.Fail(authErr("a")))
m, _ := r.Parse("fp/a,fp/b")
_, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
if err == nil {
t.Fatal("want error")
}
var apiErr *llm.APIError
if !errors.As(err, &apiErr) || apiErr.Status != http.StatusUnauthorized {
t.Errorf("error = %v, want the 401 APIError", err)
}
if got := fp.CallCount("b"); got != 0 {
t.Errorf("target b saw %d calls, want 0 (fail-fast)", got)
}
if !r.Health().Available("fp/a") {
t.Error("permanent errors must not penalize health")
}
}
// TestChainModelNotFoundAdvances: 404 advances without a health penalty.
func TestChainModelNotFoundAdvances(t *testing.T) {
r := newTestRegistry(t)
fp := fake.New("fp")
r.RegisterProvider(fp)
fp.Enqueue("a", fake.Fail(notFoundErr("a")))
fp.Enqueue("b", fake.Reply("from-b"))
m, _ := r.Parse("fp/a,fp/b")
resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
if err != nil {
t.Fatalf("Generate: %v", err)
}
if resp.Text() != "from-b" {
t.Errorf("text = %q, want from-b", resp.Text())
}
if !r.Health().Available("fp/a") {
t.Error("model-not-found must not penalize health")
}
}
// TestChainExhaustedJoinsErrors: when everything fails the error names what
// was tried and why each failed.
func TestChainExhaustedJoinsErrors(t *testing.T) {
r := newTestRegistry(t)
fp := fake.New("fp")
r.RegisterProvider(fp)
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")))
fp.Enqueue("b", fake.Fail(notFoundErr("b")))
m, _ := r.Parse("fp/a,fp/b")
_, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
if !errors.Is(err, ErrChainExhausted) {
t.Fatalf("error = %v, want ErrChainExhausted", err)
}
for _, frag := range []string{"fp/a", "fp/b", "overloaded", "no such model"} {
if !strings.Contains(err.Error(), frag) {
t.Errorf("joined error %q should mention %q", err.Error(), frag)
}
}
}
func TestChainStream(t *testing.T) {
r := newTestRegistry(t)
fp := fake.New("fp")
r.RegisterProvider(fp)
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")))
fp.Enqueue("b", fake.Reply("streamed"))
m, _ := r.Parse("fp/a,fp/b")
s, err := m.Stream(context.Background(), Request{Messages: []Message{UserText("hi")}})
if err != nil {
t.Fatalf("Stream: %v", err)
}
defer s.Close()
var text string
var final *Response
for {
ev, err := s.Next()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
t.Fatalf("Next: %v", err)
}
text += ev.TextDelta
if ev.Response != nil {
final = ev.Response
}
}
if text != "streamed" {
t.Errorf("streamed text = %q, want streamed", text)
}
if final == nil {
t.Fatal("missing final response event")
}
}
// TestChainContextCancellation: a canceled context aborts immediately.
func TestChainContextCancellation(t *testing.T) {
r := newTestRegistry(t)
fp := fake.New("fp")
r.RegisterProvider(fp)
ctx, cancel := context.WithCancel(context.Background())
cancel()
m, _ := r.Parse("fp/a,fp/b")
_, err := m.Generate(ctx, Request{Messages: []Message{UserText("hi")}})
if !errors.Is(err, context.Canceled) {
t.Errorf("error = %v, want context.Canceled", err)
}
}