From ae8e194fad6e261750dfc9b3a8e76072e0df2849 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Mon, 1 Jun 2026 00:30:08 +0200 Subject: [PATCH] feat(failover): model failover chains via comma-separated specs Parse("a,b,c") now returns one composite *llm.Model that tries each model in order, retrying transient failures, benching dead models, and failing over to the next. Comma-free specs are completely unchanged. - classify.go: Classify(err) ErrKind + IsTransient(err) error classifier mapping anthropic (typed Is*Err helpers + RequestError status), openai-go (*openai.Error status), openaicompat.FeatureUnsupportedError, context errors, and ollama "HTTP " strings to transient/auth-dead/request-specific/unknown. - failover.go: failoverProvider (satisfies provider.Provider) wrapped into a *Model via NewClient. Process-wide mutex-guarded modelHealth bench registry keyed by concrete spec, with cooldowns and a control API (ListBenched/BenchModel/UnbenchModel/IsBenched). NewFailoverModel + ParseChain constructors, FailoverOption config, FailoverObserver (carries the full request), and configurable package-level defaults. - parse.go: comma-aware Parse splits into a failover chain; alias/resolver targets that expand to comma chains are routed through the comma-aware path and flattened. All access to global health is mutex-guarded; tests reset it via resetHealthForTest and pass under go test -race. Co-Authored-By: Claude Opus 4.8 --- v2/classify.go | 207 ++++++++++++++++ v2/classify_test.go | 115 +++++++++ v2/failover.go | 588 ++++++++++++++++++++++++++++++++++++++++++++ v2/failover_test.go | 302 +++++++++++++++++++++++ v2/parse.go | 40 +++ v2/parse_test.go | 88 ++++++- 6 files changed, 1335 insertions(+), 5 deletions(-) create mode 100644 v2/classify.go create mode 100644 v2/classify_test.go create mode 100644 v2/failover.go create mode 100644 v2/failover_test.go diff --git a/v2/classify.go b/v2/classify.go new file mode 100644 index 0000000..684d25a --- /dev/null +++ b/v2/classify.go @@ -0,0 +1,207 @@ +package llm + +import ( + "context" + "errors" + "strings" + + anth "github.com/liushuangls/go-anthropic/v2" + "github.com/openai/openai-go" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +// ErrKind classifies a provider error for failover decision-making. +// +// Why: failover must decide, per error, whether to retry the same model +// (transient), bench it as broken (auth/model dead), or fail over without +// benching (this request's fault). Without a classifier every error looks +// the same and we'd either thrash a dead model or bench a healthy one. +// What: an enum of the four outcomes the failover algorithm distinguishes. +// Test: see classify_test.go — every branch is table-tested with faked SDK errors. +type ErrKind int + +const ( + // ErrUnknown is an unrecognized error. Failover treats it as transient + // (conservative — retry then fail over), EXCEPT context.Canceled which + // the caller special-cases as an abort. + ErrUnknown ErrKind = iota + // ErrTransient is a temporary failure (429/5xx/timeout): retry, then + // bench-and-fail-over if retries are exhausted. + ErrTransient + // ErrAuthDead is an auth failure or model-not-found (401/403/404): the + // model is unusable; bench immediately and fail over. + ErrAuthDead + // ErrRequestSpecific is the caller's fault for THIS request (400/413/422, + // unsupported feature): fail over to try a more capable model, but do NOT + // bench — the model itself is healthy. + ErrRequestSpecific +) + +// classifyStatus maps an HTTP status code to an ErrKind. +// +// Why: openai-go and anthropic RequestError both expose a numeric StatusCode; +// centralizing the mapping keeps the per-SDK branches thin and consistent. +// What: 408/409/429/5xx → transient, 401/403/404 → auth-dead, 400/413/422 → +// request-specific, anything else → unknown. +// Test: covered indirectly via Classify table tests for each SDK. +func classifyStatus(code int) ErrKind { + switch code { + case 408, 409, 429, 500, 502, 503, 504: + return ErrTransient + case 401, 403, 404: + return ErrAuthDead + case 400, 413, 422: + return ErrRequestSpecific + default: + return ErrUnknown + } +} + +// Classify inspects a provider error and returns its ErrKind. +// +// Why: the failover composite needs typed, status-code-aware classification to +// retry/bench/skip correctly across the anthropic, openai-compat, and ollama +// providers, each of which surfaces errors differently. +// What: prefers anthropic's typed Is*Err helpers, falls back to numeric status +// codes (openai-go, anthropic RequestError), then the openaicompat +// FeatureUnsupportedError, context errors, and finally an ollama HTTP-string +// fallback; unrecognized errors are ErrUnknown. +// Test: classify_test.go faked SDK errors exercise every branch. +func Classify(err error) ErrKind { + if err == nil { + return ErrUnknown + } + + // context.Canceled is reported as ErrUnknown here; the failover algorithm + // special-cases it as an abort before consulting the kind. + if errors.Is(err, context.Canceled) { + return ErrUnknown + } + if errors.Is(err, context.DeadlineExceeded) { + return ErrTransient + } + + // FeatureUnsupportedError is a permanent, request-shaped failure. + var featErr *openaicompat.FeatureUnsupportedError + if errors.As(err, &featErr) { + return ErrRequestSpecific + } + + // Anthropic APIError: prefer the typed helpers (no StatusCode available). + var apiErr *anth.APIError + if errors.As(err, &apiErr) { + switch { + case apiErr.IsRateLimitErr(), apiErr.IsOverloadedErr(), apiErr.IsApiErr(): + return ErrTransient + case apiErr.IsAuthenticationErr(), apiErr.IsPermissionErr(), apiErr.IsNotFoundErr(): + return ErrAuthDead + case apiErr.IsTooLargeErr(), apiErr.IsInvalidRequestErr(): + return ErrRequestSpecific + default: + return ErrUnknown + } + } + + // Anthropic RequestError: status-code based. + var anthReqErr *anth.RequestError + if errors.As(err, &anthReqErr) { + return classifyStatus(anthReqErr.StatusCode) + } + + // openai-go (openai/deepseek/moonshot/xai/groq): status-code based. + var oaiErr *openai.Error + if errors.As(err, &oaiErr) { + return classifyStatus(oaiErr.StatusCode) + } + + // Ollama: no typed status — fall back to its "ollama: HTTP :" string. + if k := classifyOllamaString(err.Error()); k != ErrUnknown { + return k + } + + return ErrUnknown +} + +// classifyOllamaString extracts an HTTP status from ollama's error string +// format ("ollama: HTTP : ...") and classifies it. +// +// Why: the ollama provider stringifies errors without a typed status code, so +// failover can only classify by parsing the message. +// What: looks for "HTTP " in the message and maps the code; returns +// ErrUnknown when no recognizable status is present. +// Test: classify_test.go ollama cases cover 5xx/429/401/404/400. +func classifyOllamaString(msg string) ErrKind { + const marker = "HTTP " + idx := strings.Index(msg, marker) + if idx < 0 { + return ErrUnknown + } + rest := msg[idx+len(marker):] + // Read up to 3 leading digits. + end := 0 + for end < len(rest) && end < 3 && rest[end] >= '0' && rest[end] <= '9' { + end++ + } + if end == 0 { + return ErrUnknown + } + code := 0 + for i := 0; i < end; i++ { + code = code*10 + int(rest[i]-'0') + } + return classifyStatus(code) +} + +// extractStatus best-effort pulls an HTTP status code out of a provider error +// for structured logging. Returns 0 when none is available. +// +// Why: log lines benefit from the numeric status even though classification +// may use typed helpers; this keeps that detail out of the hot path. +// What: checks anthropic RequestError and openai-go Error StatusCode fields, +// then parses ollama's "HTTP " string; returns 0 otherwise. +// Test: covered indirectly via failover log assertions / manual inspection. +func extractStatus(err error) int { + if err == nil { + return 0 + } + var anthReqErr *anth.RequestError + if errors.As(err, &anthReqErr) { + return anthReqErr.StatusCode + } + var oaiErr *openai.Error + if errors.As(err, &oaiErr) { + return oaiErr.StatusCode + } + const marker = "HTTP " + msg := err.Error() + if idx := strings.Index(msg, marker); idx >= 0 { + rest := msg[idx+len(marker):] + code, n := 0, 0 + for n < len(rest) && n < 3 && rest[n] >= '0' && rest[n] <= '9' { + code = code*10 + int(rest[n]-'0') + n++ + } + if n > 0 { + return code + } + } + return 0 +} + +// IsTransient reports whether an error should be retried/failed-over rather +// than treated as a hard, model-specific failure. +// +// Why: callers (and failover) want a one-call "is this worth retrying?" check +// that is conservative about unknown errors. +// What: returns true for ErrTransient and ErrUnknown (conservative), false for +// ErrAuthDead and ErrRequestSpecific. +// Test: TestIsTransient asserts 503→true, unknown→true, 401→false, 400→false. +func IsTransient(err error) bool { + switch Classify(err) { + case ErrTransient, ErrUnknown: + return true + default: + return false + } +} diff --git a/v2/classify_test.go b/v2/classify_test.go new file mode 100644 index 0000000..fa9ed79 --- /dev/null +++ b/v2/classify_test.go @@ -0,0 +1,115 @@ +package llm + +import ( + "context" + "errors" + "fmt" + "testing" + + anth "github.com/liushuangls/go-anthropic/v2" + "github.com/openai/openai-go" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" +) + +func TestClassify(t *testing.T) { + tests := []struct { + name string + err error + want ErrKind + }{ + // nil + {"nil", nil, ErrUnknown}, + + // openai-go status codes (transient) + {"openai 408", &openai.Error{StatusCode: 408}, ErrTransient}, + {"openai 409", &openai.Error{StatusCode: 409}, ErrTransient}, + {"openai 429", &openai.Error{StatusCode: 429}, ErrTransient}, + {"openai 500", &openai.Error{StatusCode: 500}, ErrTransient}, + {"openai 502", &openai.Error{StatusCode: 502}, ErrTransient}, + {"openai 503", &openai.Error{StatusCode: 503}, ErrTransient}, + {"openai 504", &openai.Error{StatusCode: 504}, ErrTransient}, + + // openai-go status codes (auth dead) + {"openai 401", &openai.Error{StatusCode: 401}, ErrAuthDead}, + {"openai 403", &openai.Error{StatusCode: 403}, ErrAuthDead}, + {"openai 404", &openai.Error{StatusCode: 404}, ErrAuthDead}, + + // openai-go status codes (request specific) + {"openai 400", &openai.Error{StatusCode: 400}, ErrRequestSpecific}, + {"openai 413", &openai.Error{StatusCode: 413}, ErrRequestSpecific}, + {"openai 422", &openai.Error{StatusCode: 422}, ErrRequestSpecific}, + + // openai unrecognized status -> unknown + {"openai 418", &openai.Error{StatusCode: 418}, ErrUnknown}, + + // wrapped openai error (providers wrap with %w) + {"wrapped openai 503", fmt.Errorf("openai completion error: %w", &openai.Error{StatusCode: 503}), ErrTransient}, + + // FeatureUnsupportedError -> request specific + {"feature unsupported", &openaicompat.FeatureUnsupportedError{Feature: "tools", Model: "m"}, ErrRequestSpecific}, + {"wrapped feature unsupported", fmt.Errorf("x: %w", &openaicompat.FeatureUnsupportedError{Feature: "vision", Model: "m"}), ErrRequestSpecific}, + + // anthropic RequestError (status-code based) + {"anth req 503", &anth.RequestError{StatusCode: 503}, ErrTransient}, + {"anth req 429", &anth.RequestError{StatusCode: 429}, ErrTransient}, + {"anth req 401", &anth.RequestError{StatusCode: 401}, ErrAuthDead}, + {"anth req 400", &anth.RequestError{StatusCode: 400}, ErrRequestSpecific}, + {"wrapped anth req 502", fmt.Errorf("anthropic completion error: %w", &anth.RequestError{StatusCode: 502}), ErrTransient}, + + // anthropic APIError (helper based) + {"anth rate limit", &anth.APIError{Type: anth.ErrTypeRateLimit}, ErrTransient}, + {"anth overloaded", &anth.APIError{Type: anth.ErrTypeOverloaded}, ErrTransient}, + {"anth api", &anth.APIError{Type: anth.ErrTypeApi}, ErrTransient}, + {"anth auth", &anth.APIError{Type: anth.ErrTypeAuthentication}, ErrAuthDead}, + {"anth permission", &anth.APIError{Type: anth.ErrTypePermission}, ErrAuthDead}, + {"anth not found", &anth.APIError{Type: anth.ErrTypeNotFound}, ErrAuthDead}, + {"anth too large", &anth.APIError{Type: anth.ErrTypeTooLarge}, ErrRequestSpecific}, + {"anth invalid request", &anth.APIError{Type: anth.ErrTypeInvalidRequest}, ErrRequestSpecific}, + {"wrapped anth api error", fmt.Errorf("error, status code: 529, message: %w", &anth.APIError{Type: anth.ErrTypeOverloaded}), ErrTransient}, + + // context errors + {"context canceled", context.Canceled, ErrUnknown}, + {"context deadline", context.DeadlineExceeded, ErrTransient}, + {"wrapped deadline", fmt.Errorf("call failed: %w", context.DeadlineExceeded), ErrTransient}, + + // ollama string-based + {"ollama HTTP 503", errors.New("ollama: HTTP 503: service unavailable"), ErrTransient}, + {"ollama HTTP 500", errors.New("ollama: HTTP 500: internal"), ErrTransient}, + {"ollama HTTP 502", errors.New("ollama: HTTP 502: bad gateway"), ErrTransient}, + {"ollama HTTP 504", errors.New("ollama: HTTP 504: timeout"), ErrTransient}, + {"ollama HTTP 429", errors.New("ollama: HTTP 429: too many requests"), ErrTransient}, + {"ollama HTTP 401", errors.New("ollama: HTTP 401: unauthorized"), ErrAuthDead}, + {"ollama HTTP 404", errors.New("ollama: HTTP 404: not found"), ErrAuthDead}, + {"ollama HTTP 400", errors.New("ollama: HTTP 400: bad request"), ErrRequestSpecific}, + + // unknown + {"random error", errors.New("something weird"), ErrUnknown}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Classify(tt.err) + if got != tt.want { + t.Errorf("Classify(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestIsTransient(t *testing.T) { + // IsTransient treats both ErrTransient and ErrUnknown as "should retry" + // (conservative). Auth/request-specific are not transient. + if !IsTransient(&openai.Error{StatusCode: 503}) { + t.Error("503 should be transient") + } + if !IsTransient(errors.New("mystery")) { + t.Error("unknown should be treated as transient (conservative)") + } + if IsTransient(&openai.Error{StatusCode: 401}) { + t.Error("401 (auth dead) should NOT be transient") + } + if IsTransient(&openai.Error{StatusCode: 400}) { + t.Error("400 (request specific) should NOT be transient") + } +} diff --git a/v2/failover.go b/v2/failover.go new file mode 100644 index 0000000..f3fcef1 --- /dev/null +++ b/v2/failover.go @@ -0,0 +1,588 @@ +package llm + +import ( + "context" + "errors" + "fmt" + "log/slog" + "math/rand" + "strings" + "sync" + "time" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// --------------------------------------------------------------------------- +// Package-level defaults (mort configures these at boot via SetFailoverDefaults) +// --------------------------------------------------------------------------- + +var ( + // DefaultFailoverMaxRetries is the number of attempts per chain entry on + // transient errors before benching and moving to the next entry. + DefaultFailoverMaxRetries = 3 + // DefaultFailoverCooldown is how long a model stays benched after a + // qualifying failure. + DefaultFailoverCooldown = 5 * time.Minute + // DefaultFailoverBackoff is the default exponential-with-jitter backoff. + DefaultFailoverBackoff = defaultBackoff + + defaultsMu sync.Mutex +) + +// defaultBackoff returns an exponential backoff with full jitter. +// +// Why: spreads retries to avoid thundering-herd against a recovering provider. +// What: base 200ms doubling per attempt, capped at 10s, with uniform jitter. +// Test: failover retry tests inject a fast backoff; this is the production default. +func defaultBackoff(attempt int) time.Duration { + if attempt < 1 { + attempt = 1 + } + base := 200 * time.Millisecond + d := base << (attempt - 1) + if d > 10*time.Second { + d = 10 * time.Second + } + // Full jitter in [0, d]. + return time.Duration(rand.Int63n(int64(d) + 1)) +} + +// SetFailoverDefaults overrides the package-level failover defaults used when +// no per-model options are supplied (e.g. comma-spec Parse). +// +// Why: mort wants to tune retries/cooldown once at boot without threading +// options through every Parse call. +// What: sets DefaultFailoverMaxRetries and DefaultFailoverCooldown under a lock. +// Test: set defaults, build a comma model, assert its cfg reflects them. +func SetFailoverDefaults(maxRetries int, cooldown time.Duration) { + defaultsMu.Lock() + defer defaultsMu.Unlock() + DefaultFailoverMaxRetries = maxRetries + DefaultFailoverCooldown = cooldown +} + +// --------------------------------------------------------------------------- +// Global model health (process-wide bench registry) +// --------------------------------------------------------------------------- + +// modelHealth tracks which concrete models are temporarily disabled (benched). +// +// Why: bench decisions must persist across requests and across all failover +// chains in the process, so a model that's down isn't retried by every chain. +// What: a mutex-guarded map keyed by specKey to its disabled state. +// Test: failover tests reset it via resetHealthForTest and assert via IsBenched. +type modelHealth struct { + mu sync.Mutex + disabled map[string]disabledState +} + +type disabledState struct { + until time.Time + consecutiveFails int + manual bool +} + +// globalHealth is the process-wide singleton shared by every failover chain. +var globalHealth = &modelHealth{disabled: map[string]disabledState{}} + +// benchThreshold is the number of consecutive transient failures (each after +// exhausting retries) required before a model is benched. Auth-dead benches +// immediately regardless. +const benchThreshold = 1 + +// resetHealthForTest clears all bench state. Test-only. +func resetHealthForTest() { + globalHealth.mu.Lock() + defer globalHealth.mu.Unlock() + globalHealth.disabled = map[string]disabledState{} +} + +// isBenched reports whether key is currently benched (and not expired). +func (h *modelHealth) isBenched(key string, now time.Time) bool { + h.mu.Lock() + defer h.mu.Unlock() + st, ok := h.disabled[key] + if !ok { + return false + } + if now.After(st.until) { + delete(h.disabled, key) + return false + } + return true +} + +// recordSuccess clears any failure state for key. +func (h *modelHealth) recordSuccess(key string) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.disabled, key) +} + +// recordTransientFailure increments the consecutive failure count and benches +// the model once the threshold is reached. Returns whether it is now benched +// and for how long. +func (h *modelHealth) recordTransientFailure(key string, cooldown time.Duration, now time.Time) (benched bool, until time.Time) { + h.mu.Lock() + defer h.mu.Unlock() + st := h.disabled[key] + st.consecutiveFails++ + if st.consecutiveFails >= benchThreshold { + st.until = now.Add(cooldown) + st.manual = false + h.disabled[key] = st + return true, st.until + } + h.disabled[key] = st + return false, time.Time{} +} + +// benchNow benches a model immediately (used for auth-dead errors). +func (h *modelHealth) benchNow(key string, cooldown time.Duration, now time.Time) time.Time { + h.mu.Lock() + defer h.mu.Unlock() + st := h.disabled[key] + st.consecutiveFails++ + st.until = now.Add(cooldown) + st.manual = false + h.disabled[key] = st + return st.until +} + +// benchManual benches a model until the given time, marking it manual. +func (h *modelHealth) benchManual(key string, until time.Time) { + h.mu.Lock() + defer h.mu.Unlock() + st := h.disabled[key] + st.until = until + st.manual = true + h.disabled[key] = st +} + +// unbench removes a model's bench state, reporting whether it was benched. +func (h *modelHealth) unbench(key string, now time.Time) bool { + h.mu.Lock() + defer h.mu.Unlock() + st, ok := h.disabled[key] + if !ok || now.After(st.until) { + delete(h.disabled, key) + return false + } + delete(h.disabled, key) + return true +} + +// list returns a snapshot of all currently-benched (non-expired) models. +func (h *modelHealth) list(now time.Time) []BenchedModel { + h.mu.Lock() + defer h.mu.Unlock() + var out []BenchedModel + for k, st := range h.disabled { + if now.After(st.until) { + delete(h.disabled, k) + continue + } + out = append(out, BenchedModel{ + Model: k, + Until: st.until, + ConsecutiveFails: st.consecutiveFails, + Manual: st.manual, + }) + } + return out +} + +// --------------------------------------------------------------------------- +// Control API (admin commands / UI drive these) +// --------------------------------------------------------------------------- + +// BenchedModel is a snapshot of a benched model's state. +type BenchedModel struct { + Model string + Until time.Time + ConsecutiveFails int + Manual bool +} + +// ListBenched returns all currently-benched models across the process. +// +// Why: admin tooling needs to display which models are sidelined and why. +// What: snapshots the global health map, pruning expired entries. +// Test: BenchModel then ListBenched returns it with Manual=true. +func ListBenched() []BenchedModel { return globalHealth.list(time.Now()) } + +// BenchModel manually benches a model until the given time. +// +// Why: operators sometimes need to force a model offline (incident, cost). +// What: records a manual bench in the global health registry. +// Test: BenchModel then IsBenched returns true and ListBenched shows Manual. +func BenchModel(spec string, until time.Time) { globalHealth.benchManual(spec, until) } + +// UnbenchModel clears a model's bench state, returning whether it was benched. +// +// Why: operators need to bring a model back early after manual or auto bench. +// What: deletes the global health entry, reporting prior benched state. +// Test: bench then UnbenchModel returns true; a second call returns false. +func UnbenchModel(spec string) bool { return globalHealth.unbench(spec, time.Now()) } + +// IsBenched reports whether a model is currently benched. +// +// Why: callers/tests want a quick health check for a concrete model. +// What: consults the global health registry (expired benches read as false). +// Test: BenchModel makes it true; an expired bench reads false. +func IsBenched(spec string) bool { return globalHealth.isBenched(spec, time.Now()) } + +// --------------------------------------------------------------------------- +// Observer +// --------------------------------------------------------------------------- + +// FailoverEvent describes a single failover decision for an observer. +type FailoverEvent struct { + Model string + Err error + Kind ErrKind + Attempt int + Benched bool + BenchedFor time.Duration + NextModel string + Request provider.Request +} + +// FailoverObserver receives a FailoverEvent for each failover decision. mort +// uses this to persist the full prompt chain on failover. +type FailoverObserver func(ctx context.Context, ev FailoverEvent) + +// --------------------------------------------------------------------------- +// Config + options +// --------------------------------------------------------------------------- + +type failoverConfig struct { + maxRetries int + cooldown time.Duration + backoff func(attempt int) time.Duration + observer FailoverObserver +} + +func defaultFailoverConfig() failoverConfig { + defaultsMu.Lock() + defer defaultsMu.Unlock() + return failoverConfig{ + maxRetries: DefaultFailoverMaxRetries, + cooldown: DefaultFailoverCooldown, + backoff: DefaultFailoverBackoff, + } +} + +// FailoverOption configures a failover model. +type FailoverOption func(*failoverConfig) + +// WithFailoverMaxRetries sets attempts per entry on transient errors. +func WithFailoverMaxRetries(n int) FailoverOption { + return func(c *failoverConfig) { + if n < 1 { + n = 1 + } + c.maxRetries = n + } +} + +// WithFailoverCooldown sets how long a model stays benched after failure. +func WithFailoverCooldown(d time.Duration) FailoverOption { + return func(c *failoverConfig) { c.cooldown = d } +} + +// WithFailoverBackoff sets the retry backoff function. +func WithFailoverBackoff(fn func(attempt int) time.Duration) FailoverOption { + return func(c *failoverConfig) { + if fn != nil { + c.backoff = fn + } + } +} + +// WithFailoverObserver sets an observer notified on every failover decision. +func WithFailoverObserver(obs FailoverObserver) FailoverOption { + return func(c *failoverConfig) { c.observer = obs } +} + +// --------------------------------------------------------------------------- +// Composite provider +// --------------------------------------------------------------------------- + +type failoverEntry struct { + provider provider.Provider + model string // bare model name sent to the provider + specKey string // global health key (full concrete spec) +} + +type failoverProvider struct { + entries []failoverEntry + cfg failoverConfig + health *modelHealth +} + +// bareModel strips a leading "provider/" prefix, returning the model name the +// underlying provider expects. Specs without a slash are returned unchanged. +func bareModel(spec string) string { + if i := strings.Index(spec, "/"); i >= 0 { + return spec[i+1:] + } + return spec +} + +// NewFailoverModel builds a composite *Model that tries each sub-model in order, +// retrying/benching per the configured policy and failing over on error. +// +// Why: callers hold *Model and the base Complete handler is hardwired to one +// provider, so failover (which must switch providers) is implemented as a +// composite provider wrapped back into a *Model. +// What: flattens any nested failover sub-models, derives a specKey per entry +// from its model string, and returns NewClient(fp).Model("failover"). +// Test: failover_test.go exercises success, failover, bench, abort, and flatten. +func NewFailoverModel(models []*Model, opts ...FailoverOption) *Model { + cfg := defaultFailoverConfig() + for _, opt := range opts { + opt(&cfg) + } + + var entries []failoverEntry + for _, m := range models { + if m == nil { + continue + } + // Flatten nested failover models so cooldowns/keys stay flat. + if fp, ok := m.provider.(*failoverProvider); ok { + entries = append(entries, fp.entries...) + continue + } + entries = append(entries, failoverEntry{ + provider: m.provider, + model: bareModel(m.model), + specKey: m.model, + }) + } + + fp := &failoverProvider{ + entries: entries, + cfg: cfg, + health: globalHealth, + } + return NewClient(fp).Model("failover") +} + +// ParseChain parses each spec and combines them into one failover model. +// +// Why: lets callers build a failover chain from a slice of specs (full +// resolution per entry) without manually wiring providers. +// What: Parse each spec, preserve the original spec string as the bench key, +// flatten nested failover models, and return a composite *Model. +// Test: parse_test.go covers comma-spec parsing through the registry. +func ParseChain(specs []string, opts ...FailoverOption) (*Model, error) { + return DefaultRegistry.ParseChain(specs, opts...) +} + +// ParseChain is the registry-scoped form of ParseChain. +func (r *Registry) ParseChain(specs []string, opts ...FailoverOption) (*Model, error) { + cfg := defaultFailoverConfig() + for _, opt := range opts { + opt(&cfg) + } + + var entries []failoverEntry + for _, spec := range specs { + spec = strings.TrimSpace(spec) + if spec == "" { + continue + } + m, err := r.Parse(spec) + if err != nil { + return nil, fmt.Errorf("failover chain: parse %q: %w", spec, err) + } + if fp, ok := m.provider.(*failoverProvider); ok { + // A sub-spec was itself a comma/failover spec — splice its entries. + entries = append(entries, fp.entries...) + continue + } + entries = append(entries, failoverEntry{ + provider: m.provider, + model: m.model, + specKey: spec, + }) + } + if len(entries) == 0 { + return nil, fmt.Errorf("failover chain: no valid specs") + } + + fp := &failoverProvider{entries: entries, cfg: cfg, health: globalHealth} + return NewClient(fp).Model("failover"), nil +} + +// reqWithModel returns a shallow copy of req with its Model set to model. +// +// Why: each sub-provider must receive its own bare model name; the incoming +// req carries the placeholder "failover" model from the composite *Model. +// What: copies the struct (slices/pointers are shared, which is safe here since +// providers treat the request as read-only) and overrides Model. +// Test: TestFailover_PassesModelNameToProvider asserts the provider sees the bare name. +func reqWithModel(req provider.Request, model string) provider.Request { + req.Model = model + return req +} + +// Complete implements provider.Provider with ordered failover. +func (f *failoverProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { + now := time.Now() + + // 1. Build the live set (not currently benched). Best-effort: if all are + // benched, ignore cooldowns rather than hard-fail. + var live []failoverEntry + for _, e := range f.entries { + if !f.health.isBenched(e.specKey, now) { + live = append(live, e) + } + } + if len(live) == 0 { + live = f.entries + } + + var causes []error + + for i, entry := range live { + nextModel := "" + if i+1 < len(live) { + nextModel = live[i+1].specKey + } + + for attempt := 1; attempt <= f.cfg.maxRetries; attempt++ { + resp, err := entry.provider.Complete(ctx, reqWithModel(req, entry.model)) + if err == nil { + f.health.recordSuccess(entry.specKey) + return resp, nil + } + + // Caller aborted: stop everything, no failover, no bench. + if errors.Is(err, context.Canceled) { + return provider.Response{}, err + } + + kind := Classify(err) + + switch kind { + case ErrRequestSpecific: + f.emit(ctx, FailoverEvent{ + Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt, + NextModel: nextModel, Request: req, + }) + slog.Warn("failover: request-specific error, trying next model", + "model", entry.specKey, "kind", "request_specific", + "status", statusOf(err), "attempt", attempt, "next", nextModel) + causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err)) + goto nextEntry + + case ErrAuthDead: + until := f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now()) + f.emit(ctx, FailoverEvent{ + Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt, + Benched: true, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req, + }) + slog.Warn("failover: auth/model-dead error, benching model", + "model", entry.specKey, "kind", "auth_dead", "status", statusOf(err), + "attempt", attempt, "benched", true, "cooldown", f.cfg.cooldown, + "until", until, "next", nextModel) + causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err)) + goto nextEntry + + default: // ErrTransient or ErrUnknown -> retry, then bench. + if attempt >= f.cfg.maxRetries { + benched, until := f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now()) + f.emit(ctx, FailoverEvent{ + Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt, + Benched: benched, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req, + }) + slog.Warn("failover: transient error, retries exhausted", + "model", entry.specKey, "kind", kindString(kind), "status", statusOf(err), + "attempt", attempt, "benched", benched, "cooldown", f.cfg.cooldown, + "until", until, "next", nextModel) + causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err)) + goto nextEntry + } + // Sleep before retrying (respect ctx). + select { + case <-ctx.Done(): + return provider.Response{}, ctx.Err() + case <-time.After(f.cfg.backoff(attempt)): + } + } + } + nextEntry: + } + + return provider.Response{}, fmt.Errorf("failover: all %d models in chain failed: %w", + len(live), errors.Join(causes...)) +} + +// Stream implements provider.Provider. It fails over only on the INITIAL Stream +// call error (before any event). Once a stream begins, mid-stream failures are +// surfaced as-is — failover does not replay a partially-consumed stream. +func (f *failoverProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { + now := time.Now() + var live []failoverEntry + for _, e := range f.entries { + if !f.health.isBenched(e.specKey, now) { + live = append(live, e) + } + } + if len(live) == 0 { + live = f.entries + } + + var causes []error + for i, entry := range live { + nextModel := "" + if i+1 < len(live) { + nextModel = live[i+1].specKey + } + err := entry.provider.Stream(ctx, reqWithModel(req, entry.model), events) + if err == nil { + f.health.recordSuccess(entry.specKey) + return nil + } + if errors.Is(err, context.Canceled) { + return err + } + kind := Classify(err) + switch kind { + case ErrAuthDead: + f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now()) + case ErrTransient, ErrUnknown: + f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now()) + } + slog.Warn("failover(stream): error, trying next model", + "model", entry.specKey, "kind", kindString(kind), "status", statusOf(err), "next", nextModel) + causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err)) + } + return fmt.Errorf("failover(stream): all %d models in chain failed: %w", len(live), errors.Join(causes...)) +} + +func (f *failoverProvider) emit(ctx context.Context, ev FailoverEvent) { + if f.cfg.observer != nil { + f.cfg.observer(ctx, ev) + } +} + +func kindString(k ErrKind) string { + switch k { + case ErrTransient: + return "transient" + case ErrAuthDead: + return "auth_dead" + case ErrRequestSpecific: + return "request_specific" + default: + return "unknown" + } +} + +// statusOf best-effort extracts an HTTP status code for logging, or 0. +func statusOf(err error) int { return extractStatus(err) } diff --git a/v2/failover_test.go b/v2/failover_test.go new file mode 100644 index 0000000..7c644cd --- /dev/null +++ b/v2/failover_test.go @@ -0,0 +1,302 @@ +package llm + +import ( + "context" + "errors" + "strings" + "sync" + "testing" + "time" + + "github.com/openai/openai-go" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// fastBackoff is a near-zero backoff so retry tests don't sleep. +func fastBackoff(int) time.Duration { return time.Microsecond } + +func testFailoverOpts(extra ...FailoverOption) []FailoverOption { + base := []FailoverOption{ + WithFailoverMaxRetries(2), + WithFailoverBackoff(fastBackoff), + WithFailoverCooldown(time.Minute), + } + return append(base, extra...) +} + +// modelFor builds a *Model around a mock provider with a concrete model name, +// mimicking what Parse produces (so specKey resolution works). +func modelFor(p provider.Provider, name string) *Model { + return &Model{provider: p, model: name} +} + +func TestFailover_FirstSucceeds(t *testing.T) { + resetHealthForTest() + a := newMockProvider(provider.Response{Text: "from-a"}) + b := newMockProvider(provider.Response{Text: "from-b"}) + + fo := NewFailoverModel([]*Model{modelFor(a, "anthropic/a"), modelFor(b, "openai/b")}, testFailoverOpts()...) + resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "from-a" { + t.Errorf("expected from-a, got %q", resp.Text) + } + // b must not have been called. + b.mu.Lock() + n := len(b.Requests) + b.mu.Unlock() + if n != 0 { + t.Errorf("expected b untouched, got %d calls", n) + } +} + +func TestFailover_FailsOverToSecond(t *testing.T) { + resetHealthForTest() + // a always returns a request-specific error (400) -> fail over, no retry-bench loop noise. + a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, &openai.Error{StatusCode: 400} + }) + b := newMockProvider(provider.Response{Text: "from-b"}) + + fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) + resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "from-b" { + t.Errorf("expected from-b, got %q", resp.Text) + } + // 400 is request-specific: a must NOT be benched. + if IsBenched("p/a") { + t.Error("p/a should not be benched on a 400") + } +} + +func TestFailover_PassesModelNameToProvider(t *testing.T) { + resetHealthForTest() + a := newMockProvider(provider.Response{Text: "ok"}) + fo := NewFailoverModel([]*Model{modelFor(a, "anthropic/claude-x")}, testFailoverOpts()...) + _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatal(err) + } + if got := a.lastRequest().Model; got != "claude-x" { + t.Errorf("provider received model %q, want bare model name claude-x", got) + } +} + +func TestFailover_AuthDeadBenchesImmediately(t *testing.T) { + resetHealthForTest() + a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, &openai.Error{StatusCode: 401} + }) + b := newMockProvider(provider.Response{Text: "from-b"}) + fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) + resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatal(err) + } + if resp.Text != "from-b" { + t.Errorf("expected from-b, got %q", resp.Text) + } + if !IsBenched("p/a") { + t.Error("p/a should be benched after auth-dead error") + } + // a should have been called exactly once (no retries on auth-dead). + a.mu.Lock() + n := len(a.Requests) + a.mu.Unlock() + if n != 1 { + t.Errorf("auth-dead should not retry; a called %d times", n) + } +} + +func TestFailover_TransientRetriesThenBenches(t *testing.T) { + resetHealthForTest() + var calls int + var mu sync.Mutex + a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + mu.Lock() + calls++ + mu.Unlock() + return provider.Response{}, &openai.Error{StatusCode: 503} + }) + b := newMockProvider(provider.Response{Text: "from-b"}) + // maxRetries=2 means 2 attempts total per entry. + fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, + WithFailoverMaxRetries(2), WithFailoverBackoff(fastBackoff), WithFailoverCooldown(time.Minute)) + resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatal(err) + } + if resp.Text != "from-b" { + t.Errorf("expected from-b, got %q", resp.Text) + } + mu.Lock() + n := calls + mu.Unlock() + if n != 2 { + t.Errorf("expected 2 attempts on transient model, got %d", n) + } + if !IsBenched("p/a") { + t.Error("p/a should be benched after exhausting retries") + } +} + +func TestFailover_AllFail(t *testing.T) { + resetHealthForTest() + a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, &openai.Error{StatusCode: 400} + }) + b := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, &openai.Error{StatusCode: 400} + }) + fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) + _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err == nil { + t.Fatal("expected error when all models fail") + } + if !strings.Contains(err.Error(), "2") { + t.Errorf("error should mention all 2 models failed: %v", err) + } +} + +func TestFailover_ContextCanceledAborts(t *testing.T) { + resetHealthForTest() + a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, context.Canceled + }) + b := newMockProvider(provider.Response{Text: "from-b"}) + fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) + _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if !errors.Is(err, context.Canceled) { + t.Errorf("expected context.Canceled to abort, got %v", err) + } + // b must not be tried. + b.mu.Lock() + n := len(b.Requests) + b.mu.Unlock() + if n != 0 { + t.Errorf("canceled should not fail over; b called %d times", n) + } + if IsBenched("p/a") { + t.Error("canceled should not bench") + } +} + +func TestFailover_AllBenchedBestEffort(t *testing.T) { + resetHealthForTest() + // Manually bench both, then ensure Complete still tries (best-effort) and succeeds. + BenchModel("p/a", time.Now().Add(time.Hour)) + BenchModel("p/b", time.Now().Add(time.Hour)) + a := newMockProvider(provider.Response{Text: "from-a"}) + b := newMockProvider(provider.Response{Text: "from-b"}) + fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...) + resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatalf("best-effort should still try benched models: %v", err) + } + if resp.Text != "from-a" { + t.Errorf("expected from-a, got %q", resp.Text) + } +} + +func TestFailover_Observer(t *testing.T) { + resetHealthForTest() + a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, &openai.Error{StatusCode: 401} + }) + b := newMockProvider(provider.Response{Text: "from-b"}) + + var mu sync.Mutex + var events []FailoverEvent + obs := func(ctx context.Context, ev FailoverEvent) { + mu.Lock() + events = append(events, ev) + mu.Unlock() + } + fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, + append(testFailoverOpts(), WithFailoverObserver(obs))...) + _, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatal(err) + } + mu.Lock() + defer mu.Unlock() + if len(events) == 0 { + t.Fatal("expected observer to be called") + } + ev := events[0] + if ev.Model != "p/a" || ev.Kind != ErrAuthDead || !ev.Benched { + t.Errorf("unexpected event: %+v", ev) + } + if ev.NextModel != "p/b" { + t.Errorf("expected NextModel p/b, got %q", ev.NextModel) + } + if len(ev.Request.Messages) == 0 { + t.Error("observer event should carry the full request") + } +} + +func TestFailover_ControlAPI(t *testing.T) { + resetHealthForTest() + if IsBenched("x/y") { + t.Error("should start unbenched") + } + until := time.Now().Add(time.Hour) + BenchModel("x/y", until) + if !IsBenched("x/y") { + t.Error("should be benched") + } + list := ListBenched() + if len(list) != 1 || list[0].Model != "x/y" || !list[0].Manual { + t.Errorf("unexpected list: %+v", list) + } + if !UnbenchModel("x/y") { + t.Error("UnbenchModel should report it was benched") + } + if IsBenched("x/y") { + t.Error("should be unbenched now") + } + if UnbenchModel("x/y") { + t.Error("UnbenchModel on non-benched should return false") + } +} + +func TestFailover_ExpiredBenchIsLive(t *testing.T) { + resetHealthForTest() + // Bench in the past -> should be considered live again. + BenchModel("p/a", time.Now().Add(-time.Hour)) + if IsBenched("p/a") { + t.Error("expired bench should not count as benched") + } +} + +// TestParseChain exercises ParseChain via a registry-backed seam is covered in +// parse_test.go; here we verify NewFailoverModel flattens nested failover models. +func TestNewFailoverModel_Flattens(t *testing.T) { + resetHealthForTest() + a := newMockProvider(provider.Response{Text: "a"}) + b := newMockProvider(provider.Response{Text: "b"}) + c := newMockProvider(provider.Response{Text: "c"}) + inner := NewFailoverModel([]*Model{modelFor(b, "p/b"), modelFor(c, "p/c")}, testFailoverOpts()...) + outer := NewFailoverModel([]*Model{modelFor(a, "p/a"), inner}, testFailoverOpts()...) + + fp, ok := outer.provider.(*failoverProvider) + if !ok { + t.Fatalf("expected *failoverProvider, got %T", outer.provider) + } + if len(fp.entries) != 3 { + t.Errorf("expected flattened 3 entries, got %d", len(fp.entries)) + } + keys := []string{fp.entries[0].specKey, fp.entries[1].specKey, fp.entries[2].specKey} + want := []string{"p/a", "p/b", "p/c"} + for i := range want { + if keys[i] != want[i] { + t.Errorf("entry %d specKey = %q, want %q", i, keys[i], want[i]) + } + } +} diff --git a/v2/parse.go b/v2/parse.go index 3b2b5d3..f5b0515 100644 --- a/v2/parse.go +++ b/v2/parse.go @@ -109,6 +109,25 @@ func Parse(spec string) (*Model, error) { // What: walks the resolution chain and returns the final *Model with reasoning applied. // Test: see parse_test.go for comprehensive table-driven tests. func (r *Registry) Parse(spec string) (*Model, error) { + // Comma-separated specs become an ordered failover chain. A single part + // (after trimming/dropping empties) falls through to normal single-model + // parsing, preserving exact existing behavior for comma-free specs. + if strings.Contains(spec, ",") { + var parts []string + for _, p := range strings.Split(spec, ",") { + if p = strings.TrimSpace(p); p != "" { + parts = append(parts, p) + } + } + if len(parts) == 0 { + return nil, fmt.Errorf("%w: empty failover spec %q", ErrUnknownProvider, spec) + } + if len(parts) > 1 { + return r.ParseChain(parts) + } + spec = parts[0] + } + m, level, err := r.parse(spec, 0) if err != nil { return nil, err @@ -134,6 +153,15 @@ func (r *Registry) parse(spec string, depth int) (*Model, ReasoningLevel, error) target, isAlias := r.aliases[base] r.mu.RUnlock() if isAlias { + // An alias may expand to a comma-separated failover chain; route those + // through the comma-aware public Parse so the chain is built correctly. + if strings.Contains(target, ",") { + m, err := r.Parse(target) + if err != nil { + return nil, "", err + } + return m, userLevel, nil + } m, aliasLevel, err := r.parse(target, depth+1) if err != nil { return nil, "", err @@ -155,6 +183,18 @@ func (r *Registry) parse(spec string, depth int) (*Model, ReasoningLevel, error) if resolved == "" { return nil, "", fmt.Errorf("resolver returned empty spec for %q", base) } + // A resolver may return a comma-separated failover chain. + if strings.Contains(resolved, ",") { + m, err := r.Parse(resolved) + if err != nil { + return nil, "", err + } + level := defaultLevel + if userLevel != "" { + level = userLevel + } + return m, level, nil + } m, resolvedLevel, err := r.parse(resolved, depth+1) if err != nil { return nil, "", err diff --git a/v2/parse_test.go b/v2/parse_test.go index 73c15b4..b03b07d 100644 --- a/v2/parse_test.go +++ b/v2/parse_test.go @@ -72,12 +72,12 @@ func TestSplitReasoning(t *testing.T) { {"gpt-4o:high", "gpt-4o", ReasoningHigh}, {"gpt-4o:low", "gpt-4o", ReasoningLow}, {"gpt-4o:medium", "gpt-4o", ReasoningMedium}, - {"qwen3:30b", "qwen3:30b", ""}, // Ollama tag, not a level + {"qwen3:30b", "qwen3:30b", ""}, // Ollama tag, not a level {"qwen3:30b:high", "qwen3:30b", ReasoningHigh}, // tag + reasoning - {"model:", "model:", ""}, // trailing colon, empty suffix - {"", "", ""}, // empty string - {"a:b:c:high", "a:b:c", ReasoningHigh}, // multiple colons - {"a:b:c:30b", "a:b:c:30b", ""}, // multiple colons, non-level + {"model:", "model:", ""}, // trailing colon, empty suffix + {"", "", ""}, // empty string + {"a:b:c:high", "a:b:c", ReasoningHigh}, // multiple colons + {"a:b:c:30b", "a:b:c:30b", ""}, // multiple colons, non-level } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { @@ -637,3 +637,81 @@ func TestNewRegistryIsolation(t *testing.T) { t.Error("alias registered in r1 should not appear in r2") } } + +// --------------------------------------------------------------------------- +// Comma-separated failover chains +// --------------------------------------------------------------------------- + +func TestParse_CommaProducesFailover(t *testing.T) { + resetHealthForTest() + r, alpha, beta := testRegistry(func(string) string { return "" }) + + m, err := r.Parse("alpha/model-a,beta/model-b") + if err != nil { + t.Fatalf("Parse failover spec: %v", err) + } + fp, ok := m.provider.(*failoverProvider) + if !ok { + t.Fatalf("expected *failoverProvider, got %T", m.provider) + } + if len(fp.entries) != 2 { + t.Fatalf("expected 2 entries, got %d", len(fp.entries)) + } + if fp.entries[0].specKey != "alpha/model-a" || fp.entries[1].specKey != "beta/model-b" { + t.Errorf("unexpected specKeys: %q, %q", fp.entries[0].specKey, fp.entries[1].specKey) + } + // Complete routes to the first provider and passes the bare model name. + _, err = m.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) + if err != nil { + t.Fatal(err) + } + if alpha.lastModel != "model-a" { + t.Errorf("alpha got model %q, want model-a", alpha.lastModel) + } + if beta.lastModel != "" { + t.Errorf("beta should not have been called, got model %q", beta.lastModel) + } +} + +func TestParse_NoCommaUnchanged(t *testing.T) { + resetHealthForTest() + r, _, _ := testRegistry(func(string) string { return "" }) + m, err := r.Parse("alpha/model-a") + if err != nil { + t.Fatal(err) + } + if _, ok := m.provider.(*failoverProvider); ok { + t.Error("single (comma-free) spec must NOT produce a failover provider") + } +} + +func TestParse_CommaSinglePartFallsThrough(t *testing.T) { + resetHealthForTest() + r, _, _ := testRegistry(func(string) string { return "" }) + // Trailing comma / whitespace collapses to a single real part. + m, err := r.Parse("alpha/model-a, ") + if err != nil { + t.Fatal(err) + } + if _, ok := m.provider.(*failoverProvider); ok { + t.Error("a single effective part must not produce a failover provider") + } +} + +func TestParse_CommaFlattensNested(t *testing.T) { + resetHealthForTest() + r, _, _ := testRegistry(func(string) string { return "" }) + // Register an alias that is itself a comma chain. + r.RegisterAlias("pair", "alpha/model-a,beta/model-b") + m, err := r.Parse("pair,beta/model-b2") + if err != nil { + t.Fatal(err) + } + fp, ok := m.provider.(*failoverProvider) + if !ok { + t.Fatalf("expected *failoverProvider, got %T", m.provider) + } + if len(fp.entries) != 3 { + t.Errorf("expected flattened 3 entries, got %d", len(fp.entries)) + } +}