package majordomo import ( "context" "errors" "io" "net/http" "strings" "testing" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" "gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake" ) func transientErr(model string) error { return &llm.APIError{Provider: "fp", Model: model, Status: http.StatusServiceUnavailable, Message: "overloaded"} } func authErr(model string) error { return &llm.APIError{Provider: "fp", Model: model, Status: http.StatusUnauthorized, Message: "bad key"} } func notFoundErr(model string) error { return &llm.APIError{Provider: "fp", Model: model, Status: http.StatusNotFound, Message: "no such model"} } // TestChainSingleTransientRecoversViaRetry: one blip, same target succeeds // on the retry — the request never fails over. func TestChainSingleTransientRecoversViaRetry(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Reply("recovered")) m, err := r.Parse("fp/a,fp/b") if err != nil { t.Fatalf("Parse: %v", err) } resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "recovered" { t.Errorf("text = %q, want recovered (same-target retry)", resp.Text()) } if got := fp.CallCount("a"); got != 2 { t.Errorf("target a saw %d calls, want 2 (initial + retry)", got) } if got := fp.CallCount("b"); got != 0 { t.Errorf("target b saw %d calls, want 0", got) } } // TestChainRepeatedTransientFailsOver: the head exhausts its retry, gets // benched, and the chain advances to the next element. func TestChainRepeatedTransientFailsOver(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a"))) fp.Enqueue("b", fake.Reply("from-b"), fake.Reply("from-b")) m, err := r.Parse("fp/a,fp/b") if err != nil { t.Fatalf("Parse: %v", err) } resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "from-b" { t.Errorf("text = %q, want from-b", resp.Text()) } // Two consecutive transient failures hit the default threshold: the // head is now backed off and skipped on the next request. if r.Health().Available("fp/a") { t.Error("fp/a should be backed off after two consecutive transient failures") } resp2, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("again")}}) if err != nil { t.Fatalf("Generate #2: %v", err) } if resp2.Text() != "from-b" { t.Errorf("second response = %q, want from-b (head skipped)", resp2.Text()) } if got := fp.CallCount("a"); got != 2 { t.Errorf("backed-off target a saw %d calls, want 2", got) } } // TestChainPermanentAuthFailsFast: failing over cannot fix bad credentials. func TestChainPermanentAuthFailsFast(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(authErr("a"))) m, _ := r.Parse("fp/a,fp/b") _, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}) if err == nil { t.Fatal("want error") } var apiErr *llm.APIError if !errors.As(err, &apiErr) || apiErr.Status != http.StatusUnauthorized { t.Errorf("error = %v, want the 401 APIError", err) } if got := fp.CallCount("b"); got != 0 { t.Errorf("target b saw %d calls, want 0 (fail-fast)", got) } if !r.Health().Available("fp/a") { t.Error("permanent errors must not penalize health") } } // TestChainModelNotFoundAdvances: 404 advances without a health penalty. func TestChainModelNotFoundAdvances(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(notFoundErr("a"))) fp.Enqueue("b", fake.Reply("from-b")) m, _ := r.Parse("fp/a,fp/b") resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}) if err != nil { t.Fatalf("Generate: %v", err) } if resp.Text() != "from-b" { t.Errorf("text = %q, want from-b", resp.Text()) } if !r.Health().Available("fp/a") { t.Error("model-not-found must not penalize health") } } // TestChainExhaustedJoinsErrors: when everything fails the error names what // was tried and why each failed. func TestChainExhaustedJoinsErrors(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a"))) fp.Enqueue("b", fake.Fail(notFoundErr("b"))) m, _ := r.Parse("fp/a,fp/b") _, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}}) if !errors.Is(err, ErrChainExhausted) { t.Fatalf("error = %v, want ErrChainExhausted", err) } for _, frag := range []string{"fp/a", "fp/b", "overloaded", "no such model"} { if !strings.Contains(err.Error(), frag) { t.Errorf("joined error %q should mention %q", err.Error(), frag) } } } func TestChainStream(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") r.RegisterProvider(fp) fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a"))) fp.Enqueue("b", fake.Reply("streamed")) m, _ := r.Parse("fp/a,fp/b") s, err := m.Stream(context.Background(), Request{Messages: []Message{UserText("hi")}}) if err != nil { t.Fatalf("Stream: %v", err) } defer s.Close() var text string var final *Response for { ev, err := s.Next() if errors.Is(err, io.EOF) { break } if err != nil { t.Fatalf("Next: %v", err) } text += ev.TextDelta if ev.Response != nil { final = ev.Response } } if text != "streamed" { t.Errorf("streamed text = %q, want streamed", text) } if final == nil { t.Fatal("missing final response event") } } // TestChainContextCancellation: a canceled context aborts immediately. func TestChainContextCancellation(t *testing.T) { r := newTestRegistry(t) fp := fake.New("fp") r.RegisterProvider(fp) ctx, cancel := context.WithCancel(context.Background()) cancel() m, _ := r.Parse("fp/a,fp/b") _, err := m.Generate(ctx, Request{Messages: []Message{UserText("hi")}}) if !errors.Is(err, context.Canceled) { t.Errorf("error = %v, want context.Canceled", err) } }