package llm import ( "context" "errors" "strings" "sync" "testing" "time" "github.com/openai/openai-go" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) // fastBackoff is a near-zero backoff so retry tests don't sleep. func fastBackoff(int) time.Duration { return time.Microsecond } func testFailoverOpts(extra ...FailoverOption) []FailoverOption { base := []FailoverOption{ WithFailoverMaxRetries(2), WithFailoverBackoff(fastBackoff), WithFailoverCooldown(time.Minute), } return append(base, extra...) } // modelFor builds a *Model around a mock provider with a concrete model name, // mimicking what Parse produces (so specKey resolution works). func modelFor(p provider.Provider, name string) *Model { return &Model{provider: p, model: name} } func TestFailover_FirstSucceeds(t *testing.T) { resetHealthForTest() a := newMockProvider(provider.Response{Text: "from-a"}) b := newMockProvider(provider.Response{Text: "from-b"}) fo := NewFailoverModel([]*Model{modelFor(a, "anthropic/a"), modelFor(b, "openai/b")}, testFailoverOpts()...) resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "from-a" { t.Errorf("expected from-a, got %q", resp.Text) } // b must not have been called. b.mu.Lock() n := len(b.Requests) b.mu.Unlock() if n != 0 { t.Errorf("expected b untouched, got %d calls", n) } } func TestFailover_FailsOverToSecond(t *testing.T) { resetHealthForTest() // a always returns a request-specific error (400) -> fail over, no retry-bench loop noise. a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, &openai.Error{StatusCode: 400} }) b := newMockProvider(provider.Response{Text: "from-b"}) fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "from-b" { t.Errorf("expected from-b, got %q", resp.Text) } // 400 is request-specific: a must NOT be benched. if IsBenched("p/a") { t.Error("p/a should not be benched on a 400") } } func TestFailover_PassesModelNameToProvider(t *testing.T) { resetHealthForTest() a := newMockProvider(provider.Response{Text: "ok"}) fo := NewFailoverModel([]*Model{modelFor(a, "anthropic/claude-x")}, testFailoverOpts()...) _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatal(err) } if got := a.lastRequest().Model; got != "claude-x" { t.Errorf("provider received model %q, want bare model name claude-x", got) } } func TestFailover_AuthDeadBenchesImmediately(t *testing.T) { resetHealthForTest() a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, &openai.Error{StatusCode: 401} }) b := newMockProvider(provider.Response{Text: "from-b"}) fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatal(err) } if resp.Text != "from-b" { t.Errorf("expected from-b, got %q", resp.Text) } if !IsBenched("p/a") { t.Error("p/a should be benched after auth-dead error") } // a should have been called exactly once (no retries on auth-dead). a.mu.Lock() n := len(a.Requests) a.mu.Unlock() if n != 1 { t.Errorf("auth-dead should not retry; a called %d times", n) } } func TestFailover_TransientRetriesThenBenches(t *testing.T) { resetHealthForTest() var calls int var mu sync.Mutex a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { mu.Lock() calls++ mu.Unlock() return provider.Response{}, &openai.Error{StatusCode: 503} }) b := newMockProvider(provider.Response{Text: "from-b"}) // maxRetries=2 means 2 attempts total per entry. fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, WithFailoverMaxRetries(2), WithFailoverBackoff(fastBackoff), WithFailoverCooldown(time.Minute)) resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatal(err) } if resp.Text != "from-b" { t.Errorf("expected from-b, got %q", resp.Text) } mu.Lock() n := calls mu.Unlock() if n != 2 { t.Errorf("expected 2 attempts on transient model, got %d", n) } if !IsBenched("p/a") { t.Error("p/a should be benched after exhausting retries") } } func TestFailover_AllFail(t *testing.T) { resetHealthForTest() a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, &openai.Error{StatusCode: 400} }) b := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, &openai.Error{StatusCode: 400} }) fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err == nil { t.Fatal("expected error when all models fail") } if !strings.Contains(err.Error(), "2") { t.Errorf("error should mention all 2 models failed: %v", err) } } func TestFailover_ContextCanceledAborts(t *testing.T) { resetHealthForTest() a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, context.Canceled }) b := newMockProvider(provider.Response{Text: "from-b"}) fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if !errors.Is(err, context.Canceled) { t.Errorf("expected context.Canceled to abort, got %v", err) } // b must not be tried. b.mu.Lock() n := len(b.Requests) b.mu.Unlock() if n != 0 { t.Errorf("canceled should not fail over; b called %d times", n) } if IsBenched("p/a") { t.Error("canceled should not bench") } } func TestFailover_AllBenchedBestEffort(t *testing.T) { resetHealthForTest() // Manually bench both, then ensure Complete still tries (best-effort) and succeeds. BenchModel("p/a", time.Now().Add(time.Hour)) BenchModel("p/b", time.Now().Add(time.Hour)) a := newMockProvider(provider.Response{Text: "from-a"}) b := newMockProvider(provider.Response{Text: "from-b"}) fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatalf("best-effort should still try benched models: %v", err) } if resp.Text != "from-a" { t.Errorf("expected from-a, got %q", resp.Text) } } func TestFailover_Observer(t *testing.T) { resetHealthForTest() a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, &openai.Error{StatusCode: 401} }) b := newMockProvider(provider.Response{Text: "from-b"}) var mu sync.Mutex var events []FailoverEvent obs := func(ctx context.Context, ev FailoverEvent) { mu.Lock() events = append(events, ev) mu.Unlock() } fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, append(testFailoverOpts(), WithFailoverObserver(obs))...) _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatal(err) } mu.Lock() defer mu.Unlock() if len(events) == 0 { t.Fatal("expected observer to be called") } ev := events[0] if ev.Model != "p/a" || ev.Kind != ErrAuthDead || !ev.Benched { t.Errorf("unexpected event: %+v", ev) } if ev.NextModel != "p/b" { t.Errorf("expected NextModel p/b, got %q", ev.NextModel) } if len(ev.Request.Messages) == 0 { t.Error("observer event should carry the full request") } } // TestFailover_DefaultObserverFiresOnTransparentChain verifies that a chain // built via NewFailoverModel with NO options still notifies a package-level // default observer set via SetFailoverObserver. This is the transparent // comma-Parse path: defaultFailoverConfig() must seed the default observer. func TestFailover_DefaultObserverFiresOnTransparentChain(t *testing.T) { resetHealthForTest() t.Cleanup(func() { SetFailoverObserver(nil) resetHealthForTest() }) var mu sync.Mutex var events []FailoverEvent SetFailoverObserver(func(ctx context.Context, ev FailoverEvent) { mu.Lock() events = append(events, ev) mu.Unlock() }) a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, &openai.Error{StatusCode: 401} }) b := newMockProvider(provider.Response{Text: "from-b"}) // NO options: the only way the observer can fire is via the package default. fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}) _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatal(err) } mu.Lock() defer mu.Unlock() if len(events) == 0 { t.Fatal("expected default observer to fire on a transparently-built chain") } if events[0].Model != "p/a" || events[0].Kind != ErrAuthDead { t.Errorf("unexpected event: %+v", events[0]) } } // TestFailover_DefaultObserverFiresOnParseChain verifies the comma-Parse seam: // a chain built through the registry's ParseChain (no per-call observer) fires // the package-level default observer. func TestFailover_DefaultObserverFiresOnParseChain(t *testing.T) { resetHealthForTest() t.Cleanup(func() { SetFailoverObserver(nil) resetHealthForTest() }) var mu sync.Mutex var events []FailoverEvent SetFailoverObserver(func(ctx context.Context, ev FailoverEvent) { mu.Lock() events = append(events, ev) mu.Unlock() }) r, alpha, _ := testRegistry(nil) // alpha returns a 401 (auth-dead) so the chain fails over and emits an event. alpha.err = &openai.Error{StatusCode: 401} m, err := r.Parse("alpha/model-a,beta/model-b") if err != nil { t.Fatalf("parse failed: %v", err) } if _, ok := m.provider.(*failoverProvider); !ok { t.Fatalf("expected a failover provider, got %T", m.provider) } _, err = m.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatal(err) } mu.Lock() defer mu.Unlock() if len(events) == 0 { t.Fatal("expected default observer to fire on a comma-Parse'd chain") } } // TestFailover_ExplicitObserverOverridesDefault verifies WithFailoverObserver // still wins: when both a package default and an explicit observer are present, // only the explicit one fires for that chain. func TestFailover_ExplicitObserverOverridesDefault(t *testing.T) { resetHealthForTest() t.Cleanup(func() { SetFailoverObserver(nil) resetHealthForTest() }) var mu sync.Mutex var defaultCalls, explicitCalls int SetFailoverObserver(func(ctx context.Context, ev FailoverEvent) { mu.Lock() defaultCalls++ mu.Unlock() }) explicit := func(ctx context.Context, ev FailoverEvent) { mu.Lock() explicitCalls++ mu.Unlock() } a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, &openai.Error{StatusCode: 401} }) b := newMockProvider(provider.Response{Text: "from-b"}) fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, WithFailoverObserver(explicit)) _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatal(err) } mu.Lock() defer mu.Unlock() if explicitCalls == 0 { t.Error("explicit observer should fire") } if defaultCalls != 0 { t.Errorf("default observer must NOT fire when an explicit one is set; got %d calls", defaultCalls) } } func TestFailover_ControlAPI(t *testing.T) { resetHealthForTest() if IsBenched("x/y") { t.Error("should start unbenched") } until := time.Now().Add(time.Hour) BenchModel("x/y", until) if !IsBenched("x/y") { t.Error("should be benched") } list := ListBenched() if len(list) != 1 || list[0].Model != "x/y" || !list[0].Manual { t.Errorf("unexpected list: %+v", list) } if !UnbenchModel("x/y") { t.Error("UnbenchModel should report it was benched") } if IsBenched("x/y") { t.Error("should be unbenched now") } if UnbenchModel("x/y") { t.Error("UnbenchModel on non-benched should return false") } } func TestFailover_ExpiredBenchIsLive(t *testing.T) { resetHealthForTest() // Bench in the past -> should be considered live again. BenchModel("p/a", time.Now().Add(-time.Hour)) if IsBenched("p/a") { t.Error("expired bench should not count as benched") } } // TestParseChain exercises ParseChain via a registry-backed seam is covered in // parse_test.go; here we verify NewFailoverModel flattens nested failover models. func TestNewFailoverModel_Flattens(t *testing.T) { resetHealthForTest() a := newMockProvider(provider.Response{Text: "a"}) b := newMockProvider(provider.Response{Text: "b"}) c := newMockProvider(provider.Response{Text: "c"}) inner := NewFailoverModel([]*Model{modelFor(b, "p/b"), modelFor(c, "p/c")}, testFailoverOpts()...) outer := NewFailoverModel([]*Model{modelFor(a, "p/a"), inner}, testFailoverOpts()...) fp, ok := outer.provider.(*failoverProvider) if !ok { t.Fatalf("expected *failoverProvider, got %T", outer.provider) } if len(fp.entries) != 3 { t.Errorf("expected flattened 3 entries, got %d", len(fp.entries)) } keys := []string{fp.entries[0].specKey, fp.entries[1].specKey, fp.entries[2].specKey} want := []string{"p/a", "p/b", "p/c"} for i := range want { if keys[i] != want[i] { t.Errorf("entry %d specKey = %q, want %q", i, keys[i], want[i]) } } } // TestFailover_ManualBenchSurvivesAutomaticDowngrade is a regression test: // an active manual bench (long window) must NOT be cleared or shortened by the // automatic recordTransientFailure/benchNow paths. The best-effort all-benched // failover loop can re-try a manually benched model; if it fails, the automatic // logic previously overwrote manual=true -> false and shortened until to the // short auto cooldown. The operator's intent must win. func TestFailover_ManualBenchSurvivesAutomaticDowngrade(t *testing.T) { resetHealthForTest() const key = "p/manual" longUntil := time.Now().Add(time.Hour) shortCooldown := time.Minute now := time.Now() // Operator manually benches the model for a long window. BenchModel(key, longUntil) assertManualPreserved := func(stage string) { t.Helper() list := ListBenched() var got *BenchedModel for i := range list { if list[i].Model == key { got = &list[i] break } } if got == nil { t.Fatalf("%s: %q missing from ListBenched; manual bench was cleared", stage, key) } if !got.Manual { t.Errorf("%s: Manual=false, want true (automatic logic downgraded manual bench)", stage) } if !got.Until.Equal(longUntil) { t.Errorf("%s: Until=%v, want %v (automatic logic shortened operator window)", stage, got.Until, longUntil) } } // (c) Active manual bench hit by the automatic transient path. benched, until := globalHealth.recordTransientFailure(key, shortCooldown, now) if !benched { t.Errorf("recordTransientFailure: benched=false, want true (model stays benched)") } if !until.Equal(longUntil) { t.Errorf("recordTransientFailure: until=%v, want %v (long manual window preserved)", until, longUntil) } assertManualPreserved("after recordTransientFailure") // (c) Active manual bench hit by the automatic auth-dead path. until = globalHealth.benchNow(key, shortCooldown, now) if !until.Equal(longUntil) { t.Errorf("benchNow: until=%v, want %v (long manual window preserved)", until, longUntil) } assertManualPreserved("after benchNow") // (d) Expired manual bench: automatic logic IS allowed to take over. resetHealthForTest() expired := time.Now().Add(-time.Hour) BenchModel(key, expired) until = globalHealth.benchNow(key, shortCooldown, now) if want := now.Add(shortCooldown); !until.Equal(want) { t.Errorf("benchNow on expired manual: until=%v, want auto cooldown %v", until, want) } // (a) No prior state and (b) prior automatic bench: automatic cooldown applies. resetHealthForTest() benched, until = globalHealth.recordTransientFailure(key, shortCooldown, now) // (a) if !benched || !until.Equal(now.Add(shortCooldown)) { t.Errorf("recordTransientFailure(no prior): benched=%v until=%v, want true %v", benched, until, now.Add(shortCooldown)) } until = globalHealth.benchNow(key, shortCooldown, now) // (b) prior automatic state if want := now.Add(shortCooldown); !until.Equal(want) { t.Errorf("benchNow(prior auto): until=%v, want %v", until, want) } if list := ListBenched(); len(list) != 1 || list[0].Manual { t.Errorf("automatic bench should not be Manual: %+v", list) } }