dcd004289f
Phase 1 of the majordomo build: - llm/ canonical contract (messages, parts, tools, capabilities, streaming, Model/Provider, error classification) - health/ clock-injected tracker (threshold bench, exponential capped cooldown, reset-on-success) - root Registry + Parse (verbatim model ids, inline recursive alias expansion with cycle detection, chain dedup), LLM_* env-DSN providers (go-llm parity: lazy fallback + eager LoadEnv), health-aware chain executor behind the Model interface - provider/fake scriptable test provider; hermetic test suite incl. the trailing-thinking chain and foreman:// env loading - ADRs 0001-0008, CLAUDE.md, README (honest matrix), CI workflow, docs/phase-1-design.md Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
208 lines
6.1 KiB
Go
208 lines
6.1 KiB
Go
package majordomo
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
|
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/fake"
|
|
)
|
|
|
|
func transientErr(model string) error {
|
|
return &llm.APIError{Provider: "fp", Model: model, Status: http.StatusServiceUnavailable, Message: "overloaded"}
|
|
}
|
|
|
|
func authErr(model string) error {
|
|
return &llm.APIError{Provider: "fp", Model: model, Status: http.StatusUnauthorized, Message: "bad key"}
|
|
}
|
|
|
|
func notFoundErr(model string) error {
|
|
return &llm.APIError{Provider: "fp", Model: model, Status: http.StatusNotFound, Message: "no such model"}
|
|
}
|
|
|
|
// TestChainSingleTransientRecoversViaRetry: one blip, same target succeeds
|
|
// on the retry — the request never fails over.
|
|
func TestChainSingleTransientRecoversViaRetry(t *testing.T) {
|
|
r := newTestRegistry(t)
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Reply("recovered"))
|
|
|
|
m, err := r.Parse("fp/a,fp/b")
|
|
if err != nil {
|
|
t.Fatalf("Parse: %v", err)
|
|
}
|
|
resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if resp.Text() != "recovered" {
|
|
t.Errorf("text = %q, want recovered (same-target retry)", resp.Text())
|
|
}
|
|
if got := fp.CallCount("a"); got != 2 {
|
|
t.Errorf("target a saw %d calls, want 2 (initial + retry)", got)
|
|
}
|
|
if got := fp.CallCount("b"); got != 0 {
|
|
t.Errorf("target b saw %d calls, want 0", got)
|
|
}
|
|
}
|
|
|
|
// TestChainRepeatedTransientFailsOver: the head exhausts its retry, gets
|
|
// benched, and the chain advances to the next element.
|
|
func TestChainRepeatedTransientFailsOver(t *testing.T) {
|
|
r := newTestRegistry(t)
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")))
|
|
fp.Enqueue("b", fake.Reply("from-b"), fake.Reply("from-b"))
|
|
|
|
m, err := r.Parse("fp/a,fp/b")
|
|
if err != nil {
|
|
t.Fatalf("Parse: %v", err)
|
|
}
|
|
resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if resp.Text() != "from-b" {
|
|
t.Errorf("text = %q, want from-b", resp.Text())
|
|
}
|
|
// Two consecutive transient failures hit the default threshold: the
|
|
// head is now backed off and skipped on the next request.
|
|
if r.Health().Available("fp/a") {
|
|
t.Error("fp/a should be backed off after two consecutive transient failures")
|
|
}
|
|
resp2, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("again")}})
|
|
if err != nil {
|
|
t.Fatalf("Generate #2: %v", err)
|
|
}
|
|
if resp2.Text() != "from-b" {
|
|
t.Errorf("second response = %q, want from-b (head skipped)", resp2.Text())
|
|
}
|
|
if got := fp.CallCount("a"); got != 2 {
|
|
t.Errorf("backed-off target a saw %d calls, want 2", got)
|
|
}
|
|
}
|
|
|
|
// TestChainPermanentAuthFailsFast: failing over cannot fix bad credentials.
|
|
func TestChainPermanentAuthFailsFast(t *testing.T) {
|
|
r := newTestRegistry(t)
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
fp.Enqueue("a", fake.Fail(authErr("a")))
|
|
|
|
m, _ := r.Parse("fp/a,fp/b")
|
|
_, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
|
|
if err == nil {
|
|
t.Fatal("want error")
|
|
}
|
|
var apiErr *llm.APIError
|
|
if !errors.As(err, &apiErr) || apiErr.Status != http.StatusUnauthorized {
|
|
t.Errorf("error = %v, want the 401 APIError", err)
|
|
}
|
|
if got := fp.CallCount("b"); got != 0 {
|
|
t.Errorf("target b saw %d calls, want 0 (fail-fast)", got)
|
|
}
|
|
if !r.Health().Available("fp/a") {
|
|
t.Error("permanent errors must not penalize health")
|
|
}
|
|
}
|
|
|
|
// TestChainModelNotFoundAdvances: 404 advances without a health penalty.
|
|
func TestChainModelNotFoundAdvances(t *testing.T) {
|
|
r := newTestRegistry(t)
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
fp.Enqueue("a", fake.Fail(notFoundErr("a")))
|
|
fp.Enqueue("b", fake.Reply("from-b"))
|
|
|
|
m, _ := r.Parse("fp/a,fp/b")
|
|
resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if resp.Text() != "from-b" {
|
|
t.Errorf("text = %q, want from-b", resp.Text())
|
|
}
|
|
if !r.Health().Available("fp/a") {
|
|
t.Error("model-not-found must not penalize health")
|
|
}
|
|
}
|
|
|
|
// TestChainExhaustedJoinsErrors: when everything fails the error names what
|
|
// was tried and why each failed.
|
|
func TestChainExhaustedJoinsErrors(t *testing.T) {
|
|
r := newTestRegistry(t)
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")))
|
|
fp.Enqueue("b", fake.Fail(notFoundErr("b")))
|
|
|
|
m, _ := r.Parse("fp/a,fp/b")
|
|
_, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hi")}})
|
|
if !errors.Is(err, ErrChainExhausted) {
|
|
t.Fatalf("error = %v, want ErrChainExhausted", err)
|
|
}
|
|
for _, frag := range []string{"fp/a", "fp/b", "overloaded", "no such model"} {
|
|
if !strings.Contains(err.Error(), frag) {
|
|
t.Errorf("joined error %q should mention %q", err.Error(), frag)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestChainStream(t *testing.T) {
|
|
r := newTestRegistry(t)
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
fp.Enqueue("a", fake.Fail(transientErr("a")), fake.Fail(transientErr("a")))
|
|
fp.Enqueue("b", fake.Reply("streamed"))
|
|
|
|
m, _ := r.Parse("fp/a,fp/b")
|
|
s, err := m.Stream(context.Background(), Request{Messages: []Message{UserText("hi")}})
|
|
if err != nil {
|
|
t.Fatalf("Stream: %v", err)
|
|
}
|
|
defer s.Close()
|
|
|
|
var text string
|
|
var final *Response
|
|
for {
|
|
ev, err := s.Next()
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Next: %v", err)
|
|
}
|
|
text += ev.TextDelta
|
|
if ev.Response != nil {
|
|
final = ev.Response
|
|
}
|
|
}
|
|
if text != "streamed" {
|
|
t.Errorf("streamed text = %q, want streamed", text)
|
|
}
|
|
if final == nil {
|
|
t.Fatal("missing final response event")
|
|
}
|
|
}
|
|
|
|
// TestChainContextCancellation: a canceled context aborts immediately.
|
|
func TestChainContextCancellation(t *testing.T) {
|
|
r := newTestRegistry(t)
|
|
fp := fake.New("fp")
|
|
r.RegisterProvider(fp)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
m, _ := r.Parse("fp/a,fp/b")
|
|
_, err := m.Generate(ctx, Request{Messages: []Message{UserText("hi")}})
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Errorf("error = %v, want context.Canceled", err)
|
|
}
|
|
}
|