diff --git a/v2/failover.go b/v2/failover.go index 9c9fb41..cdffdd0 100644 --- a/v2/failover.go +++ b/v2/failover.go @@ -27,6 +27,12 @@ var ( // DefaultFailoverBackoff is the default exponential-with-jitter backoff. DefaultFailoverBackoff = defaultBackoff + // defaultFailoverObserver is the package-level observer applied to chains + // built without an explicit WithFailoverObserver (e.g. the transparent + // comma-Parse path). Kept unexported and behind defaultsMu so reads/writes + // are race-safe under -race. mort sets this at boot to persist failover events. + defaultFailoverObserver FailoverObserver + defaultsMu sync.Mutex ) @@ -62,6 +68,31 @@ func SetFailoverDefaults(maxRetries int, cooldown time.Duration) { DefaultFailoverCooldown = cooldown } +// SetFailoverObserver sets the package-level default observer notified on +// failover decisions for chains built without an explicit WithFailoverObserver. +// +// Why: the transparent comma-Parse path builds chains via NewFailoverModel with +// no options, so without a package default no observer ever fires; mort sets +// this once at boot to persist failover events from every chain. +// What: stores the observer under defaultsMu; pass nil to disable. +// Test: set an observer, build a no-option chain, assert it fires on failover. +func SetFailoverObserver(obs FailoverObserver) { + defaultsMu.Lock() + defer defaultsMu.Unlock() + defaultFailoverObserver = obs +} + +// DefaultFailoverObserver returns the current package-level default observer. +// +// Why: lets tests assert/restore the default without reaching into the unexported var. +// What: reads defaultFailoverObserver under defaultsMu. +// Test: set via SetFailoverObserver, assert this returns a non-nil func. +func DefaultFailoverObserver() FailoverObserver { + defaultsMu.Lock() + defer defaultsMu.Unlock() + return defaultFailoverObserver +} + // --------------------------------------------------------------------------- // Global model health (process-wide bench registry) // --------------------------------------------------------------------------- @@ -285,6 +316,11 @@ func defaultFailoverConfig() failoverConfig { maxRetries: DefaultFailoverMaxRetries, cooldown: DefaultFailoverCooldown, backoff: DefaultFailoverBackoff, + // Seed the package-level default observer. An explicit + // WithFailoverObserver applied after this in NewFailoverModel/ParseChain + // overrides it for that chain. Read under the same defaultsMu we already + // hold (a single Lock above), so no re-lock / deadlock. + observer: defaultFailoverObserver, } } diff --git a/v2/failover_test.go b/v2/failover_test.go index e039e5e..30879ed 100644 --- a/v2/failover_test.go +++ b/v2/failover_test.go @@ -241,6 +241,134 @@ func TestFailover_Observer(t *testing.T) { } } +// TestFailover_DefaultObserverFiresOnTransparentChain verifies that a chain +// built via NewFailoverModel with NO options still notifies a package-level +// default observer set via SetFailoverObserver. This is the transparent +// comma-Parse path: defaultFailoverConfig() must seed the default observer. +func TestFailover_DefaultObserverFiresOnTransparentChain(t *testing.T) { + resetHealthForTest() + t.Cleanup(func() { + SetFailoverObserver(nil) + resetHealthForTest() + }) + + var mu sync.Mutex + var events []FailoverEvent + SetFailoverObserver(func(ctx context.Context, ev FailoverEvent) { + mu.Lock() + events = append(events, ev) + mu.Unlock() + }) + + a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, &openai.Error{StatusCode: 401} + }) + b := newMockProvider(provider.Response{Text: "from-b"}) + + // NO options: the only way the observer can fire is via the package default. + fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}) + _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatal(err) + } + + mu.Lock() + defer mu.Unlock() + if len(events) == 0 { + t.Fatal("expected default observer to fire on a transparently-built chain") + } + if events[0].Model != "p/a" || events[0].Kind != ErrAuthDead { + t.Errorf("unexpected event: %+v", events[0]) + } +} + +// TestFailover_DefaultObserverFiresOnParseChain verifies the comma-Parse seam: +// a chain built through the registry's ParseChain (no per-call observer) fires +// the package-level default observer. +func TestFailover_DefaultObserverFiresOnParseChain(t *testing.T) { + resetHealthForTest() + t.Cleanup(func() { + SetFailoverObserver(nil) + resetHealthForTest() + }) + + var mu sync.Mutex + var events []FailoverEvent + SetFailoverObserver(func(ctx context.Context, ev FailoverEvent) { + mu.Lock() + events = append(events, ev) + mu.Unlock() + }) + + r, alpha, _ := testRegistry(nil) + // alpha returns a 401 (auth-dead) so the chain fails over and emits an event. + alpha.err = &openai.Error{StatusCode: 401} + + m, err := r.Parse("alpha/model-a,beta/model-b") + if err != nil { + t.Fatalf("parse failed: %v", err) + } + if _, ok := m.provider.(*failoverProvider); !ok { + t.Fatalf("expected a failover provider, got %T", m.provider) + } + + _, err = m.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatal(err) + } + + mu.Lock() + defer mu.Unlock() + if len(events) == 0 { + t.Fatal("expected default observer to fire on a comma-Parse'd chain") + } +} + +// TestFailover_ExplicitObserverOverridesDefault verifies WithFailoverObserver +// still wins: when both a package default and an explicit observer are present, +// only the explicit one fires for that chain. +func TestFailover_ExplicitObserverOverridesDefault(t *testing.T) { + resetHealthForTest() + t.Cleanup(func() { + SetFailoverObserver(nil) + resetHealthForTest() + }) + + var mu sync.Mutex + var defaultCalls, explicitCalls int + SetFailoverObserver(func(ctx context.Context, ev FailoverEvent) { + mu.Lock() + defaultCalls++ + mu.Unlock() + }) + explicit := func(ctx context.Context, ev FailoverEvent) { + mu.Lock() + explicitCalls++ + mu.Unlock() + } + + a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, &openai.Error{StatusCode: 401} + }) + b := newMockProvider(provider.Response{Text: "from-b"}) + + fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, + WithFailoverObserver(explicit)) + _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatal(err) + } + + mu.Lock() + defer mu.Unlock() + if explicitCalls == 0 { + t.Error("explicit observer should fire") + } + if defaultCalls != 0 { + t.Errorf("default observer must NOT fire when an explicit one is set; got %d calls", defaultCalls) + } +} + func TestFailover_ControlAPI(t *testing.T) { resetHealthForTest() if IsBenched("x/y") { diff --git a/v2/parse_test.go b/v2/parse_test.go index b03b07d..a0eb450 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -12,10 +12,16 @@ import ( // verify that Parse resolved to the correct model without network calls. type recordingProvider struct { lastModel string + // err, when non-nil, is returned from Complete so failover tests can drive + // a comma-Parse'd chain through a failover decision. Defaults to nil (success). + err error } func (p *recordingProvider) Complete(_ context.Context, req provider.Request) (provider.Response, error) { p.lastModel = req.Model + if p.err != nil { + return provider.Response{}, p.err + } return provider.Response{Text: "ok"}, nil }