feat(failover): model failover chains via comma-separated specs

Parse("a,b,c") now returns one composite *llm.Model that tries each model
in order, retrying transient failures, benching dead models, and failing
over to the next. Comma-free specs are completely unchanged.

- classify.go: Classify(err) ErrKind + IsTransient(err) error classifier
  mapping anthropic (typed Is*Err helpers + RequestError status),
  openai-go (*openai.Error status), openaicompat.FeatureUnsupportedError,
  context errors, and ollama "HTTP <code>" strings to
  transient/auth-dead/request-specific/unknown.
- failover.go: failoverProvider (satisfies provider.Provider) wrapped into a
  *Model via NewClient. Process-wide mutex-guarded modelHealth bench
  registry keyed by concrete spec, with cooldowns and a control API
  (ListBenched/BenchModel/UnbenchModel/IsBenched). NewFailoverModel +
  ParseChain constructors, FailoverOption config, FailoverObserver (carries
  the full request), and configurable package-level defaults.
- parse.go: comma-aware Parse splits into a failover chain; alias/resolver
  targets that expand to comma chains are routed through the comma-aware
  path and flattened.

All access to global health is mutex-guarded; tests reset it via
resetHealthForTest and pass under go test -race.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-01 00:30:08 +02:00
parent 67c3ebe067
commit ae8e194fad
6 changed files with 1335 additions and 5 deletions
+207
View File
@@ -0,0 +1,207 @@
package llm
import (
"context"
"errors"
"strings"
anth "github.com/liushuangls/go-anthropic/v2"
"github.com/openai/openai-go"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
)
// ErrKind classifies a provider error for failover decision-making.
//
// Why: failover must decide, per error, whether to retry the same model
// (transient), bench it as broken (auth/model dead), or fail over without
// benching (this request's fault). Without a classifier every error looks
// the same and we'd either thrash a dead model or bench a healthy one.
// What: an enum of the four outcomes the failover algorithm distinguishes.
// Test: see classify_test.go — every branch is table-tested with faked SDK errors.
type ErrKind int
const (
// ErrUnknown is an unrecognized error. Failover treats it as transient
// (conservative — retry then fail over), EXCEPT context.Canceled which
// the caller special-cases as an abort.
ErrUnknown ErrKind = iota
// ErrTransient is a temporary failure (429/5xx/timeout): retry, then
// bench-and-fail-over if retries are exhausted.
ErrTransient
// ErrAuthDead is an auth failure or model-not-found (401/403/404): the
// model is unusable; bench immediately and fail over.
ErrAuthDead
// ErrRequestSpecific is the caller's fault for THIS request (400/413/422,
// unsupported feature): fail over to try a more capable model, but do NOT
// bench — the model itself is healthy.
ErrRequestSpecific
)
// classifyStatus maps an HTTP status code to an ErrKind.
//
// Why: openai-go and anthropic RequestError both expose a numeric StatusCode;
// centralizing the mapping keeps the per-SDK branches thin and consistent.
// What: 408/409/429/5xx → transient, 401/403/404 → auth-dead, 400/413/422 →
// request-specific, anything else → unknown.
// Test: covered indirectly via Classify table tests for each SDK.
func classifyStatus(code int) ErrKind {
switch code {
case 408, 409, 429, 500, 502, 503, 504:
return ErrTransient
case 401, 403, 404:
return ErrAuthDead
case 400, 413, 422:
return ErrRequestSpecific
default:
return ErrUnknown
}
}
// Classify inspects a provider error and returns its ErrKind.
//
// Why: the failover composite needs typed, status-code-aware classification to
// retry/bench/skip correctly across the anthropic, openai-compat, and ollama
// providers, each of which surfaces errors differently.
// What: prefers anthropic's typed Is*Err helpers, falls back to numeric status
// codes (openai-go, anthropic RequestError), then the openaicompat
// FeatureUnsupportedError, context errors, and finally an ollama HTTP-string
// fallback; unrecognized errors are ErrUnknown.
// Test: classify_test.go faked SDK errors exercise every branch.
func Classify(err error) ErrKind {
if err == nil {
return ErrUnknown
}
// context.Canceled is reported as ErrUnknown here; the failover algorithm
// special-cases it as an abort before consulting the kind.
if errors.Is(err, context.Canceled) {
return ErrUnknown
}
if errors.Is(err, context.DeadlineExceeded) {
return ErrTransient
}
// FeatureUnsupportedError is a permanent, request-shaped failure.
var featErr *openaicompat.FeatureUnsupportedError
if errors.As(err, &featErr) {
return ErrRequestSpecific
}
// Anthropic APIError: prefer the typed helpers (no StatusCode available).
var apiErr *anth.APIError
if errors.As(err, &apiErr) {
switch {
case apiErr.IsRateLimitErr(), apiErr.IsOverloadedErr(), apiErr.IsApiErr():
return ErrTransient
case apiErr.IsAuthenticationErr(), apiErr.IsPermissionErr(), apiErr.IsNotFoundErr():
return ErrAuthDead
case apiErr.IsTooLargeErr(), apiErr.IsInvalidRequestErr():
return ErrRequestSpecific
default:
return ErrUnknown
}
}
// Anthropic RequestError: status-code based.
var anthReqErr *anth.RequestError
if errors.As(err, &anthReqErr) {
return classifyStatus(anthReqErr.StatusCode)
}
// openai-go (openai/deepseek/moonshot/xai/groq): status-code based.
var oaiErr *openai.Error
if errors.As(err, &oaiErr) {
return classifyStatus(oaiErr.StatusCode)
}
// Ollama: no typed status — fall back to its "ollama: HTTP <code>:" string.
if k := classifyOllamaString(err.Error()); k != ErrUnknown {
return k
}
return ErrUnknown
}
// classifyOllamaString extracts an HTTP status from ollama's error string
// format ("ollama: HTTP <code>: ...") and classifies it.
//
// Why: the ollama provider stringifies errors without a typed status code, so
// failover can only classify by parsing the message.
// What: looks for "HTTP <code>" in the message and maps the code; returns
// ErrUnknown when no recognizable status is present.
// Test: classify_test.go ollama cases cover 5xx/429/401/404/400.
func classifyOllamaString(msg string) ErrKind {
const marker = "HTTP "
idx := strings.Index(msg, marker)
if idx < 0 {
return ErrUnknown
}
rest := msg[idx+len(marker):]
// Read up to 3 leading digits.
end := 0
for end < len(rest) && end < 3 && rest[end] >= '0' && rest[end] <= '9' {
end++
}
if end == 0 {
return ErrUnknown
}
code := 0
for i := 0; i < end; i++ {
code = code*10 + int(rest[i]-'0')
}
return classifyStatus(code)
}
// extractStatus best-effort pulls an HTTP status code out of a provider error
// for structured logging. Returns 0 when none is available.
//
// Why: log lines benefit from the numeric status even though classification
// may use typed helpers; this keeps that detail out of the hot path.
// What: checks anthropic RequestError and openai-go Error StatusCode fields,
// then parses ollama's "HTTP <code>" string; returns 0 otherwise.
// Test: covered indirectly via failover log assertions / manual inspection.
func extractStatus(err error) int {
if err == nil {
return 0
}
var anthReqErr *anth.RequestError
if errors.As(err, &anthReqErr) {
return anthReqErr.StatusCode
}
var oaiErr *openai.Error
if errors.As(err, &oaiErr) {
return oaiErr.StatusCode
}
const marker = "HTTP "
msg := err.Error()
if idx := strings.Index(msg, marker); idx >= 0 {
rest := msg[idx+len(marker):]
code, n := 0, 0
for n < len(rest) && n < 3 && rest[n] >= '0' && rest[n] <= '9' {
code = code*10 + int(rest[n]-'0')
n++
}
if n > 0 {
return code
}
}
return 0
}
// IsTransient reports whether an error should be retried/failed-over rather
// than treated as a hard, model-specific failure.
//
// Why: callers (and failover) want a one-call "is this worth retrying?" check
// that is conservative about unknown errors.
// What: returns true for ErrTransient and ErrUnknown (conservative), false for
// ErrAuthDead and ErrRequestSpecific.
// Test: TestIsTransient asserts 503→true, unknown→true, 401→false, 400→false.
func IsTransient(err error) bool {
switch Classify(err) {
case ErrTransient, ErrUnknown:
return true
default:
return false
}
}
+115
View File
@@ -0,0 +1,115 @@
package llm
import (
"context"
"errors"
"fmt"
"testing"
anth "github.com/liushuangls/go-anthropic/v2"
"github.com/openai/openai-go"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
)
func TestClassify(t *testing.T) {
tests := []struct {
name string
err error
want ErrKind
}{
// nil
{"nil", nil, ErrUnknown},
// openai-go status codes (transient)
{"openai 408", &openai.Error{StatusCode: 408}, ErrTransient},
{"openai 409", &openai.Error{StatusCode: 409}, ErrTransient},
{"openai 429", &openai.Error{StatusCode: 429}, ErrTransient},
{"openai 500", &openai.Error{StatusCode: 500}, ErrTransient},
{"openai 502", &openai.Error{StatusCode: 502}, ErrTransient},
{"openai 503", &openai.Error{StatusCode: 503}, ErrTransient},
{"openai 504", &openai.Error{StatusCode: 504}, ErrTransient},
// openai-go status codes (auth dead)
{"openai 401", &openai.Error{StatusCode: 401}, ErrAuthDead},
{"openai 403", &openai.Error{StatusCode: 403}, ErrAuthDead},
{"openai 404", &openai.Error{StatusCode: 404}, ErrAuthDead},
// openai-go status codes (request specific)
{"openai 400", &openai.Error{StatusCode: 400}, ErrRequestSpecific},
{"openai 413", &openai.Error{StatusCode: 413}, ErrRequestSpecific},
{"openai 422", &openai.Error{StatusCode: 422}, ErrRequestSpecific},
// openai unrecognized status -> unknown
{"openai 418", &openai.Error{StatusCode: 418}, ErrUnknown},
// wrapped openai error (providers wrap with %w)
{"wrapped openai 503", fmt.Errorf("openai completion error: %w", &openai.Error{StatusCode: 503}), ErrTransient},
// FeatureUnsupportedError -> request specific
{"feature unsupported", &openaicompat.FeatureUnsupportedError{Feature: "tools", Model: "m"}, ErrRequestSpecific},
{"wrapped feature unsupported", fmt.Errorf("x: %w", &openaicompat.FeatureUnsupportedError{Feature: "vision", Model: "m"}), ErrRequestSpecific},
// anthropic RequestError (status-code based)
{"anth req 503", &anth.RequestError{StatusCode: 503}, ErrTransient},
{"anth req 429", &anth.RequestError{StatusCode: 429}, ErrTransient},
{"anth req 401", &anth.RequestError{StatusCode: 401}, ErrAuthDead},
{"anth req 400", &anth.RequestError{StatusCode: 400}, ErrRequestSpecific},
{"wrapped anth req 502", fmt.Errorf("anthropic completion error: %w", &anth.RequestError{StatusCode: 502}), ErrTransient},
// anthropic APIError (helper based)
{"anth rate limit", &anth.APIError{Type: anth.ErrTypeRateLimit}, ErrTransient},
{"anth overloaded", &anth.APIError{Type: anth.ErrTypeOverloaded}, ErrTransient},
{"anth api", &anth.APIError{Type: anth.ErrTypeApi}, ErrTransient},
{"anth auth", &anth.APIError{Type: anth.ErrTypeAuthentication}, ErrAuthDead},
{"anth permission", &anth.APIError{Type: anth.ErrTypePermission}, ErrAuthDead},
{"anth not found", &anth.APIError{Type: anth.ErrTypeNotFound}, ErrAuthDead},
{"anth too large", &anth.APIError{Type: anth.ErrTypeTooLarge}, ErrRequestSpecific},
{"anth invalid request", &anth.APIError{Type: anth.ErrTypeInvalidRequest}, ErrRequestSpecific},
{"wrapped anth api error", fmt.Errorf("error, status code: 529, message: %w", &anth.APIError{Type: anth.ErrTypeOverloaded}), ErrTransient},
// context errors
{"context canceled", context.Canceled, ErrUnknown},
{"context deadline", context.DeadlineExceeded, ErrTransient},
{"wrapped deadline", fmt.Errorf("call failed: %w", context.DeadlineExceeded), ErrTransient},
// ollama string-based
{"ollama HTTP 503", errors.New("ollama: HTTP 503: service unavailable"), ErrTransient},
{"ollama HTTP 500", errors.New("ollama: HTTP 500: internal"), ErrTransient},
{"ollama HTTP 502", errors.New("ollama: HTTP 502: bad gateway"), ErrTransient},
{"ollama HTTP 504", errors.New("ollama: HTTP 504: timeout"), ErrTransient},
{"ollama HTTP 429", errors.New("ollama: HTTP 429: too many requests"), ErrTransient},
{"ollama HTTP 401", errors.New("ollama: HTTP 401: unauthorized"), ErrAuthDead},
{"ollama HTTP 404", errors.New("ollama: HTTP 404: not found"), ErrAuthDead},
{"ollama HTTP 400", errors.New("ollama: HTTP 400: bad request"), ErrRequestSpecific},
// unknown
{"random error", errors.New("something weird"), ErrUnknown},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Classify(tt.err)
if got != tt.want {
t.Errorf("Classify(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
func TestIsTransient(t *testing.T) {
// IsTransient treats both ErrTransient and ErrUnknown as "should retry"
// (conservative). Auth/request-specific are not transient.
if !IsTransient(&openai.Error{StatusCode: 503}) {
t.Error("503 should be transient")
}
if !IsTransient(errors.New("mystery")) {
t.Error("unknown should be treated as transient (conservative)")
}
if IsTransient(&openai.Error{StatusCode: 401}) {
t.Error("401 (auth dead) should NOT be transient")
}
if IsTransient(&openai.Error{StatusCode: 400}) {
t.Error("400 (request specific) should NOT be transient")
}
}
+588
View File
@@ -0,0 +1,588 @@
package llm
import (
"context"
"errors"
"fmt"
"log/slog"
"math/rand"
"strings"
"sync"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// ---------------------------------------------------------------------------
// Package-level defaults (mort configures these at boot via SetFailoverDefaults)
// ---------------------------------------------------------------------------
var (
// DefaultFailoverMaxRetries is the number of attempts per chain entry on
// transient errors before benching and moving to the next entry.
DefaultFailoverMaxRetries = 3
// DefaultFailoverCooldown is how long a model stays benched after a
// qualifying failure.
DefaultFailoverCooldown = 5 * time.Minute
// DefaultFailoverBackoff is the default exponential-with-jitter backoff.
DefaultFailoverBackoff = defaultBackoff
defaultsMu sync.Mutex
)
// defaultBackoff returns an exponential backoff with full jitter.
//
// Why: spreads retries to avoid thundering-herd against a recovering provider.
// What: base 200ms doubling per attempt, capped at 10s, with uniform jitter.
// Test: failover retry tests inject a fast backoff; this is the production default.
func defaultBackoff(attempt int) time.Duration {
if attempt < 1 {
attempt = 1
}
base := 200 * time.Millisecond
d := base << (attempt - 1)
if d > 10*time.Second {
d = 10 * time.Second
}
// Full jitter in [0, d].
return time.Duration(rand.Int63n(int64(d) + 1))
}
// SetFailoverDefaults overrides the package-level failover defaults used when
// no per-model options are supplied (e.g. comma-spec Parse).
//
// Why: mort wants to tune retries/cooldown once at boot without threading
// options through every Parse call.
// What: sets DefaultFailoverMaxRetries and DefaultFailoverCooldown under a lock.
// Test: set defaults, build a comma model, assert its cfg reflects them.
func SetFailoverDefaults(maxRetries int, cooldown time.Duration) {
defaultsMu.Lock()
defer defaultsMu.Unlock()
DefaultFailoverMaxRetries = maxRetries
DefaultFailoverCooldown = cooldown
}
// ---------------------------------------------------------------------------
// Global model health (process-wide bench registry)
// ---------------------------------------------------------------------------
// modelHealth tracks which concrete models are temporarily disabled (benched).
//
// Why: bench decisions must persist across requests and across all failover
// chains in the process, so a model that's down isn't retried by every chain.
// What: a mutex-guarded map keyed by specKey to its disabled state.
// Test: failover tests reset it via resetHealthForTest and assert via IsBenched.
type modelHealth struct {
mu sync.Mutex
disabled map[string]disabledState
}
type disabledState struct {
until time.Time
consecutiveFails int
manual bool
}
// globalHealth is the process-wide singleton shared by every failover chain.
var globalHealth = &modelHealth{disabled: map[string]disabledState{}}
// benchThreshold is the number of consecutive transient failures (each after
// exhausting retries) required before a model is benched. Auth-dead benches
// immediately regardless.
const benchThreshold = 1
// resetHealthForTest clears all bench state. Test-only.
func resetHealthForTest() {
globalHealth.mu.Lock()
defer globalHealth.mu.Unlock()
globalHealth.disabled = map[string]disabledState{}
}
// isBenched reports whether key is currently benched (and not expired).
func (h *modelHealth) isBenched(key string, now time.Time) bool {
h.mu.Lock()
defer h.mu.Unlock()
st, ok := h.disabled[key]
if !ok {
return false
}
if now.After(st.until) {
delete(h.disabled, key)
return false
}
return true
}
// recordSuccess clears any failure state for key.
func (h *modelHealth) recordSuccess(key string) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.disabled, key)
}
// recordTransientFailure increments the consecutive failure count and benches
// the model once the threshold is reached. Returns whether it is now benched
// and for how long.
func (h *modelHealth) recordTransientFailure(key string, cooldown time.Duration, now time.Time) (benched bool, until time.Time) {
h.mu.Lock()
defer h.mu.Unlock()
st := h.disabled[key]
st.consecutiveFails++
if st.consecutiveFails >= benchThreshold {
st.until = now.Add(cooldown)
st.manual = false
h.disabled[key] = st
return true, st.until
}
h.disabled[key] = st
return false, time.Time{}
}
// benchNow benches a model immediately (used for auth-dead errors).
func (h *modelHealth) benchNow(key string, cooldown time.Duration, now time.Time) time.Time {
h.mu.Lock()
defer h.mu.Unlock()
st := h.disabled[key]
st.consecutiveFails++
st.until = now.Add(cooldown)
st.manual = false
h.disabled[key] = st
return st.until
}
// benchManual benches a model until the given time, marking it manual.
func (h *modelHealth) benchManual(key string, until time.Time) {
h.mu.Lock()
defer h.mu.Unlock()
st := h.disabled[key]
st.until = until
st.manual = true
h.disabled[key] = st
}
// unbench removes a model's bench state, reporting whether it was benched.
func (h *modelHealth) unbench(key string, now time.Time) bool {
h.mu.Lock()
defer h.mu.Unlock()
st, ok := h.disabled[key]
if !ok || now.After(st.until) {
delete(h.disabled, key)
return false
}
delete(h.disabled, key)
return true
}
// list returns a snapshot of all currently-benched (non-expired) models.
func (h *modelHealth) list(now time.Time) []BenchedModel {
h.mu.Lock()
defer h.mu.Unlock()
var out []BenchedModel
for k, st := range h.disabled {
if now.After(st.until) {
delete(h.disabled, k)
continue
}
out = append(out, BenchedModel{
Model: k,
Until: st.until,
ConsecutiveFails: st.consecutiveFails,
Manual: st.manual,
})
}
return out
}
// ---------------------------------------------------------------------------
// Control API (admin commands / UI drive these)
// ---------------------------------------------------------------------------
// BenchedModel is a snapshot of a benched model's state.
type BenchedModel struct {
Model string
Until time.Time
ConsecutiveFails int
Manual bool
}
// ListBenched returns all currently-benched models across the process.
//
// Why: admin tooling needs to display which models are sidelined and why.
// What: snapshots the global health map, pruning expired entries.
// Test: BenchModel then ListBenched returns it with Manual=true.
func ListBenched() []BenchedModel { return globalHealth.list(time.Now()) }
// BenchModel manually benches a model until the given time.
//
// Why: operators sometimes need to force a model offline (incident, cost).
// What: records a manual bench in the global health registry.
// Test: BenchModel then IsBenched returns true and ListBenched shows Manual.
func BenchModel(spec string, until time.Time) { globalHealth.benchManual(spec, until) }
// UnbenchModel clears a model's bench state, returning whether it was benched.
//
// Why: operators need to bring a model back early after manual or auto bench.
// What: deletes the global health entry, reporting prior benched state.
// Test: bench then UnbenchModel returns true; a second call returns false.
func UnbenchModel(spec string) bool { return globalHealth.unbench(spec, time.Now()) }
// IsBenched reports whether a model is currently benched.
//
// Why: callers/tests want a quick health check for a concrete model.
// What: consults the global health registry (expired benches read as false).
// Test: BenchModel makes it true; an expired bench reads false.
func IsBenched(spec string) bool { return globalHealth.isBenched(spec, time.Now()) }
// ---------------------------------------------------------------------------
// Observer
// ---------------------------------------------------------------------------
// FailoverEvent describes a single failover decision for an observer.
type FailoverEvent struct {
Model string
Err error
Kind ErrKind
Attempt int
Benched bool
BenchedFor time.Duration
NextModel string
Request provider.Request
}
// FailoverObserver receives a FailoverEvent for each failover decision. mort
// uses this to persist the full prompt chain on failover.
type FailoverObserver func(ctx context.Context, ev FailoverEvent)
// ---------------------------------------------------------------------------
// Config + options
// ---------------------------------------------------------------------------
type failoverConfig struct {
maxRetries int
cooldown time.Duration
backoff func(attempt int) time.Duration
observer FailoverObserver
}
func defaultFailoverConfig() failoverConfig {
defaultsMu.Lock()
defer defaultsMu.Unlock()
return failoverConfig{
maxRetries: DefaultFailoverMaxRetries,
cooldown: DefaultFailoverCooldown,
backoff: DefaultFailoverBackoff,
}
}
// FailoverOption configures a failover model.
type FailoverOption func(*failoverConfig)
// WithFailoverMaxRetries sets attempts per entry on transient errors.
func WithFailoverMaxRetries(n int) FailoverOption {
return func(c *failoverConfig) {
if n < 1 {
n = 1
}
c.maxRetries = n
}
}
// WithFailoverCooldown sets how long a model stays benched after failure.
func WithFailoverCooldown(d time.Duration) FailoverOption {
return func(c *failoverConfig) { c.cooldown = d }
}
// WithFailoverBackoff sets the retry backoff function.
func WithFailoverBackoff(fn func(attempt int) time.Duration) FailoverOption {
return func(c *failoverConfig) {
if fn != nil {
c.backoff = fn
}
}
}
// WithFailoverObserver sets an observer notified on every failover decision.
func WithFailoverObserver(obs FailoverObserver) FailoverOption {
return func(c *failoverConfig) { c.observer = obs }
}
// ---------------------------------------------------------------------------
// Composite provider
// ---------------------------------------------------------------------------
type failoverEntry struct {
provider provider.Provider
model string // bare model name sent to the provider
specKey string // global health key (full concrete spec)
}
type failoverProvider struct {
entries []failoverEntry
cfg failoverConfig
health *modelHealth
}
// bareModel strips a leading "provider/" prefix, returning the model name the
// underlying provider expects. Specs without a slash are returned unchanged.
func bareModel(spec string) string {
if i := strings.Index(spec, "/"); i >= 0 {
return spec[i+1:]
}
return spec
}
// NewFailoverModel builds a composite *Model that tries each sub-model in order,
// retrying/benching per the configured policy and failing over on error.
//
// Why: callers hold *Model and the base Complete handler is hardwired to one
// provider, so failover (which must switch providers) is implemented as a
// composite provider wrapped back into a *Model.
// What: flattens any nested failover sub-models, derives a specKey per entry
// from its model string, and returns NewClient(fp).Model("failover").
// Test: failover_test.go exercises success, failover, bench, abort, and flatten.
func NewFailoverModel(models []*Model, opts ...FailoverOption) *Model {
cfg := defaultFailoverConfig()
for _, opt := range opts {
opt(&cfg)
}
var entries []failoverEntry
for _, m := range models {
if m == nil {
continue
}
// Flatten nested failover models so cooldowns/keys stay flat.
if fp, ok := m.provider.(*failoverProvider); ok {
entries = append(entries, fp.entries...)
continue
}
entries = append(entries, failoverEntry{
provider: m.provider,
model: bareModel(m.model),
specKey: m.model,
})
}
fp := &failoverProvider{
entries: entries,
cfg: cfg,
health: globalHealth,
}
return NewClient(fp).Model("failover")
}
// ParseChain parses each spec and combines them into one failover model.
//
// Why: lets callers build a failover chain from a slice of specs (full
// resolution per entry) without manually wiring providers.
// What: Parse each spec, preserve the original spec string as the bench key,
// flatten nested failover models, and return a composite *Model.
// Test: parse_test.go covers comma-spec parsing through the registry.
func ParseChain(specs []string, opts ...FailoverOption) (*Model, error) {
return DefaultRegistry.ParseChain(specs, opts...)
}
// ParseChain is the registry-scoped form of ParseChain.
func (r *Registry) ParseChain(specs []string, opts ...FailoverOption) (*Model, error) {
cfg := defaultFailoverConfig()
for _, opt := range opts {
opt(&cfg)
}
var entries []failoverEntry
for _, spec := range specs {
spec = strings.TrimSpace(spec)
if spec == "" {
continue
}
m, err := r.Parse(spec)
if err != nil {
return nil, fmt.Errorf("failover chain: parse %q: %w", spec, err)
}
if fp, ok := m.provider.(*failoverProvider); ok {
// A sub-spec was itself a comma/failover spec — splice its entries.
entries = append(entries, fp.entries...)
continue
}
entries = append(entries, failoverEntry{
provider: m.provider,
model: m.model,
specKey: spec,
})
}
if len(entries) == 0 {
return nil, fmt.Errorf("failover chain: no valid specs")
}
fp := &failoverProvider{entries: entries, cfg: cfg, health: globalHealth}
return NewClient(fp).Model("failover"), nil
}
// reqWithModel returns a shallow copy of req with its Model set to model.
//
// Why: each sub-provider must receive its own bare model name; the incoming
// req carries the placeholder "failover" model from the composite *Model.
// What: copies the struct (slices/pointers are shared, which is safe here since
// providers treat the request as read-only) and overrides Model.
// Test: TestFailover_PassesModelNameToProvider asserts the provider sees the bare name.
func reqWithModel(req provider.Request, model string) provider.Request {
req.Model = model
return req
}
// Complete implements provider.Provider with ordered failover.
func (f *failoverProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
now := time.Now()
// 1. Build the live set (not currently benched). Best-effort: if all are
// benched, ignore cooldowns rather than hard-fail.
var live []failoverEntry
for _, e := range f.entries {
if !f.health.isBenched(e.specKey, now) {
live = append(live, e)
}
}
if len(live) == 0 {
live = f.entries
}
var causes []error
for i, entry := range live {
nextModel := ""
if i+1 < len(live) {
nextModel = live[i+1].specKey
}
for attempt := 1; attempt <= f.cfg.maxRetries; attempt++ {
resp, err := entry.provider.Complete(ctx, reqWithModel(req, entry.model))
if err == nil {
f.health.recordSuccess(entry.specKey)
return resp, nil
}
// Caller aborted: stop everything, no failover, no bench.
if errors.Is(err, context.Canceled) {
return provider.Response{}, err
}
kind := Classify(err)
switch kind {
case ErrRequestSpecific:
f.emit(ctx, FailoverEvent{
Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt,
NextModel: nextModel, Request: req,
})
slog.Warn("failover: request-specific error, trying next model",
"model", entry.specKey, "kind", "request_specific",
"status", statusOf(err), "attempt", attempt, "next", nextModel)
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
goto nextEntry
case ErrAuthDead:
until := f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now())
f.emit(ctx, FailoverEvent{
Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt,
Benched: true, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req,
})
slog.Warn("failover: auth/model-dead error, benching model",
"model", entry.specKey, "kind", "auth_dead", "status", statusOf(err),
"attempt", attempt, "benched", true, "cooldown", f.cfg.cooldown,
"until", until, "next", nextModel)
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
goto nextEntry
default: // ErrTransient or ErrUnknown -> retry, then bench.
if attempt >= f.cfg.maxRetries {
benched, until := f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now())
f.emit(ctx, FailoverEvent{
Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt,
Benched: benched, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req,
})
slog.Warn("failover: transient error, retries exhausted",
"model", entry.specKey, "kind", kindString(kind), "status", statusOf(err),
"attempt", attempt, "benched", benched, "cooldown", f.cfg.cooldown,
"until", until, "next", nextModel)
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
goto nextEntry
}
// Sleep before retrying (respect ctx).
select {
case <-ctx.Done():
return provider.Response{}, ctx.Err()
case <-time.After(f.cfg.backoff(attempt)):
}
}
}
nextEntry:
}
return provider.Response{}, fmt.Errorf("failover: all %d models in chain failed: %w",
len(live), errors.Join(causes...))
}
// Stream implements provider.Provider. It fails over only on the INITIAL Stream
// call error (before any event). Once a stream begins, mid-stream failures are
// surfaced as-is — failover does not replay a partially-consumed stream.
func (f *failoverProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
now := time.Now()
var live []failoverEntry
for _, e := range f.entries {
if !f.health.isBenched(e.specKey, now) {
live = append(live, e)
}
}
if len(live) == 0 {
live = f.entries
}
var causes []error
for i, entry := range live {
nextModel := ""
if i+1 < len(live) {
nextModel = live[i+1].specKey
}
err := entry.provider.Stream(ctx, reqWithModel(req, entry.model), events)
if err == nil {
f.health.recordSuccess(entry.specKey)
return nil
}
if errors.Is(err, context.Canceled) {
return err
}
kind := Classify(err)
switch kind {
case ErrAuthDead:
f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now())
case ErrTransient, ErrUnknown:
f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now())
}
slog.Warn("failover(stream): error, trying next model",
"model", entry.specKey, "kind", kindString(kind), "status", statusOf(err), "next", nextModel)
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
}
return fmt.Errorf("failover(stream): all %d models in chain failed: %w", len(live), errors.Join(causes...))
}
func (f *failoverProvider) emit(ctx context.Context, ev FailoverEvent) {
if f.cfg.observer != nil {
f.cfg.observer(ctx, ev)
}
}
func kindString(k ErrKind) string {
switch k {
case ErrTransient:
return "transient"
case ErrAuthDead:
return "auth_dead"
case ErrRequestSpecific:
return "request_specific"
default:
return "unknown"
}
}
// statusOf best-effort extracts an HTTP status code for logging, or 0.
func statusOf(err error) int { return extractStatus(err) }
+302
View File
@@ -0,0 +1,302 @@
package llm
import (
"context"
"errors"
"strings"
"sync"
"testing"
"time"
"github.com/openai/openai-go"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// fastBackoff is a near-zero backoff so retry tests don't sleep.
func fastBackoff(int) time.Duration { return time.Microsecond }
func testFailoverOpts(extra ...FailoverOption) []FailoverOption {
base := []FailoverOption{
WithFailoverMaxRetries(2),
WithFailoverBackoff(fastBackoff),
WithFailoverCooldown(time.Minute),
}
return append(base, extra...)
}
// modelFor builds a *Model around a mock provider with a concrete model name,
// mimicking what Parse produces (so specKey resolution works).
func modelFor(p provider.Provider, name string) *Model {
return &Model{provider: p, model: name}
}
func TestFailover_FirstSucceeds(t *testing.T) {
resetHealthForTest()
a := newMockProvider(provider.Response{Text: "from-a"})
b := newMockProvider(provider.Response{Text: "from-b"})
fo := NewFailoverModel([]*Model{modelFor(a, "anthropic/a"), modelFor(b, "openai/b")}, testFailoverOpts()...)
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "from-a" {
t.Errorf("expected from-a, got %q", resp.Text)
}
// b must not have been called.
b.mu.Lock()
n := len(b.Requests)
b.mu.Unlock()
if n != 0 {
t.Errorf("expected b untouched, got %d calls", n)
}
}
func TestFailover_FailsOverToSecond(t *testing.T) {
resetHealthForTest()
// a always returns a request-specific error (400) -> fail over, no retry-bench loop noise.
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, &openai.Error{StatusCode: 400}
})
b := newMockProvider(provider.Response{Text: "from-b"})
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "from-b" {
t.Errorf("expected from-b, got %q", resp.Text)
}
// 400 is request-specific: a must NOT be benched.
if IsBenched("p/a") {
t.Error("p/a should not be benched on a 400")
}
}
func TestFailover_PassesModelNameToProvider(t *testing.T) {
resetHealthForTest()
a := newMockProvider(provider.Response{Text: "ok"})
fo := NewFailoverModel([]*Model{modelFor(a, "anthropic/claude-x")}, testFailoverOpts()...)
_, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err != nil {
t.Fatal(err)
}
if got := a.lastRequest().Model; got != "claude-x" {
t.Errorf("provider received model %q, want bare model name claude-x", got)
}
}
func TestFailover_AuthDeadBenchesImmediately(t *testing.T) {
resetHealthForTest()
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, &openai.Error{StatusCode: 401}
})
b := newMockProvider(provider.Response{Text: "from-b"})
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err != nil {
t.Fatal(err)
}
if resp.Text != "from-b" {
t.Errorf("expected from-b, got %q", resp.Text)
}
if !IsBenched("p/a") {
t.Error("p/a should be benched after auth-dead error")
}
// a should have been called exactly once (no retries on auth-dead).
a.mu.Lock()
n := len(a.Requests)
a.mu.Unlock()
if n != 1 {
t.Errorf("auth-dead should not retry; a called %d times", n)
}
}
func TestFailover_TransientRetriesThenBenches(t *testing.T) {
resetHealthForTest()
var calls int
var mu sync.Mutex
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
mu.Lock()
calls++
mu.Unlock()
return provider.Response{}, &openai.Error{StatusCode: 503}
})
b := newMockProvider(provider.Response{Text: "from-b"})
// maxRetries=2 means 2 attempts total per entry.
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")},
WithFailoverMaxRetries(2), WithFailoverBackoff(fastBackoff), WithFailoverCooldown(time.Minute))
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err != nil {
t.Fatal(err)
}
if resp.Text != "from-b" {
t.Errorf("expected from-b, got %q", resp.Text)
}
mu.Lock()
n := calls
mu.Unlock()
if n != 2 {
t.Errorf("expected 2 attempts on transient model, got %d", n)
}
if !IsBenched("p/a") {
t.Error("p/a should be benched after exhausting retries")
}
}
func TestFailover_AllFail(t *testing.T) {
resetHealthForTest()
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, &openai.Error{StatusCode: 400}
})
b := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, &openai.Error{StatusCode: 400}
})
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
_, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err == nil {
t.Fatal("expected error when all models fail")
}
if !strings.Contains(err.Error(), "2") {
t.Errorf("error should mention all 2 models failed: %v", err)
}
}
func TestFailover_ContextCanceledAborts(t *testing.T) {
resetHealthForTest()
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, context.Canceled
})
b := newMockProvider(provider.Response{Text: "from-b"})
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
_, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if !errors.Is(err, context.Canceled) {
t.Errorf("expected context.Canceled to abort, got %v", err)
}
// b must not be tried.
b.mu.Lock()
n := len(b.Requests)
b.mu.Unlock()
if n != 0 {
t.Errorf("canceled should not fail over; b called %d times", n)
}
if IsBenched("p/a") {
t.Error("canceled should not bench")
}
}
func TestFailover_AllBenchedBestEffort(t *testing.T) {
resetHealthForTest()
// Manually bench both, then ensure Complete still tries (best-effort) and succeeds.
BenchModel("p/a", time.Now().Add(time.Hour))
BenchModel("p/b", time.Now().Add(time.Hour))
a := newMockProvider(provider.Response{Text: "from-a"})
b := newMockProvider(provider.Response{Text: "from-b"})
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err != nil {
t.Fatalf("best-effort should still try benched models: %v", err)
}
if resp.Text != "from-a" {
t.Errorf("expected from-a, got %q", resp.Text)
}
}
func TestFailover_Observer(t *testing.T) {
resetHealthForTest()
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, &openai.Error{StatusCode: 401}
})
b := newMockProvider(provider.Response{Text: "from-b"})
var mu sync.Mutex
var events []FailoverEvent
obs := func(ctx context.Context, ev FailoverEvent) {
mu.Lock()
events = append(events, ev)
mu.Unlock()
}
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")},
append(testFailoverOpts(), WithFailoverObserver(obs))...)
_, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err != nil {
t.Fatal(err)
}
mu.Lock()
defer mu.Unlock()
if len(events) == 0 {
t.Fatal("expected observer to be called")
}
ev := events[0]
if ev.Model != "p/a" || ev.Kind != ErrAuthDead || !ev.Benched {
t.Errorf("unexpected event: %+v", ev)
}
if ev.NextModel != "p/b" {
t.Errorf("expected NextModel p/b, got %q", ev.NextModel)
}
if len(ev.Request.Messages) == 0 {
t.Error("observer event should carry the full request")
}
}
func TestFailover_ControlAPI(t *testing.T) {
resetHealthForTest()
if IsBenched("x/y") {
t.Error("should start unbenched")
}
until := time.Now().Add(time.Hour)
BenchModel("x/y", until)
if !IsBenched("x/y") {
t.Error("should be benched")
}
list := ListBenched()
if len(list) != 1 || list[0].Model != "x/y" || !list[0].Manual {
t.Errorf("unexpected list: %+v", list)
}
if !UnbenchModel("x/y") {
t.Error("UnbenchModel should report it was benched")
}
if IsBenched("x/y") {
t.Error("should be unbenched now")
}
if UnbenchModel("x/y") {
t.Error("UnbenchModel on non-benched should return false")
}
}
func TestFailover_ExpiredBenchIsLive(t *testing.T) {
resetHealthForTest()
// Bench in the past -> should be considered live again.
BenchModel("p/a", time.Now().Add(-time.Hour))
if IsBenched("p/a") {
t.Error("expired bench should not count as benched")
}
}
// TestParseChain exercises ParseChain via a registry-backed seam is covered in
// parse_test.go; here we verify NewFailoverModel flattens nested failover models.
func TestNewFailoverModel_Flattens(t *testing.T) {
resetHealthForTest()
a := newMockProvider(provider.Response{Text: "a"})
b := newMockProvider(provider.Response{Text: "b"})
c := newMockProvider(provider.Response{Text: "c"})
inner := NewFailoverModel([]*Model{modelFor(b, "p/b"), modelFor(c, "p/c")}, testFailoverOpts()...)
outer := NewFailoverModel([]*Model{modelFor(a, "p/a"), inner}, testFailoverOpts()...)
fp, ok := outer.provider.(*failoverProvider)
if !ok {
t.Fatalf("expected *failoverProvider, got %T", outer.provider)
}
if len(fp.entries) != 3 {
t.Errorf("expected flattened 3 entries, got %d", len(fp.entries))
}
keys := []string{fp.entries[0].specKey, fp.entries[1].specKey, fp.entries[2].specKey}
want := []string{"p/a", "p/b", "p/c"}
for i := range want {
if keys[i] != want[i] {
t.Errorf("entry %d specKey = %q, want %q", i, keys[i], want[i])
}
}
}
+40
View File
@@ -109,6 +109,25 @@ func Parse(spec string) (*Model, error) {
// What: walks the resolution chain and returns the final *Model with reasoning applied.
// Test: see parse_test.go for comprehensive table-driven tests.
func (r *Registry) Parse(spec string) (*Model, error) {
// Comma-separated specs become an ordered failover chain. A single part
// (after trimming/dropping empties) falls through to normal single-model
// parsing, preserving exact existing behavior for comma-free specs.
if strings.Contains(spec, ",") {
var parts []string
for _, p := range strings.Split(spec, ",") {
if p = strings.TrimSpace(p); p != "" {
parts = append(parts, p)
}
}
if len(parts) == 0 {
return nil, fmt.Errorf("%w: empty failover spec %q", ErrUnknownProvider, spec)
}
if len(parts) > 1 {
return r.ParseChain(parts)
}
spec = parts[0]
}
m, level, err := r.parse(spec, 0)
if err != nil {
return nil, err
@@ -134,6 +153,15 @@ func (r *Registry) parse(spec string, depth int) (*Model, ReasoningLevel, error)
target, isAlias := r.aliases[base]
r.mu.RUnlock()
if isAlias {
// An alias may expand to a comma-separated failover chain; route those
// through the comma-aware public Parse so the chain is built correctly.
if strings.Contains(target, ",") {
m, err := r.Parse(target)
if err != nil {
return nil, "", err
}
return m, userLevel, nil
}
m, aliasLevel, err := r.parse(target, depth+1)
if err != nil {
return nil, "", err
@@ -155,6 +183,18 @@ func (r *Registry) parse(spec string, depth int) (*Model, ReasoningLevel, error)
if resolved == "" {
return nil, "", fmt.Errorf("resolver returned empty spec for %q", base)
}
// A resolver may return a comma-separated failover chain.
if strings.Contains(resolved, ",") {
m, err := r.Parse(resolved)
if err != nil {
return nil, "", err
}
level := defaultLevel
if userLevel != "" {
level = userLevel
}
return m, level, nil
}
m, resolvedLevel, err := r.parse(resolved, depth+1)
if err != nil {
return nil, "", err
+78
View File
@@ -637,3 +637,81 @@ func TestNewRegistryIsolation(t *testing.T) {
t.Error("alias registered in r1 should not appear in r2")
}
}
// ---------------------------------------------------------------------------
// Comma-separated failover chains
// ---------------------------------------------------------------------------
func TestParse_CommaProducesFailover(t *testing.T) {
resetHealthForTest()
r, alpha, beta := testRegistry(func(string) string { return "" })
m, err := r.Parse("alpha/model-a,beta/model-b")
if err != nil {
t.Fatalf("Parse failover spec: %v", err)
}
fp, ok := m.provider.(*failoverProvider)
if !ok {
t.Fatalf("expected *failoverProvider, got %T", m.provider)
}
if len(fp.entries) != 2 {
t.Fatalf("expected 2 entries, got %d", len(fp.entries))
}
if fp.entries[0].specKey != "alpha/model-a" || fp.entries[1].specKey != "beta/model-b" {
t.Errorf("unexpected specKeys: %q, %q", fp.entries[0].specKey, fp.entries[1].specKey)
}
// Complete routes to the first provider and passes the bare model name.
_, err = m.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err != nil {
t.Fatal(err)
}
if alpha.lastModel != "model-a" {
t.Errorf("alpha got model %q, want model-a", alpha.lastModel)
}
if beta.lastModel != "" {
t.Errorf("beta should not have been called, got model %q", beta.lastModel)
}
}
func TestParse_NoCommaUnchanged(t *testing.T) {
resetHealthForTest()
r, _, _ := testRegistry(func(string) string { return "" })
m, err := r.Parse("alpha/model-a")
if err != nil {
t.Fatal(err)
}
if _, ok := m.provider.(*failoverProvider); ok {
t.Error("single (comma-free) spec must NOT produce a failover provider")
}
}
func TestParse_CommaSinglePartFallsThrough(t *testing.T) {
resetHealthForTest()
r, _, _ := testRegistry(func(string) string { return "" })
// Trailing comma / whitespace collapses to a single real part.
m, err := r.Parse("alpha/model-a, ")
if err != nil {
t.Fatal(err)
}
if _, ok := m.provider.(*failoverProvider); ok {
t.Error("a single effective part must not produce a failover provider")
}
}
func TestParse_CommaFlattensNested(t *testing.T) {
resetHealthForTest()
r, _, _ := testRegistry(func(string) string { return "" })
// Register an alias that is itself a comma chain.
r.RegisterAlias("pair", "alpha/model-a,beta/model-b")
m, err := r.Parse("pair,beta/model-b2")
if err != nil {
t.Fatal(err)
}
fp, ok := m.provider.(*failoverProvider)
if !ok {
t.Fatalf("expected *failoverProvider, got %T", m.provider)
}
if len(fp.entries) != 3 {
t.Errorf("expected flattened 3 entries, got %d", len(fp.entries))
}
}