feat(failover): model failover chains via comma-separated specs
Parse("a,b,c") now returns one composite *llm.Model that tries each model
in order, retrying transient failures, benching dead models, and failing
over to the next. Comma-free specs are completely unchanged.
- classify.go: Classify(err) ErrKind + IsTransient(err) error classifier
mapping anthropic (typed Is*Err helpers + RequestError status),
openai-go (*openai.Error status), openaicompat.FeatureUnsupportedError,
context errors, and ollama "HTTP <code>" strings to
transient/auth-dead/request-specific/unknown.
- failover.go: failoverProvider (satisfies provider.Provider) wrapped into a
*Model via NewClient. Process-wide mutex-guarded modelHealth bench
registry keyed by concrete spec, with cooldowns and a control API
(ListBenched/BenchModel/UnbenchModel/IsBenched). NewFailoverModel +
ParseChain constructors, FailoverOption config, FailoverObserver (carries
the full request), and configurable package-level defaults.
- parse.go: comma-aware Parse splits into a failover chain; alias/resolver
targets that expand to comma chains are routed through the comma-aware
path and flattened.
All access to global health is mutex-guarded; tests reset it via
resetHealthForTest and pass under go test -race.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+207
@@ -0,0 +1,207 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
anth "github.com/liushuangls/go-anthropic/v2"
|
||||
"github.com/openai/openai-go"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
)
|
||||
|
||||
// ErrKind classifies a provider error for failover decision-making.
|
||||
//
|
||||
// Why: failover must decide, per error, whether to retry the same model
|
||||
// (transient), bench it as broken (auth/model dead), or fail over without
|
||||
// benching (this request's fault). Without a classifier every error looks
|
||||
// the same and we'd either thrash a dead model or bench a healthy one.
|
||||
// What: an enum of the four outcomes the failover algorithm distinguishes.
|
||||
// Test: see classify_test.go — every branch is table-tested with faked SDK errors.
|
||||
type ErrKind int
|
||||
|
||||
const (
|
||||
// ErrUnknown is an unrecognized error. Failover treats it as transient
|
||||
// (conservative — retry then fail over), EXCEPT context.Canceled which
|
||||
// the caller special-cases as an abort.
|
||||
ErrUnknown ErrKind = iota
|
||||
// ErrTransient is a temporary failure (429/5xx/timeout): retry, then
|
||||
// bench-and-fail-over if retries are exhausted.
|
||||
ErrTransient
|
||||
// ErrAuthDead is an auth failure or model-not-found (401/403/404): the
|
||||
// model is unusable; bench immediately and fail over.
|
||||
ErrAuthDead
|
||||
// ErrRequestSpecific is the caller's fault for THIS request (400/413/422,
|
||||
// unsupported feature): fail over to try a more capable model, but do NOT
|
||||
// bench — the model itself is healthy.
|
||||
ErrRequestSpecific
|
||||
)
|
||||
|
||||
// classifyStatus maps an HTTP status code to an ErrKind.
|
||||
//
|
||||
// Why: openai-go and anthropic RequestError both expose a numeric StatusCode;
|
||||
// centralizing the mapping keeps the per-SDK branches thin and consistent.
|
||||
// What: 408/409/429/5xx → transient, 401/403/404 → auth-dead, 400/413/422 →
|
||||
// request-specific, anything else → unknown.
|
||||
// Test: covered indirectly via Classify table tests for each SDK.
|
||||
func classifyStatus(code int) ErrKind {
|
||||
switch code {
|
||||
case 408, 409, 429, 500, 502, 503, 504:
|
||||
return ErrTransient
|
||||
case 401, 403, 404:
|
||||
return ErrAuthDead
|
||||
case 400, 413, 422:
|
||||
return ErrRequestSpecific
|
||||
default:
|
||||
return ErrUnknown
|
||||
}
|
||||
}
|
||||
|
||||
// Classify inspects a provider error and returns its ErrKind.
|
||||
//
|
||||
// Why: the failover composite needs typed, status-code-aware classification to
|
||||
// retry/bench/skip correctly across the anthropic, openai-compat, and ollama
|
||||
// providers, each of which surfaces errors differently.
|
||||
// What: prefers anthropic's typed Is*Err helpers, falls back to numeric status
|
||||
// codes (openai-go, anthropic RequestError), then the openaicompat
|
||||
// FeatureUnsupportedError, context errors, and finally an ollama HTTP-string
|
||||
// fallback; unrecognized errors are ErrUnknown.
|
||||
// Test: classify_test.go faked SDK errors exercise every branch.
|
||||
func Classify(err error) ErrKind {
|
||||
if err == nil {
|
||||
return ErrUnknown
|
||||
}
|
||||
|
||||
// context.Canceled is reported as ErrUnknown here; the failover algorithm
|
||||
// special-cases it as an abort before consulting the kind.
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return ErrUnknown
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return ErrTransient
|
||||
}
|
||||
|
||||
// FeatureUnsupportedError is a permanent, request-shaped failure.
|
||||
var featErr *openaicompat.FeatureUnsupportedError
|
||||
if errors.As(err, &featErr) {
|
||||
return ErrRequestSpecific
|
||||
}
|
||||
|
||||
// Anthropic APIError: prefer the typed helpers (no StatusCode available).
|
||||
var apiErr *anth.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
switch {
|
||||
case apiErr.IsRateLimitErr(), apiErr.IsOverloadedErr(), apiErr.IsApiErr():
|
||||
return ErrTransient
|
||||
case apiErr.IsAuthenticationErr(), apiErr.IsPermissionErr(), apiErr.IsNotFoundErr():
|
||||
return ErrAuthDead
|
||||
case apiErr.IsTooLargeErr(), apiErr.IsInvalidRequestErr():
|
||||
return ErrRequestSpecific
|
||||
default:
|
||||
return ErrUnknown
|
||||
}
|
||||
}
|
||||
|
||||
// Anthropic RequestError: status-code based.
|
||||
var anthReqErr *anth.RequestError
|
||||
if errors.As(err, &anthReqErr) {
|
||||
return classifyStatus(anthReqErr.StatusCode)
|
||||
}
|
||||
|
||||
// openai-go (openai/deepseek/moonshot/xai/groq): status-code based.
|
||||
var oaiErr *openai.Error
|
||||
if errors.As(err, &oaiErr) {
|
||||
return classifyStatus(oaiErr.StatusCode)
|
||||
}
|
||||
|
||||
// Ollama: no typed status — fall back to its "ollama: HTTP <code>:" string.
|
||||
if k := classifyOllamaString(err.Error()); k != ErrUnknown {
|
||||
return k
|
||||
}
|
||||
|
||||
return ErrUnknown
|
||||
}
|
||||
|
||||
// classifyOllamaString extracts an HTTP status from ollama's error string
|
||||
// format ("ollama: HTTP <code>: ...") and classifies it.
|
||||
//
|
||||
// Why: the ollama provider stringifies errors without a typed status code, so
|
||||
// failover can only classify by parsing the message.
|
||||
// What: looks for "HTTP <code>" in the message and maps the code; returns
|
||||
// ErrUnknown when no recognizable status is present.
|
||||
// Test: classify_test.go ollama cases cover 5xx/429/401/404/400.
|
||||
func classifyOllamaString(msg string) ErrKind {
|
||||
const marker = "HTTP "
|
||||
idx := strings.Index(msg, marker)
|
||||
if idx < 0 {
|
||||
return ErrUnknown
|
||||
}
|
||||
rest := msg[idx+len(marker):]
|
||||
// Read up to 3 leading digits.
|
||||
end := 0
|
||||
for end < len(rest) && end < 3 && rest[end] >= '0' && rest[end] <= '9' {
|
||||
end++
|
||||
}
|
||||
if end == 0 {
|
||||
return ErrUnknown
|
||||
}
|
||||
code := 0
|
||||
for i := 0; i < end; i++ {
|
||||
code = code*10 + int(rest[i]-'0')
|
||||
}
|
||||
return classifyStatus(code)
|
||||
}
|
||||
|
||||
// extractStatus best-effort pulls an HTTP status code out of a provider error
|
||||
// for structured logging. Returns 0 when none is available.
|
||||
//
|
||||
// Why: log lines benefit from the numeric status even though classification
|
||||
// may use typed helpers; this keeps that detail out of the hot path.
|
||||
// What: checks anthropic RequestError and openai-go Error StatusCode fields,
|
||||
// then parses ollama's "HTTP <code>" string; returns 0 otherwise.
|
||||
// Test: covered indirectly via failover log assertions / manual inspection.
|
||||
func extractStatus(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
var anthReqErr *anth.RequestError
|
||||
if errors.As(err, &anthReqErr) {
|
||||
return anthReqErr.StatusCode
|
||||
}
|
||||
var oaiErr *openai.Error
|
||||
if errors.As(err, &oaiErr) {
|
||||
return oaiErr.StatusCode
|
||||
}
|
||||
const marker = "HTTP "
|
||||
msg := err.Error()
|
||||
if idx := strings.Index(msg, marker); idx >= 0 {
|
||||
rest := msg[idx+len(marker):]
|
||||
code, n := 0, 0
|
||||
for n < len(rest) && n < 3 && rest[n] >= '0' && rest[n] <= '9' {
|
||||
code = code*10 + int(rest[n]-'0')
|
||||
n++
|
||||
}
|
||||
if n > 0 {
|
||||
return code
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsTransient reports whether an error should be retried/failed-over rather
|
||||
// than treated as a hard, model-specific failure.
|
||||
//
|
||||
// Why: callers (and failover) want a one-call "is this worth retrying?" check
|
||||
// that is conservative about unknown errors.
|
||||
// What: returns true for ErrTransient and ErrUnknown (conservative), false for
|
||||
// ErrAuthDead and ErrRequestSpecific.
|
||||
// Test: TestIsTransient asserts 503→true, unknown→true, 401→false, 400→false.
|
||||
func IsTransient(err error) bool {
|
||||
switch Classify(err) {
|
||||
case ErrTransient, ErrUnknown:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
anth "github.com/liushuangls/go-anthropic/v2"
|
||||
"github.com/openai/openai-go"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
||||
)
|
||||
|
||||
func TestClassify(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want ErrKind
|
||||
}{
|
||||
// nil
|
||||
{"nil", nil, ErrUnknown},
|
||||
|
||||
// openai-go status codes (transient)
|
||||
{"openai 408", &openai.Error{StatusCode: 408}, ErrTransient},
|
||||
{"openai 409", &openai.Error{StatusCode: 409}, ErrTransient},
|
||||
{"openai 429", &openai.Error{StatusCode: 429}, ErrTransient},
|
||||
{"openai 500", &openai.Error{StatusCode: 500}, ErrTransient},
|
||||
{"openai 502", &openai.Error{StatusCode: 502}, ErrTransient},
|
||||
{"openai 503", &openai.Error{StatusCode: 503}, ErrTransient},
|
||||
{"openai 504", &openai.Error{StatusCode: 504}, ErrTransient},
|
||||
|
||||
// openai-go status codes (auth dead)
|
||||
{"openai 401", &openai.Error{StatusCode: 401}, ErrAuthDead},
|
||||
{"openai 403", &openai.Error{StatusCode: 403}, ErrAuthDead},
|
||||
{"openai 404", &openai.Error{StatusCode: 404}, ErrAuthDead},
|
||||
|
||||
// openai-go status codes (request specific)
|
||||
{"openai 400", &openai.Error{StatusCode: 400}, ErrRequestSpecific},
|
||||
{"openai 413", &openai.Error{StatusCode: 413}, ErrRequestSpecific},
|
||||
{"openai 422", &openai.Error{StatusCode: 422}, ErrRequestSpecific},
|
||||
|
||||
// openai unrecognized status -> unknown
|
||||
{"openai 418", &openai.Error{StatusCode: 418}, ErrUnknown},
|
||||
|
||||
// wrapped openai error (providers wrap with %w)
|
||||
{"wrapped openai 503", fmt.Errorf("openai completion error: %w", &openai.Error{StatusCode: 503}), ErrTransient},
|
||||
|
||||
// FeatureUnsupportedError -> request specific
|
||||
{"feature unsupported", &openaicompat.FeatureUnsupportedError{Feature: "tools", Model: "m"}, ErrRequestSpecific},
|
||||
{"wrapped feature unsupported", fmt.Errorf("x: %w", &openaicompat.FeatureUnsupportedError{Feature: "vision", Model: "m"}), ErrRequestSpecific},
|
||||
|
||||
// anthropic RequestError (status-code based)
|
||||
{"anth req 503", &anth.RequestError{StatusCode: 503}, ErrTransient},
|
||||
{"anth req 429", &anth.RequestError{StatusCode: 429}, ErrTransient},
|
||||
{"anth req 401", &anth.RequestError{StatusCode: 401}, ErrAuthDead},
|
||||
{"anth req 400", &anth.RequestError{StatusCode: 400}, ErrRequestSpecific},
|
||||
{"wrapped anth req 502", fmt.Errorf("anthropic completion error: %w", &anth.RequestError{StatusCode: 502}), ErrTransient},
|
||||
|
||||
// anthropic APIError (helper based)
|
||||
{"anth rate limit", &anth.APIError{Type: anth.ErrTypeRateLimit}, ErrTransient},
|
||||
{"anth overloaded", &anth.APIError{Type: anth.ErrTypeOverloaded}, ErrTransient},
|
||||
{"anth api", &anth.APIError{Type: anth.ErrTypeApi}, ErrTransient},
|
||||
{"anth auth", &anth.APIError{Type: anth.ErrTypeAuthentication}, ErrAuthDead},
|
||||
{"anth permission", &anth.APIError{Type: anth.ErrTypePermission}, ErrAuthDead},
|
||||
{"anth not found", &anth.APIError{Type: anth.ErrTypeNotFound}, ErrAuthDead},
|
||||
{"anth too large", &anth.APIError{Type: anth.ErrTypeTooLarge}, ErrRequestSpecific},
|
||||
{"anth invalid request", &anth.APIError{Type: anth.ErrTypeInvalidRequest}, ErrRequestSpecific},
|
||||
{"wrapped anth api error", fmt.Errorf("error, status code: 529, message: %w", &anth.APIError{Type: anth.ErrTypeOverloaded}), ErrTransient},
|
||||
|
||||
// context errors
|
||||
{"context canceled", context.Canceled, ErrUnknown},
|
||||
{"context deadline", context.DeadlineExceeded, ErrTransient},
|
||||
{"wrapped deadline", fmt.Errorf("call failed: %w", context.DeadlineExceeded), ErrTransient},
|
||||
|
||||
// ollama string-based
|
||||
{"ollama HTTP 503", errors.New("ollama: HTTP 503: service unavailable"), ErrTransient},
|
||||
{"ollama HTTP 500", errors.New("ollama: HTTP 500: internal"), ErrTransient},
|
||||
{"ollama HTTP 502", errors.New("ollama: HTTP 502: bad gateway"), ErrTransient},
|
||||
{"ollama HTTP 504", errors.New("ollama: HTTP 504: timeout"), ErrTransient},
|
||||
{"ollama HTTP 429", errors.New("ollama: HTTP 429: too many requests"), ErrTransient},
|
||||
{"ollama HTTP 401", errors.New("ollama: HTTP 401: unauthorized"), ErrAuthDead},
|
||||
{"ollama HTTP 404", errors.New("ollama: HTTP 404: not found"), ErrAuthDead},
|
||||
{"ollama HTTP 400", errors.New("ollama: HTTP 400: bad request"), ErrRequestSpecific},
|
||||
|
||||
// unknown
|
||||
{"random error", errors.New("something weird"), ErrUnknown},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := Classify(tt.err)
|
||||
if got != tt.want {
|
||||
t.Errorf("Classify(%v) = %v, want %v", tt.err, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTransient(t *testing.T) {
|
||||
// IsTransient treats both ErrTransient and ErrUnknown as "should retry"
|
||||
// (conservative). Auth/request-specific are not transient.
|
||||
if !IsTransient(&openai.Error{StatusCode: 503}) {
|
||||
t.Error("503 should be transient")
|
||||
}
|
||||
if !IsTransient(errors.New("mystery")) {
|
||||
t.Error("unknown should be treated as transient (conservative)")
|
||||
}
|
||||
if IsTransient(&openai.Error{StatusCode: 401}) {
|
||||
t.Error("401 (auth dead) should NOT be transient")
|
||||
}
|
||||
if IsTransient(&openai.Error{StatusCode: 400}) {
|
||||
t.Error("400 (request specific) should NOT be transient")
|
||||
}
|
||||
}
|
||||
+588
@@ -0,0 +1,588 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Package-level defaults (mort configures these at boot via SetFailoverDefaults)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
// DefaultFailoverMaxRetries is the number of attempts per chain entry on
|
||||
// transient errors before benching and moving to the next entry.
|
||||
DefaultFailoverMaxRetries = 3
|
||||
// DefaultFailoverCooldown is how long a model stays benched after a
|
||||
// qualifying failure.
|
||||
DefaultFailoverCooldown = 5 * time.Minute
|
||||
// DefaultFailoverBackoff is the default exponential-with-jitter backoff.
|
||||
DefaultFailoverBackoff = defaultBackoff
|
||||
|
||||
defaultsMu sync.Mutex
|
||||
)
|
||||
|
||||
// defaultBackoff returns an exponential backoff with full jitter.
|
||||
//
|
||||
// Why: spreads retries to avoid thundering-herd against a recovering provider.
|
||||
// What: base 200ms doubling per attempt, capped at 10s, with uniform jitter.
|
||||
// Test: failover retry tests inject a fast backoff; this is the production default.
|
||||
func defaultBackoff(attempt int) time.Duration {
|
||||
if attempt < 1 {
|
||||
attempt = 1
|
||||
}
|
||||
base := 200 * time.Millisecond
|
||||
d := base << (attempt - 1)
|
||||
if d > 10*time.Second {
|
||||
d = 10 * time.Second
|
||||
}
|
||||
// Full jitter in [0, d].
|
||||
return time.Duration(rand.Int63n(int64(d) + 1))
|
||||
}
|
||||
|
||||
// SetFailoverDefaults overrides the package-level failover defaults used when
|
||||
// no per-model options are supplied (e.g. comma-spec Parse).
|
||||
//
|
||||
// Why: mort wants to tune retries/cooldown once at boot without threading
|
||||
// options through every Parse call.
|
||||
// What: sets DefaultFailoverMaxRetries and DefaultFailoverCooldown under a lock.
|
||||
// Test: set defaults, build a comma model, assert its cfg reflects them.
|
||||
func SetFailoverDefaults(maxRetries int, cooldown time.Duration) {
|
||||
defaultsMu.Lock()
|
||||
defer defaultsMu.Unlock()
|
||||
DefaultFailoverMaxRetries = maxRetries
|
||||
DefaultFailoverCooldown = cooldown
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Global model health (process-wide bench registry)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// modelHealth tracks which concrete models are temporarily disabled (benched).
|
||||
//
|
||||
// Why: bench decisions must persist across requests and across all failover
|
||||
// chains in the process, so a model that's down isn't retried by every chain.
|
||||
// What: a mutex-guarded map keyed by specKey to its disabled state.
|
||||
// Test: failover tests reset it via resetHealthForTest and assert via IsBenched.
|
||||
type modelHealth struct {
|
||||
mu sync.Mutex
|
||||
disabled map[string]disabledState
|
||||
}
|
||||
|
||||
type disabledState struct {
|
||||
until time.Time
|
||||
consecutiveFails int
|
||||
manual bool
|
||||
}
|
||||
|
||||
// globalHealth is the process-wide singleton shared by every failover chain.
|
||||
var globalHealth = &modelHealth{disabled: map[string]disabledState{}}
|
||||
|
||||
// benchThreshold is the number of consecutive transient failures (each after
|
||||
// exhausting retries) required before a model is benched. Auth-dead benches
|
||||
// immediately regardless.
|
||||
const benchThreshold = 1
|
||||
|
||||
// resetHealthForTest clears all bench state. Test-only.
|
||||
func resetHealthForTest() {
|
||||
globalHealth.mu.Lock()
|
||||
defer globalHealth.mu.Unlock()
|
||||
globalHealth.disabled = map[string]disabledState{}
|
||||
}
|
||||
|
||||
// isBenched reports whether key is currently benched (and not expired).
|
||||
func (h *modelHealth) isBenched(key string, now time.Time) bool {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
st, ok := h.disabled[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if now.After(st.until) {
|
||||
delete(h.disabled, key)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// recordSuccess clears any failure state for key.
|
||||
func (h *modelHealth) recordSuccess(key string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
delete(h.disabled, key)
|
||||
}
|
||||
|
||||
// recordTransientFailure increments the consecutive failure count and benches
|
||||
// the model once the threshold is reached. Returns whether it is now benched
|
||||
// and for how long.
|
||||
func (h *modelHealth) recordTransientFailure(key string, cooldown time.Duration, now time.Time) (benched bool, until time.Time) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
st := h.disabled[key]
|
||||
st.consecutiveFails++
|
||||
if st.consecutiveFails >= benchThreshold {
|
||||
st.until = now.Add(cooldown)
|
||||
st.manual = false
|
||||
h.disabled[key] = st
|
||||
return true, st.until
|
||||
}
|
||||
h.disabled[key] = st
|
||||
return false, time.Time{}
|
||||
}
|
||||
|
||||
// benchNow benches a model immediately (used for auth-dead errors).
|
||||
func (h *modelHealth) benchNow(key string, cooldown time.Duration, now time.Time) time.Time {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
st := h.disabled[key]
|
||||
st.consecutiveFails++
|
||||
st.until = now.Add(cooldown)
|
||||
st.manual = false
|
||||
h.disabled[key] = st
|
||||
return st.until
|
||||
}
|
||||
|
||||
// benchManual benches a model until the given time, marking it manual.
|
||||
func (h *modelHealth) benchManual(key string, until time.Time) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
st := h.disabled[key]
|
||||
st.until = until
|
||||
st.manual = true
|
||||
h.disabled[key] = st
|
||||
}
|
||||
|
||||
// unbench removes a model's bench state, reporting whether it was benched.
|
||||
func (h *modelHealth) unbench(key string, now time.Time) bool {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
st, ok := h.disabled[key]
|
||||
if !ok || now.After(st.until) {
|
||||
delete(h.disabled, key)
|
||||
return false
|
||||
}
|
||||
delete(h.disabled, key)
|
||||
return true
|
||||
}
|
||||
|
||||
// list returns a snapshot of all currently-benched (non-expired) models.
|
||||
func (h *modelHealth) list(now time.Time) []BenchedModel {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
var out []BenchedModel
|
||||
for k, st := range h.disabled {
|
||||
if now.After(st.until) {
|
||||
delete(h.disabled, k)
|
||||
continue
|
||||
}
|
||||
out = append(out, BenchedModel{
|
||||
Model: k,
|
||||
Until: st.until,
|
||||
ConsecutiveFails: st.consecutiveFails,
|
||||
Manual: st.manual,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Control API (admin commands / UI drive these)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// BenchedModel is a snapshot of a benched model's state.
|
||||
type BenchedModel struct {
|
||||
Model string
|
||||
Until time.Time
|
||||
ConsecutiveFails int
|
||||
Manual bool
|
||||
}
|
||||
|
||||
// ListBenched returns all currently-benched models across the process.
|
||||
//
|
||||
// Why: admin tooling needs to display which models are sidelined and why.
|
||||
// What: snapshots the global health map, pruning expired entries.
|
||||
// Test: BenchModel then ListBenched returns it with Manual=true.
|
||||
func ListBenched() []BenchedModel { return globalHealth.list(time.Now()) }
|
||||
|
||||
// BenchModel manually benches a model until the given time.
|
||||
//
|
||||
// Why: operators sometimes need to force a model offline (incident, cost).
|
||||
// What: records a manual bench in the global health registry.
|
||||
// Test: BenchModel then IsBenched returns true and ListBenched shows Manual.
|
||||
func BenchModel(spec string, until time.Time) { globalHealth.benchManual(spec, until) }
|
||||
|
||||
// UnbenchModel clears a model's bench state, returning whether it was benched.
|
||||
//
|
||||
// Why: operators need to bring a model back early after manual or auto bench.
|
||||
// What: deletes the global health entry, reporting prior benched state.
|
||||
// Test: bench then UnbenchModel returns true; a second call returns false.
|
||||
func UnbenchModel(spec string) bool { return globalHealth.unbench(spec, time.Now()) }
|
||||
|
||||
// IsBenched reports whether a model is currently benched.
|
||||
//
|
||||
// Why: callers/tests want a quick health check for a concrete model.
|
||||
// What: consults the global health registry (expired benches read as false).
|
||||
// Test: BenchModel makes it true; an expired bench reads false.
|
||||
func IsBenched(spec string) bool { return globalHealth.isBenched(spec, time.Now()) }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Observer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// FailoverEvent describes a single failover decision for an observer.
|
||||
type FailoverEvent struct {
|
||||
Model string
|
||||
Err error
|
||||
Kind ErrKind
|
||||
Attempt int
|
||||
Benched bool
|
||||
BenchedFor time.Duration
|
||||
NextModel string
|
||||
Request provider.Request
|
||||
}
|
||||
|
||||
// FailoverObserver receives a FailoverEvent for each failover decision. mort
|
||||
// uses this to persist the full prompt chain on failover.
|
||||
type FailoverObserver func(ctx context.Context, ev FailoverEvent)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Config + options
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type failoverConfig struct {
|
||||
maxRetries int
|
||||
cooldown time.Duration
|
||||
backoff func(attempt int) time.Duration
|
||||
observer FailoverObserver
|
||||
}
|
||||
|
||||
func defaultFailoverConfig() failoverConfig {
|
||||
defaultsMu.Lock()
|
||||
defer defaultsMu.Unlock()
|
||||
return failoverConfig{
|
||||
maxRetries: DefaultFailoverMaxRetries,
|
||||
cooldown: DefaultFailoverCooldown,
|
||||
backoff: DefaultFailoverBackoff,
|
||||
}
|
||||
}
|
||||
|
||||
// FailoverOption configures a failover model.
|
||||
type FailoverOption func(*failoverConfig)
|
||||
|
||||
// WithFailoverMaxRetries sets attempts per entry on transient errors.
|
||||
func WithFailoverMaxRetries(n int) FailoverOption {
|
||||
return func(c *failoverConfig) {
|
||||
if n < 1 {
|
||||
n = 1
|
||||
}
|
||||
c.maxRetries = n
|
||||
}
|
||||
}
|
||||
|
||||
// WithFailoverCooldown sets how long a model stays benched after failure.
|
||||
func WithFailoverCooldown(d time.Duration) FailoverOption {
|
||||
return func(c *failoverConfig) { c.cooldown = d }
|
||||
}
|
||||
|
||||
// WithFailoverBackoff sets the retry backoff function.
|
||||
func WithFailoverBackoff(fn func(attempt int) time.Duration) FailoverOption {
|
||||
return func(c *failoverConfig) {
|
||||
if fn != nil {
|
||||
c.backoff = fn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithFailoverObserver sets an observer notified on every failover decision.
|
||||
func WithFailoverObserver(obs FailoverObserver) FailoverOption {
|
||||
return func(c *failoverConfig) { c.observer = obs }
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Composite provider
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type failoverEntry struct {
|
||||
provider provider.Provider
|
||||
model string // bare model name sent to the provider
|
||||
specKey string // global health key (full concrete spec)
|
||||
}
|
||||
|
||||
type failoverProvider struct {
|
||||
entries []failoverEntry
|
||||
cfg failoverConfig
|
||||
health *modelHealth
|
||||
}
|
||||
|
||||
// bareModel strips a leading "provider/" prefix, returning the model name the
|
||||
// underlying provider expects. Specs without a slash are returned unchanged.
|
||||
func bareModel(spec string) string {
|
||||
if i := strings.Index(spec, "/"); i >= 0 {
|
||||
return spec[i+1:]
|
||||
}
|
||||
return spec
|
||||
}
|
||||
|
||||
// NewFailoverModel builds a composite *Model that tries each sub-model in order,
|
||||
// retrying/benching per the configured policy and failing over on error.
|
||||
//
|
||||
// Why: callers hold *Model and the base Complete handler is hardwired to one
|
||||
// provider, so failover (which must switch providers) is implemented as a
|
||||
// composite provider wrapped back into a *Model.
|
||||
// What: flattens any nested failover sub-models, derives a specKey per entry
|
||||
// from its model string, and returns NewClient(fp).Model("failover").
|
||||
// Test: failover_test.go exercises success, failover, bench, abort, and flatten.
|
||||
func NewFailoverModel(models []*Model, opts ...FailoverOption) *Model {
|
||||
cfg := defaultFailoverConfig()
|
||||
for _, opt := range opts {
|
||||
opt(&cfg)
|
||||
}
|
||||
|
||||
var entries []failoverEntry
|
||||
for _, m := range models {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
// Flatten nested failover models so cooldowns/keys stay flat.
|
||||
if fp, ok := m.provider.(*failoverProvider); ok {
|
||||
entries = append(entries, fp.entries...)
|
||||
continue
|
||||
}
|
||||
entries = append(entries, failoverEntry{
|
||||
provider: m.provider,
|
||||
model: bareModel(m.model),
|
||||
specKey: m.model,
|
||||
})
|
||||
}
|
||||
|
||||
fp := &failoverProvider{
|
||||
entries: entries,
|
||||
cfg: cfg,
|
||||
health: globalHealth,
|
||||
}
|
||||
return NewClient(fp).Model("failover")
|
||||
}
|
||||
|
||||
// ParseChain parses each spec and combines them into one failover model.
|
||||
//
|
||||
// Why: lets callers build a failover chain from a slice of specs (full
|
||||
// resolution per entry) without manually wiring providers.
|
||||
// What: Parse each spec, preserve the original spec string as the bench key,
|
||||
// flatten nested failover models, and return a composite *Model.
|
||||
// Test: parse_test.go covers comma-spec parsing through the registry.
|
||||
func ParseChain(specs []string, opts ...FailoverOption) (*Model, error) {
|
||||
return DefaultRegistry.ParseChain(specs, opts...)
|
||||
}
|
||||
|
||||
// ParseChain is the registry-scoped form of ParseChain.
|
||||
func (r *Registry) ParseChain(specs []string, opts ...FailoverOption) (*Model, error) {
|
||||
cfg := defaultFailoverConfig()
|
||||
for _, opt := range opts {
|
||||
opt(&cfg)
|
||||
}
|
||||
|
||||
var entries []failoverEntry
|
||||
for _, spec := range specs {
|
||||
spec = strings.TrimSpace(spec)
|
||||
if spec == "" {
|
||||
continue
|
||||
}
|
||||
m, err := r.Parse(spec)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failover chain: parse %q: %w", spec, err)
|
||||
}
|
||||
if fp, ok := m.provider.(*failoverProvider); ok {
|
||||
// A sub-spec was itself a comma/failover spec — splice its entries.
|
||||
entries = append(entries, fp.entries...)
|
||||
continue
|
||||
}
|
||||
entries = append(entries, failoverEntry{
|
||||
provider: m.provider,
|
||||
model: m.model,
|
||||
specKey: spec,
|
||||
})
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return nil, fmt.Errorf("failover chain: no valid specs")
|
||||
}
|
||||
|
||||
fp := &failoverProvider{entries: entries, cfg: cfg, health: globalHealth}
|
||||
return NewClient(fp).Model("failover"), nil
|
||||
}
|
||||
|
||||
// reqWithModel returns a shallow copy of req with its Model set to model.
|
||||
//
|
||||
// Why: each sub-provider must receive its own bare model name; the incoming
|
||||
// req carries the placeholder "failover" model from the composite *Model.
|
||||
// What: copies the struct (slices/pointers are shared, which is safe here since
|
||||
// providers treat the request as read-only) and overrides Model.
|
||||
// Test: TestFailover_PassesModelNameToProvider asserts the provider sees the bare name.
|
||||
func reqWithModel(req provider.Request, model string) provider.Request {
|
||||
req.Model = model
|
||||
return req
|
||||
}
|
||||
|
||||
// Complete implements provider.Provider with ordered failover.
|
||||
func (f *failoverProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
now := time.Now()
|
||||
|
||||
// 1. Build the live set (not currently benched). Best-effort: if all are
|
||||
// benched, ignore cooldowns rather than hard-fail.
|
||||
var live []failoverEntry
|
||||
for _, e := range f.entries {
|
||||
if !f.health.isBenched(e.specKey, now) {
|
||||
live = append(live, e)
|
||||
}
|
||||
}
|
||||
if len(live) == 0 {
|
||||
live = f.entries
|
||||
}
|
||||
|
||||
var causes []error
|
||||
|
||||
for i, entry := range live {
|
||||
nextModel := ""
|
||||
if i+1 < len(live) {
|
||||
nextModel = live[i+1].specKey
|
||||
}
|
||||
|
||||
for attempt := 1; attempt <= f.cfg.maxRetries; attempt++ {
|
||||
resp, err := entry.provider.Complete(ctx, reqWithModel(req, entry.model))
|
||||
if err == nil {
|
||||
f.health.recordSuccess(entry.specKey)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Caller aborted: stop everything, no failover, no bench.
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return provider.Response{}, err
|
||||
}
|
||||
|
||||
kind := Classify(err)
|
||||
|
||||
switch kind {
|
||||
case ErrRequestSpecific:
|
||||
f.emit(ctx, FailoverEvent{
|
||||
Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt,
|
||||
NextModel: nextModel, Request: req,
|
||||
})
|
||||
slog.Warn("failover: request-specific error, trying next model",
|
||||
"model", entry.specKey, "kind", "request_specific",
|
||||
"status", statusOf(err), "attempt", attempt, "next", nextModel)
|
||||
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
|
||||
goto nextEntry
|
||||
|
||||
case ErrAuthDead:
|
||||
until := f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now())
|
||||
f.emit(ctx, FailoverEvent{
|
||||
Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt,
|
||||
Benched: true, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req,
|
||||
})
|
||||
slog.Warn("failover: auth/model-dead error, benching model",
|
||||
"model", entry.specKey, "kind", "auth_dead", "status", statusOf(err),
|
||||
"attempt", attempt, "benched", true, "cooldown", f.cfg.cooldown,
|
||||
"until", until, "next", nextModel)
|
||||
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
|
||||
goto nextEntry
|
||||
|
||||
default: // ErrTransient or ErrUnknown -> retry, then bench.
|
||||
if attempt >= f.cfg.maxRetries {
|
||||
benched, until := f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now())
|
||||
f.emit(ctx, FailoverEvent{
|
||||
Model: entry.specKey, Err: err, Kind: kind, Attempt: attempt,
|
||||
Benched: benched, BenchedFor: f.cfg.cooldown, NextModel: nextModel, Request: req,
|
||||
})
|
||||
slog.Warn("failover: transient error, retries exhausted",
|
||||
"model", entry.specKey, "kind", kindString(kind), "status", statusOf(err),
|
||||
"attempt", attempt, "benched", benched, "cooldown", f.cfg.cooldown,
|
||||
"until", until, "next", nextModel)
|
||||
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
|
||||
goto nextEntry
|
||||
}
|
||||
// Sleep before retrying (respect ctx).
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return provider.Response{}, ctx.Err()
|
||||
case <-time.After(f.cfg.backoff(attempt)):
|
||||
}
|
||||
}
|
||||
}
|
||||
nextEntry:
|
||||
}
|
||||
|
||||
return provider.Response{}, fmt.Errorf("failover: all %d models in chain failed: %w",
|
||||
len(live), errors.Join(causes...))
|
||||
}
|
||||
|
||||
// Stream implements provider.Provider. It fails over only on the INITIAL Stream
|
||||
// call error (before any event). Once a stream begins, mid-stream failures are
|
||||
// surfaced as-is — failover does not replay a partially-consumed stream.
|
||||
func (f *failoverProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||
now := time.Now()
|
||||
var live []failoverEntry
|
||||
for _, e := range f.entries {
|
||||
if !f.health.isBenched(e.specKey, now) {
|
||||
live = append(live, e)
|
||||
}
|
||||
}
|
||||
if len(live) == 0 {
|
||||
live = f.entries
|
||||
}
|
||||
|
||||
var causes []error
|
||||
for i, entry := range live {
|
||||
nextModel := ""
|
||||
if i+1 < len(live) {
|
||||
nextModel = live[i+1].specKey
|
||||
}
|
||||
err := entry.provider.Stream(ctx, reqWithModel(req, entry.model), events)
|
||||
if err == nil {
|
||||
f.health.recordSuccess(entry.specKey)
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return err
|
||||
}
|
||||
kind := Classify(err)
|
||||
switch kind {
|
||||
case ErrAuthDead:
|
||||
f.health.benchNow(entry.specKey, f.cfg.cooldown, time.Now())
|
||||
case ErrTransient, ErrUnknown:
|
||||
f.health.recordTransientFailure(entry.specKey, f.cfg.cooldown, time.Now())
|
||||
}
|
||||
slog.Warn("failover(stream): error, trying next model",
|
||||
"model", entry.specKey, "kind", kindString(kind), "status", statusOf(err), "next", nextModel)
|
||||
causes = append(causes, fmt.Errorf("%s: %w", entry.specKey, err))
|
||||
}
|
||||
return fmt.Errorf("failover(stream): all %d models in chain failed: %w", len(live), errors.Join(causes...))
|
||||
}
|
||||
|
||||
func (f *failoverProvider) emit(ctx context.Context, ev FailoverEvent) {
|
||||
if f.cfg.observer != nil {
|
||||
f.cfg.observer(ctx, ev)
|
||||
}
|
||||
}
|
||||
|
||||
func kindString(k ErrKind) string {
|
||||
switch k {
|
||||
case ErrTransient:
|
||||
return "transient"
|
||||
case ErrAuthDead:
|
||||
return "auth_dead"
|
||||
case ErrRequestSpecific:
|
||||
return "request_specific"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// statusOf best-effort extracts an HTTP status code for logging, or 0.
|
||||
func statusOf(err error) int { return extractStatus(err) }
|
||||
@@ -0,0 +1,302 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
// fastBackoff is a near-zero backoff so retry tests don't sleep.
|
||||
func fastBackoff(int) time.Duration { return time.Microsecond }
|
||||
|
||||
func testFailoverOpts(extra ...FailoverOption) []FailoverOption {
|
||||
base := []FailoverOption{
|
||||
WithFailoverMaxRetries(2),
|
||||
WithFailoverBackoff(fastBackoff),
|
||||
WithFailoverCooldown(time.Minute),
|
||||
}
|
||||
return append(base, extra...)
|
||||
}
|
||||
|
||||
// modelFor builds a *Model around a mock provider with a concrete model name,
|
||||
// mimicking what Parse produces (so specKey resolution works).
|
||||
func modelFor(p provider.Provider, name string) *Model {
|
||||
return &Model{provider: p, model: name}
|
||||
}
|
||||
|
||||
func TestFailover_FirstSucceeds(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
a := newMockProvider(provider.Response{Text: "from-a"})
|
||||
b := newMockProvider(provider.Response{Text: "from-b"})
|
||||
|
||||
fo := NewFailoverModel([]*Model{modelFor(a, "anthropic/a"), modelFor(b, "openai/b")}, testFailoverOpts()...)
|
||||
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "from-a" {
|
||||
t.Errorf("expected from-a, got %q", resp.Text)
|
||||
}
|
||||
// b must not have been called.
|
||||
b.mu.Lock()
|
||||
n := len(b.Requests)
|
||||
b.mu.Unlock()
|
||||
if n != 0 {
|
||||
t.Errorf("expected b untouched, got %d calls", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_FailsOverToSecond(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
// a always returns a request-specific error (400) -> fail over, no retry-bench loop noise.
|
||||
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, &openai.Error{StatusCode: 400}
|
||||
})
|
||||
b := newMockProvider(provider.Response{Text: "from-b"})
|
||||
|
||||
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
|
||||
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "from-b" {
|
||||
t.Errorf("expected from-b, got %q", resp.Text)
|
||||
}
|
||||
// 400 is request-specific: a must NOT be benched.
|
||||
if IsBenched("p/a") {
|
||||
t.Error("p/a should not be benched on a 400")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_PassesModelNameToProvider(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
a := newMockProvider(provider.Response{Text: "ok"})
|
||||
fo := NewFailoverModel([]*Model{modelFor(a, "anthropic/claude-x")}, testFailoverOpts()...)
|
||||
_, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := a.lastRequest().Model; got != "claude-x" {
|
||||
t.Errorf("provider received model %q, want bare model name claude-x", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_AuthDeadBenchesImmediately(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, &openai.Error{StatusCode: 401}
|
||||
})
|
||||
b := newMockProvider(provider.Response{Text: "from-b"})
|
||||
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
|
||||
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Text != "from-b" {
|
||||
t.Errorf("expected from-b, got %q", resp.Text)
|
||||
}
|
||||
if !IsBenched("p/a") {
|
||||
t.Error("p/a should be benched after auth-dead error")
|
||||
}
|
||||
// a should have been called exactly once (no retries on auth-dead).
|
||||
a.mu.Lock()
|
||||
n := len(a.Requests)
|
||||
a.mu.Unlock()
|
||||
if n != 1 {
|
||||
t.Errorf("auth-dead should not retry; a called %d times", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_TransientRetriesThenBenches(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
var calls int
|
||||
var mu sync.Mutex
|
||||
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
mu.Lock()
|
||||
calls++
|
||||
mu.Unlock()
|
||||
return provider.Response{}, &openai.Error{StatusCode: 503}
|
||||
})
|
||||
b := newMockProvider(provider.Response{Text: "from-b"})
|
||||
// maxRetries=2 means 2 attempts total per entry.
|
||||
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")},
|
||||
WithFailoverMaxRetries(2), WithFailoverBackoff(fastBackoff), WithFailoverCooldown(time.Minute))
|
||||
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Text != "from-b" {
|
||||
t.Errorf("expected from-b, got %q", resp.Text)
|
||||
}
|
||||
mu.Lock()
|
||||
n := calls
|
||||
mu.Unlock()
|
||||
if n != 2 {
|
||||
t.Errorf("expected 2 attempts on transient model, got %d", n)
|
||||
}
|
||||
if !IsBenched("p/a") {
|
||||
t.Error("p/a should be benched after exhausting retries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_AllFail(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, &openai.Error{StatusCode: 400}
|
||||
})
|
||||
b := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, &openai.Error{StatusCode: 400}
|
||||
})
|
||||
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
|
||||
_, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all models fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "2") {
|
||||
t.Errorf("error should mention all 2 models failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_ContextCanceledAborts(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, context.Canceled
|
||||
})
|
||||
b := newMockProvider(provider.Response{Text: "from-b"})
|
||||
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
|
||||
_, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("expected context.Canceled to abort, got %v", err)
|
||||
}
|
||||
// b must not be tried.
|
||||
b.mu.Lock()
|
||||
n := len(b.Requests)
|
||||
b.mu.Unlock()
|
||||
if n != 0 {
|
||||
t.Errorf("canceled should not fail over; b called %d times", n)
|
||||
}
|
||||
if IsBenched("p/a") {
|
||||
t.Error("canceled should not bench")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_AllBenchedBestEffort(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
// Manually bench both, then ensure Complete still tries (best-effort) and succeeds.
|
||||
BenchModel("p/a", time.Now().Add(time.Hour))
|
||||
BenchModel("p/b", time.Now().Add(time.Hour))
|
||||
a := newMockProvider(provider.Response{Text: "from-a"})
|
||||
b := newMockProvider(provider.Response{Text: "from-b"})
|
||||
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")}, testFailoverOpts()...)
|
||||
resp, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if err != nil {
|
||||
t.Fatalf("best-effort should still try benched models: %v", err)
|
||||
}
|
||||
if resp.Text != "from-a" {
|
||||
t.Errorf("expected from-a, got %q", resp.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_Observer(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
a := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, &openai.Error{StatusCode: 401}
|
||||
})
|
||||
b := newMockProvider(provider.Response{Text: "from-b"})
|
||||
|
||||
var mu sync.Mutex
|
||||
var events []FailoverEvent
|
||||
obs := func(ctx context.Context, ev FailoverEvent) {
|
||||
mu.Lock()
|
||||
events = append(events, ev)
|
||||
mu.Unlock()
|
||||
}
|
||||
fo := NewFailoverModel([]*Model{modelFor(a, "p/a"), modelFor(b, "p/b")},
|
||||
append(testFailoverOpts(), WithFailoverObserver(obs))...)
|
||||
_, err := fo.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(events) == 0 {
|
||||
t.Fatal("expected observer to be called")
|
||||
}
|
||||
ev := events[0]
|
||||
if ev.Model != "p/a" || ev.Kind != ErrAuthDead || !ev.Benched {
|
||||
t.Errorf("unexpected event: %+v", ev)
|
||||
}
|
||||
if ev.NextModel != "p/b" {
|
||||
t.Errorf("expected NextModel p/b, got %q", ev.NextModel)
|
||||
}
|
||||
if len(ev.Request.Messages) == 0 {
|
||||
t.Error("observer event should carry the full request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_ControlAPI(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
if IsBenched("x/y") {
|
||||
t.Error("should start unbenched")
|
||||
}
|
||||
until := time.Now().Add(time.Hour)
|
||||
BenchModel("x/y", until)
|
||||
if !IsBenched("x/y") {
|
||||
t.Error("should be benched")
|
||||
}
|
||||
list := ListBenched()
|
||||
if len(list) != 1 || list[0].Model != "x/y" || !list[0].Manual {
|
||||
t.Errorf("unexpected list: %+v", list)
|
||||
}
|
||||
if !UnbenchModel("x/y") {
|
||||
t.Error("UnbenchModel should report it was benched")
|
||||
}
|
||||
if IsBenched("x/y") {
|
||||
t.Error("should be unbenched now")
|
||||
}
|
||||
if UnbenchModel("x/y") {
|
||||
t.Error("UnbenchModel on non-benched should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailover_ExpiredBenchIsLive(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
// Bench in the past -> should be considered live again.
|
||||
BenchModel("p/a", time.Now().Add(-time.Hour))
|
||||
if IsBenched("p/a") {
|
||||
t.Error("expired bench should not count as benched")
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseChain exercises ParseChain via a registry-backed seam is covered in
|
||||
// parse_test.go; here we verify NewFailoverModel flattens nested failover models.
|
||||
func TestNewFailoverModel_Flattens(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
a := newMockProvider(provider.Response{Text: "a"})
|
||||
b := newMockProvider(provider.Response{Text: "b"})
|
||||
c := newMockProvider(provider.Response{Text: "c"})
|
||||
inner := NewFailoverModel([]*Model{modelFor(b, "p/b"), modelFor(c, "p/c")}, testFailoverOpts()...)
|
||||
outer := NewFailoverModel([]*Model{modelFor(a, "p/a"), inner}, testFailoverOpts()...)
|
||||
|
||||
fp, ok := outer.provider.(*failoverProvider)
|
||||
if !ok {
|
||||
t.Fatalf("expected *failoverProvider, got %T", outer.provider)
|
||||
}
|
||||
if len(fp.entries) != 3 {
|
||||
t.Errorf("expected flattened 3 entries, got %d", len(fp.entries))
|
||||
}
|
||||
keys := []string{fp.entries[0].specKey, fp.entries[1].specKey, fp.entries[2].specKey}
|
||||
want := []string{"p/a", "p/b", "p/c"}
|
||||
for i := range want {
|
||||
if keys[i] != want[i] {
|
||||
t.Errorf("entry %d specKey = %q, want %q", i, keys[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
+40
@@ -109,6 +109,25 @@ func Parse(spec string) (*Model, error) {
|
||||
// What: walks the resolution chain and returns the final *Model with reasoning applied.
|
||||
// Test: see parse_test.go for comprehensive table-driven tests.
|
||||
func (r *Registry) Parse(spec string) (*Model, error) {
|
||||
// Comma-separated specs become an ordered failover chain. A single part
|
||||
// (after trimming/dropping empties) falls through to normal single-model
|
||||
// parsing, preserving exact existing behavior for comma-free specs.
|
||||
if strings.Contains(spec, ",") {
|
||||
var parts []string
|
||||
for _, p := range strings.Split(spec, ",") {
|
||||
if p = strings.TrimSpace(p); p != "" {
|
||||
parts = append(parts, p)
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("%w: empty failover spec %q", ErrUnknownProvider, spec)
|
||||
}
|
||||
if len(parts) > 1 {
|
||||
return r.ParseChain(parts)
|
||||
}
|
||||
spec = parts[0]
|
||||
}
|
||||
|
||||
m, level, err := r.parse(spec, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -134,6 +153,15 @@ func (r *Registry) parse(spec string, depth int) (*Model, ReasoningLevel, error)
|
||||
target, isAlias := r.aliases[base]
|
||||
r.mu.RUnlock()
|
||||
if isAlias {
|
||||
// An alias may expand to a comma-separated failover chain; route those
|
||||
// through the comma-aware public Parse so the chain is built correctly.
|
||||
if strings.Contains(target, ",") {
|
||||
m, err := r.Parse(target)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return m, userLevel, nil
|
||||
}
|
||||
m, aliasLevel, err := r.parse(target, depth+1)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -155,6 +183,18 @@ func (r *Registry) parse(spec string, depth int) (*Model, ReasoningLevel, error)
|
||||
if resolved == "" {
|
||||
return nil, "", fmt.Errorf("resolver returned empty spec for %q", base)
|
||||
}
|
||||
// A resolver may return a comma-separated failover chain.
|
||||
if strings.Contains(resolved, ",") {
|
||||
m, err := r.Parse(resolved)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
level := defaultLevel
|
||||
if userLevel != "" {
|
||||
level = userLevel
|
||||
}
|
||||
return m, level, nil
|
||||
}
|
||||
m, resolvedLevel, err := r.parse(resolved, depth+1)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
|
||||
+83
-5
@@ -72,12 +72,12 @@ func TestSplitReasoning(t *testing.T) {
|
||||
{"gpt-4o:high", "gpt-4o", ReasoningHigh},
|
||||
{"gpt-4o:low", "gpt-4o", ReasoningLow},
|
||||
{"gpt-4o:medium", "gpt-4o", ReasoningMedium},
|
||||
{"qwen3:30b", "qwen3:30b", ""}, // Ollama tag, not a level
|
||||
{"qwen3:30b", "qwen3:30b", ""}, // Ollama tag, not a level
|
||||
{"qwen3:30b:high", "qwen3:30b", ReasoningHigh}, // tag + reasoning
|
||||
{"model:", "model:", ""}, // trailing colon, empty suffix
|
||||
{"", "", ""}, // empty string
|
||||
{"a:b:c:high", "a:b:c", ReasoningHigh}, // multiple colons
|
||||
{"a:b:c:30b", "a:b:c:30b", ""}, // multiple colons, non-level
|
||||
{"model:", "model:", ""}, // trailing colon, empty suffix
|
||||
{"", "", ""}, // empty string
|
||||
{"a:b:c:high", "a:b:c", ReasoningHigh}, // multiple colons
|
||||
{"a:b:c:30b", "a:b:c:30b", ""}, // multiple colons, non-level
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
@@ -637,3 +637,81 @@ func TestNewRegistryIsolation(t *testing.T) {
|
||||
t.Error("alias registered in r1 should not appear in r2")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Comma-separated failover chains
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestParse_CommaProducesFailover(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
r, alpha, beta := testRegistry(func(string) string { return "" })
|
||||
|
||||
m, err := r.Parse("alpha/model-a,beta/model-b")
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failover spec: %v", err)
|
||||
}
|
||||
fp, ok := m.provider.(*failoverProvider)
|
||||
if !ok {
|
||||
t.Fatalf("expected *failoverProvider, got %T", m.provider)
|
||||
}
|
||||
if len(fp.entries) != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", len(fp.entries))
|
||||
}
|
||||
if fp.entries[0].specKey != "alpha/model-a" || fp.entries[1].specKey != "beta/model-b" {
|
||||
t.Errorf("unexpected specKeys: %q, %q", fp.entries[0].specKey, fp.entries[1].specKey)
|
||||
}
|
||||
// Complete routes to the first provider and passes the bare model name.
|
||||
_, err = m.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if alpha.lastModel != "model-a" {
|
||||
t.Errorf("alpha got model %q, want model-a", alpha.lastModel)
|
||||
}
|
||||
if beta.lastModel != "" {
|
||||
t.Errorf("beta should not have been called, got model %q", beta.lastModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_NoCommaUnchanged(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
r, _, _ := testRegistry(func(string) string { return "" })
|
||||
m, err := r.Parse("alpha/model-a")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := m.provider.(*failoverProvider); ok {
|
||||
t.Error("single (comma-free) spec must NOT produce a failover provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_CommaSinglePartFallsThrough(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
r, _, _ := testRegistry(func(string) string { return "" })
|
||||
// Trailing comma / whitespace collapses to a single real part.
|
||||
m, err := r.Parse("alpha/model-a, ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := m.provider.(*failoverProvider); ok {
|
||||
t.Error("a single effective part must not produce a failover provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_CommaFlattensNested(t *testing.T) {
|
||||
resetHealthForTest()
|
||||
r, _, _ := testRegistry(func(string) string { return "" })
|
||||
// Register an alias that is itself a comma chain.
|
||||
r.RegisterAlias("pair", "alpha/model-a,beta/model-b")
|
||||
m, err := r.Parse("pair,beta/model-b2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fp, ok := m.provider.(*failoverProvider)
|
||||
if !ok {
|
||||
t.Fatalf("expected *failoverProvider, got %T", m.provider)
|
||||
}
|
||||
if len(fp.entries) != 3 {
|
||||
t.Errorf("expected flattened 3 entries, got %d", len(fp.entries))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user