package llm import ( "context" "errors" "fmt" "log/slog" "math/rand" "strings" "sync" "time" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) // --------------------------------------------------------------------------- // Package-level defaults (mort configures these at boot via SetFailoverDefaults) // --------------------------------------------------------------------------- var ( // DefaultFailoverMaxRetries is the number of attempts per chain entry on // transient errors before benching and moving to the next entry. DefaultFailoverMaxRetries = 3 // DefaultFailoverCooldown is how long a model stays benched after a // qualifying failure. DefaultFailoverCooldown = 5 * time.Minute // DefaultFailoverBackoff is the default exponential-with-jitter backoff. DefaultFailoverBackoff = defaultBackoff defaultsMu sync.Mutex ) // defaultBackoff returns an exponential backoff with full jitter. // // Why: spreads retries to avoid thundering-herd against a recovering provider. // What: base 200ms doubling per attempt, capped at 10s, with uniform jitter. // Test: failover retry tests inject a fast backoff; this is the production default. func defaultBackoff(attempt int) time.Duration { if attempt < 1 { attempt = 1 } base := 200 * time.Millisecond d := base << (attempt - 1) if d > 10*time.Second { d = 10 * time.Second } // Full jitter in [0, d]. return time.Duration(rand.Int63n(int64(d) + 1)) } // SetFailoverDefaults overrides the package-level failover defaults used when // no per-model options are supplied (e.g. comma-spec Parse). // // Why: mort wants to tune retries/cooldown once at boot without threading // options through every Parse call. // What: sets DefaultFailoverMaxRetries and DefaultFailoverCooldown under a lock. // Test: set defaults, build a comma model, assert its cfg reflects them. func SetFailoverDefaults(maxRetries int, cooldown time.Duration) { defaultsMu.Lock() defer defaultsMu.Unlock() DefaultFailoverMaxRetries = maxRetries DefaultFailoverCooldown = cooldown } // --------------------------------------------------------------------------- // Global model health (process-wide bench registry) // --------------------------------------------------------------------------- // modelHealth tracks which concrete models are temporarily disabled (benched). // // Why: bench decisions must persist across requests and across all failover // chains in the process, so a model that's down isn't retried by every chain. // What: a mutex-guarded map keyed by specKey to its disabled state. // Test: failover tests reset it via resetHealthForTest and assert via IsBenched. type modelHealth struct { mu sync.Mutex disabled map[string]disabledState } type disabledState struct { until time.Time consecutiveFails int manual bool } // globalHealth is the process-wide singleton shared by every failover chain. var globalHealth = &modelHealth{disabled: map[string]disabledState{}} // benchThreshold is the number of consecutive transient failures (each after // exhausting retries) required before a model is benched. Auth-dead benches // immediately regardless. const benchThreshold = 1 // resetHealthForTest clears all bench state. Test-only. func resetHealthForTest() { globalHealth.mu.Lock() defer globalHealth.mu.Unlock() globalHealth.disabled = map[string]disabledState{} } // isBenched reports whether key is currently benched (and not expired). func (h *modelHealth) isBenched(key string, now time.Time) bool { h.mu.Lock() defer h.mu.Unlock() st, ok := h.disabled[key] if !ok { return false } if now.After(st.until) { delete(h.disabled, key) return false } return true } // recordSuccess clears any failure state for key. func (h *modelHealth) recordSuccess(key string) { h.mu.Lock() defer h.mu.Unlock() delete(h.disabled, key) } // recordTransientFailure increments the consecutive failure count and benches // the model once the threshold is reached. Returns whether it is now benched // and for how long. func (h *modelHealth) recordTransientFailure(key string, cooldown time.Duration, now time.Time) (benched bool, until time.Time) { h.mu.Lock() defer h.mu.Unlock() st := h.disabled[key] st.consecutiveFails++ if st.consecutiveFails >= benchThreshold { st.until = now.Add(cooldown) st.manual = false h.disabled[key] = st return true, st.until } h.disabled[key] = st return false, time.Time{} } // benchNow benches a model immediately (used for auth-dead errors). func (h *modelHealth) benchNow(key string, cooldown time.Duration, now time.Time) time.Time { h.mu.Lock() defer h.mu.Unlock() st := h.disabled[key] st.consecutiveFails++ st.until = now.Add(cooldown) st.manual = false h.disabled[key] = st return st.until } // benchManual benches a model until the given time, marking it manual. func (h *modelHealth) benchManual(key string, until time.Time) { h.mu.Lock() defer h.mu.Unlock() st := h.disabled[key] st.until = until st.manual = true h.disabled[key] = st } // unbench removes a model's bench state, reporting whether it was benched. func (h *modelHealth) unbench(key string, now time.Time) bool { h.mu.Lock() defer h.mu.Unlock() st, ok := h.disabled[key] if !ok || now.After(st.until) { delete(h.disabled, key) return false } delete(h.disabled, key) return true } // list returns a snapshot of all currently-benched (non-expired) models. func (h *modelHealth) list(now time.Time) []BenchedModel { h.mu.Lock() defer h.mu.Unlock() var out []BenchedModel for k, st := range h.disabled { if now.After(st.until) { delete(h.disabled, k) continue } out = append(out, BenchedModel{ Model: k, Until: st.until, ConsecutiveFails: st.consecutiveFails, Manual: st.manual, }) } return out } // --------------------------------------------------------------------------- // Control API (admin commands / UI drive these) // --------------------------------------------------------------------------- // BenchedModel is a snapshot of a benched model's state. type BenchedModel struct { Model string Until time.Time ConsecutiveFails int Manual bool } // ListBenched returns all currently-benched models across the process. // // Why: admin tooling needs to display which models are sidelined and why. // What: snapshots the global health map, pruning expired entries. // Test: BenchModel then ListBenched returns it with Manual=true. func ListBenched() []BenchedModel { return globalHealth.list(time.Now()) } // BenchModel manually benches a model until the given time. // // Why: operators sometimes need to force a model offline (incident, cost). // What: records a manual bench in the global health registry. // Test: BenchModel then IsBenched returns true and ListBenched shows Manual. func BenchModel(spec string, until time.Time) { globalHealth.benchManual(spec, until) } // UnbenchModel clears a model's bench state, returning whether it was benched. // // Why: operators need to bring a model back early after manual or auto bench. // What: deletes the global health entry, reporting prior benched state. // Test: bench then UnbenchModel returns true; a second call returns false. func UnbenchModel(spec string) bool { return globalHealth.unbench(spec, time.Now()) } // IsBenched reports whether a model is currently benched. // // Why: callers/tests want a quick health check for a concrete model. // What: consults the global health registry (expired benches read as false). // Test: BenchModel makes it true; an expired bench reads false. func IsBenched(spec string) bool { return globalHealth.isBenched(spec, time.Now()) } // --------------------------------------------------------------------------- // Observer // --------------------------------------------------------------------------- // FailoverEvent describes a single failover decision for an observer. type FailoverEvent struct { Model string Err error Kind ErrKind Attempt int Benched bool BenchedFor time.Duration NextModel string Request provider.Request } // FailoverObserver receives a FailoverEvent for each failover decision. mort // uses this to persist the full prompt chain on failover. type FailoverObserver func(ctx context.Context, ev FailoverEvent) // --------------------------------------------------------------------------- // Config + options // --------------------------------------------------------------------------- type failoverConfig struct { maxRetries int cooldown time.Duration backoff func(attempt int) time.Duration observer FailoverObserver } func defaultFailoverConfig() failoverConfig { defaultsMu.Lock() defer defaultsMu.Unlock() return failoverConfig{ maxRetries: DefaultFailoverMaxRetries, cooldown: DefaultFailoverCooldown, backoff: DefaultFailoverBackoff, } } // FailoverOption configures a failover model. type FailoverOption func(*failoverConfig) // WithFailoverMaxRetries sets attempts per entry on transient errors. func WithFailoverMaxRetries(n int) FailoverOption { return func(c *failoverConfig) { if n < 1 { n = 1 } c.maxRetries = n } } // WithFailoverCooldown sets how long a model stays benched after failure. func WithFailoverCooldown(d time.Duration) FailoverOption { return func(c *failoverConfig) { c.cooldown = d } } // WithFailoverBackoff sets the retry backoff function. func WithFailoverBackoff(fn func(attempt int) time.Duration) FailoverOption { return func(c *failoverConfig) { if fn != nil { c.backoff = fn } } } // WithFailoverObserver sets an observer notified on every failover decision. func WithFailoverObserver(obs FailoverObserver) FailoverOption { return func(c *failoverConfig) { c.observer = obs } } // --------------------------------------------------------------------------- // Composite provider // --------------------------------------------------------------------------- type failoverEntry struct { provider provider.Provider model string // bare model name sent to the provider specKey string // global health key (full concrete spec) } type failoverProvider struct { entries []failoverEntry cfg failoverConfig health *modelHealth } // bareModel strips a leading "provider/" prefix, returning the model name the // underlying provider expects. Specs without a slash are returned unchanged. func bareModel(spec string) string { if i := strings.Index(spec, "/"); i >= 0 { return spec[i+1:] } return spec } // NewFailoverModel builds a composite *Model that tries each sub-model in order, // retrying/benching per the configured policy and failing over on error. // // Why: callers hold *Model and the base Complete handler is hardwired to one // provider, so failover (which must switch providers) is implemented as a // composite provider wrapped back into a *Model. // What: flattens any nested failover sub-models, derives a specKey per entry // from its model string, and returns NewClient(fp).Model("failover"). // Test: failover_test.go exercises success, failover, bench, abort, and flatten. func NewFailoverModel(models []*Model, opts ...FailoverOption) *Model { cfg := defaultFailoverConfig() for _, opt := range opts { opt(&cfg) } var entries []failoverEntry for _, m := range models { if m == nil { continue } // Flatten nested failover models so cooldowns/keys stay flat. if fp, ok := m.provider.(*failoverProvider); ok { entries = append(entries, fp.entries...) continue } entries = append(entries, failoverEntry{ provider: m.provider, model: bareModel(m.model), specKey: m.model, }) } fp := &failoverProvider{ entries: entries, cfg: cfg, health: globalHealth, } return NewClient(fp).Model("failover") } // ParseChain parses each spec and combines them into one failover model. // // Why: lets callers build a failover chain from a slice of specs (full // resolution per entry) without manually wiring providers. // What: Parse each spec, preserve the original spec string as the bench key, // flatten nested failover models, and return a composite *Model. // Test: parse_test.go covers comma-spec parsing through the registry. func ParseChain(specs []string, opts ...FailoverOption) (*Model, error) { return DefaultRegistry.ParseChain(specs, opts...) } // ParseChain is the registry-scoped form of ParseChain. func (r *Registry) ParseChain(specs []string, opts ...FailoverOption) (*Model, error) { cfg := defaultFailoverConfig() for _, opt := range opts { opt(&cfg) } var entries []failoverEntry for _, spec := range specs { spec = strings.TrimSpace(spec) if spec == "" { continue } m, err := r.Parse(spec) if err != nil { return nil, fmt.Errorf("failover chain: parse %q: %w", spec, err) } if fp, ok := m.provider.(*failoverProvider); ok { // A sub-spec was itself a comma/failover spec — splice its entries. entries = append(entries, fp.entries...) continue } entries = append(entries, failoverEntry{ provider: m.provider, model: m.model, specKey: spec, }) } if len(entries) == 0 { return nil, fmt.Errorf("failover chain: no valid specs") } fp := &failoverProvider{entries: entries, cfg: cfg, health: globalHealth} return NewClient(fp).Model("failover"), nil } // reqWithModel returns a shallow copy of req with its Model set to model. // // Why: each sub-provider must receive its own bare model name; the incoming // req carries the placeholder "failover" model from the composite *Model. // What: copies the struct (slices/pointers are shared, which is safe here since // providers treat the request as read-only) and overrides Model. // Test: TestFailover_PassesModelNameToProvider asserts the provider sees the bare name. func reqWithModel(req provider.Request, model string) provider.Request { req.Model = model return req } // Complete implements provider.Provider with ordered failover. func (f *failoverProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { now := time.Now() // 1. Build the live set (not currently benched). Best-effort: if all are // benched, ignore cooldowns rather than hard-fail. var live []failoverEntry for _, e := range f.entries { if !f.health.isBenched(e.specKey, now) { live = append(live, e) } } if len(live) == 0 { live = f.entries } var causes []error for i, entry := range live { nextModel := "" if i+1 < len(live) { nextModel = live[i+1].specKey } for attempt := 1; attempt <= f.cfg.maxRetries; attempt++ { resp, err := entry.provider.Complete(ctx, reqWithModel(req, entry.model)) if err == nil { f.health.recordSuccess(entry.specKey) return resp, nil } // Caller aborted: stop everything, no failover, no bench. if errors.Is(err, context.Canceled) { return provider.Response{}, err } kind := Classify(err) switch kind { case ErrRequestSpecific: f.emit(ctx, FailoverEvent{ Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt, NextModel: nextModel, Request: req, }) slog.Warn("failover: request-specific error, trying next model", "model", entry.specKey, "kind", "request_specific", "status", statusOf(err), "attempt", attempt, "next", nextModel) causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err)) goto nextEntry case ErrAuthDead: until := f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now()) f.emit(ctx, FailoverEvent{ Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt, Benched: true, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req, }) slog.Warn("failover: auth/model-dead error, benching model", "model", entry.specKey, "kind", "auth_dead", "status", statusOf(err), "attempt", attempt, "benched", true, "cooldown", f.cfg.cooldown, "until", until, "next", nextModel) causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err)) goto nextEntry default: // ErrTransient or ErrUnknown -> retry, then bench. if attempt >= f.cfg.maxRetries { benched, until := f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now()) f.emit(ctx, FailoverEvent{ Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt, Benched: benched, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req, }) slog.Warn("failover: transient error, retries exhausted", "model", entry.specKey, "kind", kindString(kind), "status", statusOf(err), "attempt", attempt, "benched", benched, "cooldown", f.cfg.cooldown, "until", until, "next", nextModel) causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err)) goto nextEntry } // Sleep before retrying (respect ctx). select { case <-ctx.Done(): return provider.Response{}, ctx.Err() case <-time.After(f.cfg.backoff(attempt)): } } } nextEntry: } return provider.Response{}, fmt.Errorf("failover: all %d models in chain failed: %w", len(live), errors.Join(causes...)) } // Stream implements provider.Provider. It fails over only on the INITIAL Stream // call error (before any event). Once a stream begins, mid-stream failures are // surfaced as-is — failover does not replay a partially-consumed stream. func (f *failoverProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { now := time.Now() var live []failoverEntry for _, e := range f.entries { if !f.health.isBenched(e.specKey, now) { live = append(live, e) } } if len(live) == 0 { live = f.entries } var causes []error for i, entry := range live { nextModel := "" if i+1 < len(live) { nextModel = live[i+1].specKey } err := entry.provider.Stream(ctx, reqWithModel(req, entry.model), events) if err == nil { f.health.recordSuccess(entry.specKey) return nil } if errors.Is(err, context.Canceled) { return err } kind := Classify(err) switch kind { case ErrAuthDead: f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now()) case ErrTransient, ErrUnknown: f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now()) } slog.Warn("failover(stream): error, trying next model", "model", entry.specKey, "kind", kindString(kind), "status", statusOf(err), "next", nextModel) causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err)) } return fmt.Errorf("failover(stream): all %d models in chain failed: %w", len(live), errors.Join(causes...)) } func (f *failoverProvider) emit(ctx context.Context, ev FailoverEvent) { if f.cfg.observer != nil { f.cfg.observer(ctx, ev) } } func kindString(k ErrKind) string { switch k { case ErrTransient: return "transient" case ErrAuthDead: return "auth_dead" case ErrRequestSpecific: return "request_specific" default: return "unknown" } } // statusOf best-effort extracts an HTTP status code for logging, or 0. func statusOf(err error) int { return extractStatus(err) }