package majordomo import ( "bytes" "context" "errors" "image" "image/color" "image/png" "slices" "strings" "sync" "testing" "time" "gitea.stevedudenhoeffer.com/steve/majordomo/health" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" "gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake" ) // fakeClock is a manually-advanced clock shared with the registry's health // tracker for deterministic backoff tests. type fakeClock struct { mu sync.Mutex now time.Time } func newFakeClock() *fakeClock { return &fakeClock{now: time.Date(2026, 6, 10, 12, 0, 0, 0, time.UTC)} } func (c *fakeClock) Now() time.Time { c.mu.Lock() defer c.mu.Unlock() return c.now } func (c *fakeClock) Advance(d time.Duration) { c.mu.Lock() defer c.mu.Unlock() c.now = c.now.Add(d) } func generate(t *testing.T, m Model) (*Response, error) { t.Helper() return m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}) } // TestCooldownExpiryReadmitsTarget: a benched head is skipped during its // cooldown and tried again after the clock passes it. func TestCooldownExpiryReadmitsTarget(t *testing.T) { clock := newFakeClock() r := newTestRegistry(t, WithClock(clock.Now)) fp := fake.New("fp") r.RegisterProvider(fp) // Bench the head: two consecutive transient failures. fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a"))) fp.Enqueue("b", fake.Reply("b1"), fake.Reply("b2")) fp.Enqueue("a", fake.Reply("a-recovered")) // served only after re-admission m, err := r.Parse("fp/a,fp/b") if err != nil { t.Fatalf("Parse: %v", err) } if resp, err := generate(t, m); err != nil || resp.Text() != "b1" { t.Fatalf("first request: resp=%v err=%v, want b1", resp, err) } // Inside the 5s cooldown: head must be skipped without being called. clock.Advance(4 * time.Second) callsBefore := fp.CallCount("a") if resp, _ := generate(t, m); resp.Text() != "b2" { t.Fatalf("during cooldown: got %q, want b2", resp.Text()) } if fp.CallCount("a") != callsBefore { t.Error("benched target must not be called during cooldown") } // Past the cooldown: head is re-admitted and serves the request. clock.Advance(2 * time.Second) resp, err := generate(t, m) if err != nil { t.Fatalf("after cooldown: %v", err) } if resp.Text() != "a-recovered" { t.Errorf("after cooldown: got %q, want a-recovered (re-admitted head)", resp.Text()) } // The success reset health: no residual bench. if !r.Health().Available("fp/a") { t.Error("success must reset the target's health") } } // TestBackoffGrowsAcrossBenches: the second bench (without an intervening // success) uses the doubled cooldown. func TestBackoffGrowsAcrossBenches(t *testing.T) { clock := newFakeClock() r := newTestRegistry(t, WithClock(clock.Now)) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")), // bench #1 (5s) fake.Fail(transientErr("a")), fake.Fail(transientErr("a")), // bench #2 (10s) ) m, _ := r.Parse("fp/a,fp/b") generate(t, m) // bench #1, served by b clock.Advance(5 * time.Second) generate(t, m) // re-admitted, fails twice more → bench #2 until := r.Health().BackedOffUntil("fp/a") if got, want := until.Sub(clock.Now()), 10*time.Second; got != want { t.Errorf("second bench cooldown = %v, want %v", got, want) } } // TestChainWithInlineAliasElementFailsOver: a chain whose middle element is // a registered alias expands inline, and failover walks through the // expanded targets exactly like literal ones. func TestChainWithInlineAliasElementFailsOver(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") oc := fake.New("ollama-cloud") r.RegisterProvider(fp) r.RegisterProvider(oc) r.RegisterAlias("thinking", "ollama-cloud/minimax-m3:cloud,ollama-cloud/kimi-k2.6:cloud") // Head fails hard enough to bench; first alias target 404s; second // alias target answers. fp.Enqueue("head", fake.Fail(transientErr("head")), fake.Fail(transientErr("head"))) oc.Enqueue("minimax-m3:cloud", fake.Fail(notFoundErr("minimax-m3:cloud"))) oc.Enqueue("kimi-k2.6:cloud", fake.Reply("from-kimi")) m, err := r.Parse("fp/head,thinking") if err != nil { t.Fatalf("Parse: %v", err) } want := []string{"fp/head", "ollama-cloud/minimax-m3:cloud", "ollama-cloud/kimi-k2.6:cloud"} if got := targetsOf(t, m); !slices.Equal(got, want) { t.Fatalf("targets = %v, want %v", got, want) } resp, err := generate(t, m) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "from-kimi" { t.Errorf("text = %q, want from-kimi", resp.Text()) } if resp.Model != "ollama-cloud/kimi-k2.6:cloud" { t.Errorf("resp.Model = %q, want the serving target", resp.Model) } } // TestAdvanceOnPermanentPolicy: with the policy flipped, auth errors // advance the chain instead of failing fast — and don't penalize health. func TestAdvanceOnPermanentPolicy(t *testing.T) { r := newTestRegistry(t, WithChainConfig(ChainConfig{AdvanceOnPermanent: true})) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(authErr("a"))) fp.Enqueue("b", fake.Reply("from-b")) m, _ := r.Parse("fp/a,fp/b") resp, err := generate(t, m) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "from-b" { t.Errorf("text = %q, want from-b", resp.Text()) } if !r.Health().Available("fp/a") { t.Error("permanent errors must not penalize health even when advancing") } } // TestTransientRetriesConfig: negative disables same-target retries; a // custom count is honored. func TestTransientRetriesConfig(t *testing.T) { t.Run("disabled", func(t *testing.T) { r := newTestRegistry(t, WithChainConfig(ChainConfig{TransientRetries: -1})) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a"))) fp.Enqueue("b", fake.Reply("from-b")) m, _ := r.Parse("fp/a,fp/b") if resp, err := generate(t, m); err != nil || resp.Text() != "from-b" { t.Fatalf("resp=%v err=%v", resp, err) } if got := fp.CallCount("a"); got != 1 { t.Errorf("target a saw %d calls, want 1 (retries disabled)", got) } }) t.Run("custom count with higher threshold", func(t *testing.T) { r := newTestRegistry(t, WithChainConfig(ChainConfig{TransientRetries: 3}), WithHealthConfig(health.Config{FailureThreshold: 10}), ) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")), fake.Fail(transientErr("a")), fake.Reply("fourth-attempt")) m, _ := r.Parse("fp/a") resp, err := generate(t, m) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "fourth-attempt" { t.Errorf("text = %q, want fourth-attempt", resp.Text()) } if got := fp.CallCount("a"); got != 4 { t.Errorf("target a saw %d calls, want 4 (1 + 3 retries)", got) } }) } // TestRetryStopsWhenBenchedMidRequest: with default threshold 2 and a // custom retry budget of 5, the second failed attempt benches the target // and the chain advances instead of burning the remaining retries. func TestRetryStopsWhenBenchedMidRequest(t *testing.T) { r := newTestRegistry(t, WithChainConfig(ChainConfig{TransientRetries: 5})) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")), fake.Reply("should-never-be-reached")) fp.Enqueue("b", fake.Reply("from-b")) m, _ := r.Parse("fp/a,fp/b") resp, err := generate(t, m) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "from-b" { t.Errorf("text = %q, want from-b", resp.Text()) } if got := fp.CallCount("a"); got != 2 { t.Errorf("target a saw %d calls, want 2 (benched mid-request)", got) } } // TestExhaustionListsSkippedTargets: benched-and-skipped targets appear in // the exhaustion error alongside fresh failures. func TestExhaustionListsSkippedTargets(t *testing.T) { clock := newFakeClock() r := newTestRegistry(t, WithClock(clock.Now)) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a"))) fp.Enqueue("b", fake.Fail(transientErr("b")), fake.Fail(transientErr("b"))) m, _ := r.Parse("fp/a,fp/b") if _, err := generate(t, m); !errors.Is(err, ErrChainExhausted) { t.Fatalf("first pass error = %v, want exhaustion", err) } // Both targets are now benched; the next request fails by skipping. _, err := generate(t, m) if !errors.Is(err, ErrChainExhausted) { t.Fatalf("error = %v, want ErrChainExhausted", err) } for _, frag := range []string{"fp/a", "fp/b", "skipped", "backed off"} { if !strings.Contains(err.Error(), frag) { t.Errorf("error %q should mention %q", err.Error(), frag) } } // No fresh calls were made. if fp.CallCount("a") != 2 || fp.CallCount("b") != 2 { t.Error("skipped targets must not receive calls") } } // TestCustomClassifier: a classifier override changes failover decisions. func TestCustomClassifier(t *testing.T) { sentinel := errors.New("totally-fine-actually") r := newTestRegistry(t, WithChainConfig(ChainConfig{ Classify: func(err error) llm.ErrorClass { if errors.Is(err, sentinel) { return llm.ClassPermanent } return llm.Classify(err) }, })) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(sentinel)) m, _ := r.Parse("fp/a,fp/b") _, err := generate(t, m) if !errors.Is(err, sentinel) { t.Errorf("error = %v, want fail-fast on custom-permanent sentinel", err) } if got := fp.CallCount("b"); got != 0 { t.Errorf("target b saw %d calls, want 0", got) } } // TestSingleTargetGetsChainSemantics: a chain of one retries transients and // benches itself exactly like a multi-element chain. func TestSingleTargetGetsChainSemantics(t *testing.T) { clock := newFakeClock() r := newTestRegistry(t, WithClock(clock.Now)) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("only", fake.Fail(transientErr("only")), fake.Reply("recovered")) m, _ := r.Parse("fp/only") resp, err := generate(t, m) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "recovered" { t.Errorf("text = %q, want recovered", resp.Text()) } // Now exhaust it: two failures bench the lone target. fp.Enqueue("only", fake.Fail(transientErr("only")), fake.Fail(transientErr("only"))) if _, err := generate(t, m); !errors.Is(err, ErrChainExhausted) { t.Fatalf("error = %v, want ErrChainExhausted", err) } if r.Health().Available("fp/only") { t.Error("lone target should be benched after repeated failures") } } // pngImage encodes a width×height PNG for media tests. func pngImage(t *testing.T, width, height int) []byte { t.Helper() img := image.NewRGBA(image.Rect(0, 0, width, height)) for y := range height { for x := range width { img.Set(x, y, color.RGBA{R: uint8(x), G: uint8(y), B: 128, A: 255}) } } var buf bytes.Buffer if err := png.Encode(&buf, img); err != nil { t.Fatalf("encode png: %v", err) } return buf.Bytes() } // TestChainNormalizesMediaPerTarget: the request's image is downscaled to // the capabilities of the target that actually serves it. func TestChainNormalizesMediaPerTarget(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp", fake.WithModelCapabilities("small-vision", llm.Capabilities{ MaxImagesPerReq: 2, MaxImageDimension: 32, AllowedImageMIME: []string{"image/png"}, }), ) r.RegisterProvider(fp) m, _ := r.Parse("fp/small-vision") _, err := m.Generate(context.Background(), Request{Messages: []Message{ UserParts(Text("describe"), Image("image/png", pngImage(t, 100, 50))), }}) if err != nil { t.Fatalf("Generate: %v", err) } calls := fp.Calls() if len(calls) != 1 { t.Fatalf("calls = %d", len(calls)) } var img llm.ImagePart for _, part := range calls[0].Request.Messages[0].Parts { if ip, ok := part.(llm.ImagePart); ok { img = ip } } if img.Data == nil { t.Fatal("no image reached the provider") } cfg, err := png.DecodeConfig(bytes.NewReader(img.Data)) if err != nil { t.Fatalf("decode delivered image: %v", err) } if cfg.Width != 32 || cfg.Height != 16 { t.Errorf("delivered image = %dx%d, want 32x16 (downscaled to target cap)", cfg.Width, cfg.Height) } } // TestChainAdvancesPastImagelessTarget: a text-only head can't take an // image request; the chain advances to a vision-capable element with no // health penalty. func TestChainAdvancesPastImagelessTarget(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp", fake.WithModelCapabilities("text-only", llm.Capabilities{SupportsTools: true}), fake.WithModelCapabilities("vision", llm.Capabilities{MaxImagesPerReq: 4}), ) r.RegisterProvider(fp) fp.Enqueue("vision", fake.Reply("a tasteful png")) m, _ := r.Parse("fp/text-only,fp/vision") resp, err := m.Generate(context.Background(), Request{Messages: []Message{ UserParts(Text("what is this?"), Image("image/png", pngImage(t, 8, 8))), }}) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "a tasteful png" { t.Errorf("text = %q", resp.Text()) } if got := fp.CallCount("text-only"); got != 0 { t.Errorf("text-only target saw %d calls, want 0 (normalization rejects pre-send)", got) } if !r.Health().Available("fp/text-only") { t.Error("media rejection must not penalize health") } } // TestHTTP529ClassifiedTransient: Anthropic's "overloaded" status fails // over like any other transient error. func TestHTTP529FailsOver(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") r.RegisterProvider(fp) overloaded := &llm.APIError{Provider: "fp", Model: "a", Status: 529, Message: "overloaded_error"} fp.Enqueue("a", fake.Fail(overloaded), fake.Fail(overloaded)) fp.Enqueue("b", fake.Reply("from-b")) m, _ := r.Parse("fp/a,fp/b") if resp, err := generate(t, m); err != nil || resp.Text() != "from-b" { t.Fatalf("resp=%v err=%v, want from-b", resp, err) } }