feat(failover): model failover chains via comma-separated specs

Parse("a,b,c") now returns one composite *llm.Model that tries each model
in order, retrying transient failures, benching dead models, and failing
over to the next. Comma-free specs are completely unchanged.

- classify.go: Classify(err) ErrKind + IsTransient(err) error classifier
  mapping anthropic (typed Is*Err helpers + RequestError status),
  openai-go (*openai.Error status), openaicompat.FeatureUnsupportedError,
  context errors, and ollama "HTTP <code>" strings to
  transient/auth-dead/request-specific/unknown.
- failover.go: failoverProvider (satisfies provider.Provider) wrapped into a
  *Model via NewClient. Process-wide mutex-guarded modelHealth bench
  registry keyed by concrete spec, with cooldowns and a control API
  (ListBenched/BenchModel/UnbenchModel/IsBenched). NewFailoverModel +
  ParseChain constructors, FailoverOption config, FailoverObserver (carries
  the full request), and configurable package-level defaults.
- parse.go: comma-aware Parse splits into a failover chain; alias/resolver
  targets that expand to comma chains are routed through the comma-aware
  path and flattened.

All access to global health is mutex-guarded; tests reset it via
resetHealthForTest and pass under go test -race.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-01 00:30:08 +02:00
parent 67c3ebe067
commit ae8e194fad
6 changed files with 1335 additions and 5 deletions
+588
View File
@@ -0,0 +1,588 @@
package llm
import (
"context"
"errors"
"fmt"
"log/slog"
"math/rand"
"strings"
"sync"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// ---------------------------------------------------------------------------
// Package-level defaults (mort configures these at boot via SetFailoverDefaults)
// ---------------------------------------------------------------------------
var (
// DefaultFailoverMaxRetries is the number of attempts per chain entry on
// transient errors before benching and moving to the next entry.
DefaultFailoverMaxRetries = 3
// DefaultFailoverCooldown is how long a model stays benched after a
// qualifying failure.
DefaultFailoverCooldown = 5 * time.Minute
// DefaultFailoverBackoff is the default exponential-with-jitter backoff.
DefaultFailoverBackoff = defaultBackoff
defaultsMu sync.Mutex
)
// defaultBackoff returns an exponential backoff with full jitter.
//
// Why: spreads retries to avoid thundering-herd against a recovering provider.
// What: base 200ms doubling per attempt, capped at 10s, with uniform jitter.
// Test: failover retry tests inject a fast backoff; this is the production default.
func defaultBackoff(attempt int) time.Duration {
if attempt < 1 {
attempt = 1
}
base := 200 * time.Millisecond
d := base << (attempt - 1)
if d > 10*time.Second {
d = 10 * time.Second
}
// Full jitter in [0, d].
return time.Duration(rand.Int63n(int64(d) + 1))
}
// SetFailoverDefaults overrides the package-level failover defaults used when
// no per-model options are supplied (e.g. comma-spec Parse).
//
// Why: mort wants to tune retries/cooldown once at boot without threading
// options through every Parse call.
// What: sets DefaultFailoverMaxRetries and DefaultFailoverCooldown under a lock.
// Test: set defaults, build a comma model, assert its cfg reflects them.
func SetFailoverDefaults(maxRetries int, cooldown time.Duration) {
defaultsMu.Lock()
defer defaultsMu.Unlock()
DefaultFailoverMaxRetries = maxRetries
DefaultFailoverCooldown = cooldown
}
// ---------------------------------------------------------------------------
// Global model health (process-wide bench registry)
// ---------------------------------------------------------------------------
// modelHealth tracks which concrete models are temporarily disabled (benched).
//
// Why: bench decisions must persist across requests and across all failover
// chains in the process, so a model that's down isn't retried by every chain.
// What: a mutex-guarded map keyed by specKey to its disabled state.
// Test: failover tests reset it via resetHealthForTest and assert via IsBenched.
type modelHealth struct {
mu sync.Mutex
disabled map[string]disabledState
}
type disabledState struct {
until time.Time
consecutiveFails int
manual bool
}
// globalHealth is the process-wide singleton shared by every failover chain.
var globalHealth = &modelHealth{disabled: map[string]disabledState{}}
// benchThreshold is the number of consecutive transient failures (each after
// exhausting retries) required before a model is benched. Auth-dead benches
// immediately regardless.
const benchThreshold = 1
// resetHealthForTest clears all bench state. Test-only.
func resetHealthForTest() {
globalHealth.mu.Lock()
defer globalHealth.mu.Unlock()
globalHealth.disabled = map[string]disabledState{}
}
// isBenched reports whether key is currently benched (and not expired).
func (h *modelHealth) isBenched(key string, now time.Time) bool {
h.mu.Lock()
defer h.mu.Unlock()
st, ok := h.disabled[key]
if !ok {
return false
}
if now.After(st.until) {
delete(h.disabled, key)
return false
}
return true
}
// recordSuccess clears any failure state for key.
func (h *modelHealth) recordSuccess(key string) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.disabled, key)
}
// recordTransientFailure increments the consecutive failure count and benches
// the model once the threshold is reached. Returns whether it is now benched
// and for how long.
func (h *modelHealth) recordTransientFailure(key string, cooldown time.Duration, now time.Time) (benched bool, until time.Time) {
h.mu.Lock()
defer h.mu.Unlock()
st := h.disabled[key]
st.consecutiveFails++
if st.consecutiveFails >= benchThreshold {
st.until = now.Add(cooldown)
st.manual = false
h.disabled[key] = st
return true, st.until
}
h.disabled[key] = st
return false, time.Time{}
}
// benchNow benches a model immediately (used for auth-dead errors).
func (h *modelHealth) benchNow(key string, cooldown time.Duration, now time.Time) time.Time {
h.mu.Lock()
defer h.mu.Unlock()
st := h.disabled[key]
st.consecutiveFails++
st.until = now.Add(cooldown)
st.manual = false
h.disabled[key] = st
return st.until
}
// benchManual benches a model until the given time, marking it manual.
func (h *modelHealth) benchManual(key string, until time.Time) {
h.mu.Lock()
defer h.mu.Unlock()
st := h.disabled[key]
st.until = until
st.manual = true
h.disabled[key] = st
}
// unbench removes a model's bench state, reporting whether it was benched.
func (h *modelHealth) unbench(key string, now time.Time) bool {
h.mu.Lock()
defer h.mu.Unlock()
st, ok := h.disabled[key]
if !ok || now.After(st.until) {
delete(h.disabled, key)
return false
}
delete(h.disabled, key)
return true
}
// list returns a snapshot of all currently-benched (non-expired) models.
func (h *modelHealth) list(now time.Time) []BenchedModel {
h.mu.Lock()
defer h.mu.Unlock()
var out []BenchedModel
for k, st := range h.disabled {
if now.After(st.until) {
delete(h.disabled, k)
continue
}
out = append(out, BenchedModel{
Model: k,
Until: st.until,
ConsecutiveFails: st.consecutiveFails,
Manual: st.manual,
})
}
return out
}
// ---------------------------------------------------------------------------
// Control API (admin commands / UI drive these)
// ---------------------------------------------------------------------------
// BenchedModel is a snapshot of a benched model's state.
type BenchedModel struct {
Model string
Until time.Time
ConsecutiveFails int
Manual bool
}
// ListBenched returns all currently-benched models across the process.
//
// Why: admin tooling needs to display which models are sidelined and why.
// What: snapshots the global health map, pruning expired entries.
// Test: BenchModel then ListBenched returns it with Manual=true.
func ListBenched() []BenchedModel { return globalHealth.list(time.Now()) }
// BenchModel manually benches a model until the given time.
//
// Why: operators sometimes need to force a model offline (incident, cost).
// What: records a manual bench in the global health registry.
// Test: BenchModel then IsBenched returns true and ListBenched shows Manual.
func BenchModel(spec string, until time.Time) { globalHealth.benchManual(spec, until) }
// UnbenchModel clears a model's bench state, returning whether it was benched.
//
// Why: operators need to bring a model back early after manual or auto bench.
// What: deletes the global health entry, reporting prior benched state.
// Test: bench then UnbenchModel returns true; a second call returns false.
func UnbenchModel(spec string) bool { return globalHealth.unbench(spec, time.Now()) }
// IsBenched reports whether a model is currently benched.
//
// Why: callers/tests want a quick health check for a concrete model.
// What: consults the global health registry (expired benches read as false).
// Test: BenchModel makes it true; an expired bench reads false.
func IsBenched(spec string) bool { return globalHealth.isBenched(spec, time.Now()) }
// ---------------------------------------------------------------------------
// Observer
// ---------------------------------------------------------------------------
// FailoverEvent describes a single failover decision for an observer.
type FailoverEvent struct {
Model string
Err error
Kind ErrKind
Attempt int
Benched bool
BenchedFor time.Duration
NextModel string
Request provider.Request
}
// FailoverObserver receives a FailoverEvent for each failover decision. mort
// uses this to persist the full prompt chain on failover.
type FailoverObserver func(ctx context.Context, ev FailoverEvent)
// ---------------------------------------------------------------------------
// Config + options
// ---------------------------------------------------------------------------
type failoverConfig struct {
maxRetries int
cooldown time.Duration
backoff func(attempt int) time.Duration
observer FailoverObserver
}
func defaultFailoverConfig() failoverConfig {
defaultsMu.Lock()
defer defaultsMu.Unlock()
return failoverConfig{
maxRetries: DefaultFailoverMaxRetries,
cooldown: DefaultFailoverCooldown,
backoff: DefaultFailoverBackoff,
}
}
// FailoverOption configures a failover model.
type FailoverOption func(*failoverConfig)
// WithFailoverMaxRetries sets attempts per entry on transient errors.
func WithFailoverMaxRetries(n int) FailoverOption {
return func(c *failoverConfig) {
if n < 1 {
n = 1
}
c.maxRetries = n
}
}
// WithFailoverCooldown sets how long a model stays benched after failure.
func WithFailoverCooldown(d time.Duration) FailoverOption {
return func(c *failoverConfig) { c.cooldown = d }
}
// WithFailoverBackoff sets the retry backoff function.
func WithFailoverBackoff(fn func(attempt int) time.Duration) FailoverOption {
return func(c *failoverConfig) {
if fn != nil {
c.backoff = fn
}
}
}
// WithFailoverObserver sets an observer notified on every failover decision.
func WithFailoverObserver(obs FailoverObserver) FailoverOption {
return func(c *failoverConfig) { c.observer = obs }
}
// ---------------------------------------------------------------------------
// Composite provider
// ---------------------------------------------------------------------------
type failoverEntry struct {
provider provider.Provider
model string // bare model name sent to the provider
specKey string // global health key (full concrete spec)
}
type failoverProvider struct {
entries []failoverEntry
cfg failoverConfig
health *modelHealth
}
// bareModel strips a leading "provider/" prefix, returning the model name the
// underlying provider expects. Specs without a slash are returned unchanged.
func bareModel(spec string) string {
if i := strings.Index(spec, "/"); i >= 0 {
return spec[i+1:]
}
return spec
}
// NewFailoverModel builds a composite *Model that tries each sub-model in order,
// retrying/benching per the configured policy and failing over on error.
//
// Why: callers hold *Model and the base Complete handler is hardwired to one
// provider, so failover (which must switch providers) is implemented as a
// composite provider wrapped back into a *Model.
// What: flattens any nested failover sub-models, derives a specKey per entry
// from its model string, and returns NewClient(fp).Model("failover").
// Test: failover_test.go exercises success, failover, bench, abort, and flatten.
func NewFailoverModel(models []*Model, opts ...FailoverOption) *Model {
cfg := defaultFailoverConfig()
for _, opt := range opts {
opt(&cfg)
}
var entries []failoverEntry
for _, m := range models {
if m == nil {
continue
}
// Flatten nested failover models so cooldowns/keys stay flat.
if fp, ok := m.provider.(*failoverProvider); ok {
entries = append(entries, fp.entries...)
continue
}
entries = append(entries, failoverEntry{
provider: m.provider,
model: bareModel(m.model),
specKey: m.model,
})
}
fp := &failoverProvider{
entries: entries,
cfg: cfg,
health: globalHealth,
}
return NewClient(fp).Model("failover")
}
// ParseChain parses each spec and combines them into one failover model.
//
// Why: lets callers build a failover chain from a slice of specs (full
// resolution per entry) without manually wiring providers.
// What: Parse each spec, preserve the original spec string as the bench key,
// flatten nested failover models, and return a composite *Model.
// Test: parse_test.go covers comma-spec parsing through the registry.
func ParseChain(specs []string, opts ...FailoverOption) (*Model, error) {
return DefaultRegistry.ParseChain(specs, opts...)
}
// ParseChain is the registry-scoped form of ParseChain.
func (r *Registry) ParseChain(specs []string, opts ...FailoverOption) (*Model, error) {
cfg := defaultFailoverConfig()
for _, opt := range opts {
opt(&cfg)
}
var entries []failoverEntry
for _, spec := range specs {
spec = strings.TrimSpace(spec)
if spec == "" {
continue
}
m, err := r.Parse(spec)
if err != nil {
return nil, fmt.Errorf("failover chain: parse %q: %w", spec, err)
}
if fp, ok := m.provider.(*failoverProvider); ok {
// A sub-spec was itself a comma/failover spec — splice its entries.
entries = append(entries, fp.entries...)
continue
}
entries = append(entries, failoverEntry{
provider: m.provider,
model: m.model,
specKey: spec,
})
}
if len(entries) == 0 {
return nil, fmt.Errorf("failover chain: no valid specs")
}
fp := &failoverProvider{entries: entries, cfg: cfg, health: globalHealth}
return NewClient(fp).Model("failover"), nil
}
// reqWithModel returns a shallow copy of req with its Model set to model.
//
// Why: each sub-provider must receive its own bare model name; the incoming
// req carries the placeholder "failover" model from the composite *Model.
// What: copies the struct (slices/pointers are shared, which is safe here since
// providers treat the request as read-only) and overrides Model.
// Test: TestFailover_PassesModelNameToProvider asserts the provider sees the bare name.
func reqWithModel(req provider.Request, model string) provider.Request {
req.Model = model
return req
}
// Complete implements provider.Provider with ordered failover.
func (f *failoverProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
now := time.Now()
// 1. Build the live set (not currently benched). Best-effort: if all are
// benched, ignore cooldowns rather than hard-fail.
var live []failoverEntry
for _, e := range f.entries {
if !f.health.isBenched(e.specKey, now) {
live = append(live, e)
}
}
if len(live) == 0 {
live = f.entries
}
var causes []error
for i, entry := range live {
nextModel := ""
if i+1 < len(live) {
nextModel = live[i+1].specKey
}
for attempt := 1; attempt <= f.cfg.maxRetries; attempt++ {
resp, err := entry.provider.Complete(ctx, reqWithModel(req, entry.model))
if err == nil {
f.health.recordSuccess(entry.specKey)
return resp, nil
}
// Caller aborted: stop everything, no failover, no bench.
if errors.Is(err, context.Canceled) {
return provider.Response{}, err
}
kind := Classify(err)
switch kind {
case ErrRequestSpecific:
f.emit(ctx, FailoverEvent{
Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt,
NextModel: nextModel, Request: req,
})
slog.Warn("failover: request-specific error, trying next model",
"model", entry.specKey, "kind", "request_specific",
"status", statusOf(err), "attempt", attempt, "next", nextModel)
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
goto nextEntry
case ErrAuthDead:
until := f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now())
f.emit(ctx, FailoverEvent{
Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt,
Benched: true, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req,
})
slog.Warn("failover: auth/model-dead error, benching model",
"model", entry.specKey, "kind", "auth_dead", "status", statusOf(err),
"attempt", attempt, "benched", true, "cooldown", f.cfg.cooldown,
"until", until, "next", nextModel)
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
goto nextEntry
default: // ErrTransient or ErrUnknown -> retry, then bench.
if attempt >= f.cfg.maxRetries {
benched, until := f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now())
f.emit(ctx, FailoverEvent{
Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt,
Benched: benched, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req,
})
slog.Warn("failover: transient error, retries exhausted",
"model", entry.specKey, "kind", kindString(kind), "status", statusOf(err),
"attempt", attempt, "benched", benched, "cooldown", f.cfg.cooldown,
"until", until, "next", nextModel)
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
goto nextEntry
}
// Sleep before retrying (respect ctx).
select {
case <-ctx.Done():
return provider.Response{}, ctx.Err()
case <-time.After(f.cfg.backoff(attempt)):
}
}
}
nextEntry:
}
return provider.Response{}, fmt.Errorf("failover: all %d models in chain failed: %w",
len(live), errors.Join(causes...))
}
// Stream implements provider.Provider. It fails over only on the INITIAL Stream
// call error (before any event). Once a stream begins, mid-stream failures are
// surfaced as-is — failover does not replay a partially-consumed stream.
func (f *failoverProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
now := time.Now()
var live []failoverEntry
for _, e := range f.entries {
if !f.health.isBenched(e.specKey, now) {
live = append(live, e)
}
}
if len(live) == 0 {
live = f.entries
}
var causes []error
for i, entry := range live {
nextModel := ""
if i+1 < len(live) {
nextModel = live[i+1].specKey
}
err := entry.provider.Stream(ctx, reqWithModel(req, entry.model), events)
if err == nil {
f.health.recordSuccess(entry.specKey)
return nil
}
if errors.Is(err, context.Canceled) {
return err
}
kind := Classify(err)
switch kind {
case ErrAuthDead:
f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now())
case ErrTransient, ErrUnknown:
f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now())
}
slog.Warn("failover(stream): error, trying next model",
"model", entry.specKey, "kind", kindString(kind), "status", statusOf(err), "next", nextModel)
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
}
return fmt.Errorf("failover(stream): all %d models in chain failed: %w", len(live), errors.Join(causes...))
}
func (f *failoverProvider) emit(ctx context.Context, ev FailoverEvent) {
if f.cfg.observer != nil {
f.cfg.observer(ctx, ev)
}
}
func kindString(k ErrKind) string {
switch k {
case ErrTransient:
return "transient"
case ErrAuthDead:
return "auth_dead"
case ErrRequestSpecific:
return "request_specific"
default:
return "unknown"
}
}
// statusOf best-effort extracts an HTTP status code for logging, or 0.
func statusOf(err error) int { return extractStatus(err) }