Files
go-llm/v2/parse_test.go
T
steve ae8e194fad feat(failover): model failover chains via comma-separated specs
Parse("a,b,c") now returns one composite *llm.Model that tries each model
in order, retrying transient failures, benching dead models, and failing
over to the next. Comma-free specs are completely unchanged.

- classify.go: Classify(err) ErrKind + IsTransient(err) error classifier
  mapping anthropic (typed Is*Err helpers + RequestError status),
  openai-go (*openai.Error status), openaicompat.FeatureUnsupportedError,
  context errors, and ollama "HTTP <code>" strings to
  transient/auth-dead/request-specific/unknown.
- failover.go: failoverProvider (satisfies provider.Provider) wrapped into a
  *Model via NewClient. Process-wide mutex-guarded modelHealth bench
  registry keyed by concrete spec, with cooldowns and a control API
  (ListBenched/BenchModel/UnbenchModel/IsBenched). NewFailoverModel +
  ParseChain constructors, FailoverOption config, FailoverObserver (carries
  the full request), and configurable package-level defaults.
- parse.go: comma-aware Parse splits into a failover chain; alias/resolver
  targets that expand to comma chains are routed through the comma-aware
  path and flattened.

All access to global health is mutex-guarded; tests reset it via
resetHealthForTest and pass under go test -race.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-01 00:30:08 +02:00

718 lines
20 KiB
Go

package llm
import (
"context"
"errors"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// recordingProvider captures the model name passed to Complete so tests can
// verify that Parse resolved to the correct model without network calls.
type recordingProvider struct {
lastModel string
}
func (p *recordingProvider) Complete(_ context.Context, req provider.Request) (provider.Response, error) {
p.lastModel = req.Model
return provider.Response{Text: "ok"}, nil
}
func (p *recordingProvider) Stream(_ context.Context, _ provider.Request, events chan<- provider.StreamEvent) error {
close(events)
return nil
}
// testRegistry builds a Registry with two mock providers ("alpha" and "beta")
// and an injectable envLookup. No real API keys or network access required.
func testRegistry(envFn func(string) string) (*Registry, *recordingProvider, *recordingProvider) {
alpha := &recordingProvider{}
beta := &recordingProvider{}
r := &Registry{
providers: map[string]ProviderInfo{
"alpha": {
Name: "alpha",
DisplayName: "Alpha",
EnvKey: "ALPHA_API_KEY",
Models: []string{"model-a"},
New: func(apiKey string, opts ...ClientOption) *Client {
return NewClient(alpha)
},
},
"beta": {
Name: "beta",
DisplayName: "Beta",
EnvKey: "",
Models: []string{"model-b"},
New: func(_ string, opts ...ClientOption) *Client {
return NewClient(beta)
},
},
},
order: []string{"alpha", "beta"},
aliases: make(map[string]string),
envLookup: envFn,
}
return r, alpha, beta
}
// ---------------------------------------------------------------------------
// splitReasoning
// ---------------------------------------------------------------------------
func TestSplitReasoning(t *testing.T) {
tests := []struct {
input string
wantBase string
wantLevel ReasoningLevel
}{
{"gpt-4o", "gpt-4o", ""},
{"gpt-4o:high", "gpt-4o", ReasoningHigh},
{"gpt-4o:low", "gpt-4o", ReasoningLow},
{"gpt-4o:medium", "gpt-4o", ReasoningMedium},
{"qwen3:30b", "qwen3:30b", ""}, // Ollama tag, not a level
{"qwen3:30b:high", "qwen3:30b", ReasoningHigh}, // tag + reasoning
{"model:", "model:", ""}, // trailing colon, empty suffix
{"", "", ""}, // empty string
{"a:b:c:high", "a:b:c", ReasoningHigh}, // multiple colons
{"a:b:c:30b", "a:b:c:30b", ""}, // multiple colons, non-level
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
base, level := splitReasoning(tt.input)
if base != tt.wantBase {
t.Errorf("splitReasoning(%q) base = %q, want %q", tt.input, base, tt.wantBase)
}
if level != tt.wantLevel {
t.Errorf("splitReasoning(%q) level = %q, want %q", tt.input, level, tt.wantLevel)
}
})
}
}
// ---------------------------------------------------------------------------
// ParseDSN
// ---------------------------------------------------------------------------
func TestParseDSN(t *testing.T) {
tests := []struct {
name string
raw string
want DSN
wantErr bool
}{
{
name: "full DSN with token",
raw: "foreman://test-token@foreman-m5.orgrimmar.dudenhoeffer.casa",
want: DSN{Scheme: "foreman", Token: "test-token", Host: "foreman-m5.orgrimmar.dudenhoeffer.casa"},
},
{
name: "no token",
raw: "ollama://localhost:11434",
want: DSN{Scheme: "ollama", Token: "", Host: "localhost:11434"},
},
{
name: "trailing slash stripped",
raw: "foreman://tok@host.com/",
want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com"},
},
{
name: "with path",
raw: "foreman://tok@host.com/v1/api",
want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com/v1/api"},
},
{
name: "missing scheme",
raw: "no-scheme-here",
wantErr: true,
},
{
name: "missing host",
raw: "foreman://token@",
wantErr: true,
},
{
name: "empty string",
raw: "",
wantErr: true,
},
{
name: "scheme only",
raw: "foreman://",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseDSN(tt.raw)
if tt.wantErr {
if err == nil {
t.Fatalf("ParseDSN(%q) expected error, got nil", tt.raw)
}
if !errors.Is(err, ErrInvalidDSN) {
t.Errorf("ParseDSN(%q) error = %v, want ErrInvalidDSN", tt.raw, err)
}
return
}
if err != nil {
t.Fatalf("ParseDSN(%q) unexpected error: %v", tt.raw, err)
}
if got != tt.want {
t.Errorf("ParseDSN(%q) = %+v, want %+v", tt.raw, got, tt.want)
}
})
}
}
// ---------------------------------------------------------------------------
// Registry.Parse
// ---------------------------------------------------------------------------
func TestRegistryParse(t *testing.T) {
t.Run("basic provider/model", func(t *testing.T) {
r, alpha, _ := testRegistry(nil)
m, err := r.Parse("alpha/model-a")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if m == nil {
t.Fatal("expected non-nil model")
}
// Verify the model name by exercising it.
_, _ = m.Complete(context.Background(), nil)
if alpha.lastModel != "model-a" {
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
}
})
t.Run("provider/model with reasoning suffix", func(t *testing.T) {
r, alpha, _ := testRegistry(nil)
m, err := r.Parse("alpha/model-a:high")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if m.defaultReasoning != ReasoningHigh {
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
}
_, _ = m.Complete(context.Background(), nil)
if alpha.lastModel != "model-a" {
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
}
})
t.Run("ollama-style tag preserved", func(t *testing.T) {
r, _, beta := testRegistry(nil)
m, err := r.Parse("beta/qwen3:30b")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if m.defaultReasoning != "" {
t.Errorf("reasoning = %q, want empty (30b is not a level)", m.defaultReasoning)
}
_, _ = m.Complete(context.Background(), nil)
if beta.lastModel != "qwen3:30b" {
t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b")
}
})
t.Run("ollama-style tag with reasoning", func(t *testing.T) {
r, _, beta := testRegistry(nil)
m, err := r.Parse("beta/qwen3:30b:high")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if m.defaultReasoning != ReasoningHigh {
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
}
_, _ = m.Complete(context.Background(), nil)
if beta.lastModel != "qwen3:30b" {
t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b")
}
})
t.Run("static alias", func(t *testing.T) {
r, alpha, _ := testRegistry(nil)
r.RegisterAlias("fast", "alpha/model-a")
m, err := r.Parse("fast")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, _ = m.Complete(context.Background(), nil)
if alpha.lastModel != "model-a" {
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
}
})
t.Run("alias with embedded reasoning", func(t *testing.T) {
r, alpha, _ := testRegistry(nil)
r.RegisterAlias("thinking", "alpha/model-a:high")
m, err := r.Parse("thinking")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if m.defaultReasoning != ReasoningHigh {
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
}
_, _ = m.Complete(context.Background(), nil)
if alpha.lastModel != "model-a" {
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
}
})
t.Run("user reasoning overrides alias default", func(t *testing.T) {
r, _, _ := testRegistry(nil)
r.RegisterAlias("thinking", "alpha/model-a:high")
m, err := r.Parse("thinking:low")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if m.defaultReasoning != ReasoningLow {
t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow)
}
})
t.Run("dynamic resolver", func(t *testing.T) {
r, alpha, _ := testRegistry(nil)
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
if name == "custom" {
return "alpha/resolved-model", "", true
}
return "", "", false
}))
m, err := r.Parse("custom")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, _ = m.Complete(context.Background(), nil)
if alpha.lastModel != "resolved-model" {
t.Errorf("model = %q, want %q", alpha.lastModel, "resolved-model")
}
})
t.Run("resolver with default reasoning", func(t *testing.T) {
r, _, _ := testRegistry(nil)
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
if name == "smart" {
return "alpha/model-a", ReasoningHigh, true
}
return "", "", false
}))
m, err := r.Parse("smart")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if m.defaultReasoning != ReasoningHigh {
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
}
})
t.Run("resolver default reasoning overridden by user", func(t *testing.T) {
r, _, _ := testRegistry(nil)
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
if name == "smart" {
return "alpha/model-a", ReasoningHigh, true
}
return "", "", false
}))
m, err := r.Parse("smart:low")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if m.defaultReasoning != ReasoningLow {
t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow)
}
})
t.Run("resolver returns empty spec", func(t *testing.T) {
r, _, _ := testRegistry(nil)
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
if name == "bad" {
return "", "", true
}
return "", "", false
}))
_, err := r.Parse("bad")
if err == nil {
t.Fatal("expected error for empty resolver spec")
}
})
t.Run("LLM_X env var with token", func(t *testing.T) {
envFn := func(key string) string {
if key == "LLM_M5" {
return "alpha://mytoken@myhost.com"
}
return ""
}
r, alpha, _ := testRegistry(envFn)
m, err := r.Parse("m5/qwen3:30b")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, _ = m.Complete(context.Background(), nil)
if alpha.lastModel != "qwen3:30b" {
t.Errorf("model = %q, want %q", alpha.lastModel, "qwen3:30b")
}
})
t.Run("LLM_X env var without token", func(t *testing.T) {
envFn := func(key string) string {
if key == "LLM_LOCAL" {
return "beta://localhost:11434"
}
return ""
}
r, _, beta := testRegistry(envFn)
m, err := r.Parse("local/llama3.2")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, _ = m.Complete(context.Background(), nil)
if beta.lastModel != "llama3.2" {
t.Errorf("model = %q, want %q", beta.lastModel, "llama3.2")
}
})
t.Run("LLM_X with hyphenated name", func(t *testing.T) {
envFn := func(key string) string {
if key == "LLM_MY_BOX" {
return "alpha://tok@host.com"
}
return ""
}
r, alpha, _ := testRegistry(envFn)
m, err := r.Parse("my-box/some-model")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, _ = m.Complete(context.Background(), nil)
if alpha.lastModel != "some-model" {
t.Errorf("model = %q, want %q", alpha.lastModel, "some-model")
}
})
t.Run("unknown provider no env var", func(t *testing.T) {
r, _, _ := testRegistry(func(string) string { return "" })
_, err := r.Parse("unknown/model")
if err == nil {
t.Fatal("expected error for unknown provider")
}
if !errors.Is(err, ErrUnknownProvider) {
t.Errorf("error = %v, want ErrUnknownProvider", err)
}
})
t.Run("bare unknown name no slash", func(t *testing.T) {
r, _, _ := testRegistry(nil)
_, err := r.Parse("nonexistent")
if err == nil {
t.Fatal("expected error for bare unknown name")
}
if !errors.Is(err, ErrUnknownProvider) {
t.Errorf("error = %v, want ErrUnknownProvider", err)
}
})
t.Run("alias loop", func(t *testing.T) {
r, _, _ := testRegistry(nil)
r.RegisterAlias("a", "b")
r.RegisterAlias("b", "a")
_, err := r.Parse("a")
if err == nil {
t.Fatal("expected error for alias loop")
}
if !errors.Is(err, ErrAliasLoop) {
t.Errorf("error = %v, want ErrAliasLoop", err)
}
})
t.Run("deep alias chain within limit", func(t *testing.T) {
r, alpha, _ := testRegistry(nil)
// chain: a0 → a1 → a2 → ... → a9 → alpha/model-a (depth 10 is fine, >10 is not)
r.RegisterAlias("a9", "alpha/model-a")
for i := 8; i >= 0; i-- {
r.RegisterAlias(
"a"+string(rune('0'+i)),
"a"+string(rune('0'+i+1)),
)
}
m, err := r.Parse("a0")
if err != nil {
t.Fatalf("unexpected error for deep but valid chain: %v", err)
}
_, _ = m.Complete(context.Background(), nil)
if alpha.lastModel != "model-a" {
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
}
})
t.Run("RegisterProvider replaces existing", func(t *testing.T) {
replaced := &recordingProvider{}
r, _, _ := testRegistry(nil)
r.RegisterProvider(ProviderInfo{
Name: "alpha",
DisplayName: "Alpha Replaced",
New: func(_ string, _ ...ClientOption) *Client {
return NewClient(replaced)
},
})
m, err := r.Parse("alpha/new-model")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, _ = m.Complete(context.Background(), nil)
if replaced.lastModel != "new-model" {
t.Errorf("model = %q, want %q", replaced.lastModel, "new-model")
}
// Verify order is preserved (alpha is still at index 0).
providers := r.Providers()
if providers[0].Name != "alpha" {
t.Errorf("first provider = %q, want %q (order preserved)", providers[0].Name, "alpha")
}
if providers[0].DisplayName != "Alpha Replaced" {
t.Errorf("display name = %q, want %q", providers[0].DisplayName, "Alpha Replaced")
}
})
t.Run("RegisterProvider adds new provider", func(t *testing.T) {
gamma := &recordingProvider{}
r, _, _ := testRegistry(nil)
r.RegisterProvider(ProviderInfo{
Name: "gamma",
DisplayName: "Gamma",
New: func(_ string, _ ...ClientOption) *Client {
return NewClient(gamma)
},
})
m, err := r.Parse("gamma/g-model")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, _ = m.Complete(context.Background(), nil)
if gamma.lastModel != "g-model" {
t.Errorf("model = %q, want %q", gamma.lastModel, "g-model")
}
// New provider should be last.
providers := r.Providers()
last := providers[len(providers)-1]
if last.Name != "gamma" {
t.Errorf("last provider = %q, want %q", last.Name, "gamma")
}
})
t.Run("DSN with unknown scheme", func(t *testing.T) {
envFn := func(key string) string {
if key == "LLM_X" {
return "nope://tok@host.com"
}
return ""
}
r, _, _ := testRegistry(envFn)
_, err := r.Parse("x/model")
if err == nil {
t.Fatal("expected error for unknown DSN scheme")
}
if !errors.Is(err, ErrUnknownProvider) {
t.Errorf("error = %v, want ErrUnknownProvider", err)
}
})
t.Run("DSN with invalid format", func(t *testing.T) {
envFn := func(key string) string {
if key == "LLM_BAD" {
return "no-scheme"
}
return ""
}
r, _, _ := testRegistry(envFn)
_, err := r.Parse("bad/model")
if err == nil {
t.Fatal("expected error for invalid DSN")
}
if !errors.Is(err, ErrInvalidDSN) {
t.Errorf("error = %v, want ErrInvalidDSN", err)
}
})
t.Run("createClient reads env key", func(t *testing.T) {
var capturedKey string
r, _, _ := testRegistry(func(key string) string {
if key == "MY_KEY" {
return "secret-123"
}
return ""
})
rec := &recordingProvider{}
r.RegisterProvider(ProviderInfo{
Name: "keyed",
EnvKey: "MY_KEY",
New: func(apiKey string, opts ...ClientOption) *Client {
capturedKey = apiKey
return NewClient(rec)
},
})
_, err := r.Parse("keyed/m")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if capturedKey != "secret-123" {
t.Errorf("apiKey = %q, want %q", capturedKey, "secret-123")
}
})
}
// ---------------------------------------------------------------------------
// Backward compatibility: package-level Providers() and ProviderByName()
// ---------------------------------------------------------------------------
func TestBackwardCompat(t *testing.T) {
t.Run("Providers includes foreman", func(t *testing.T) {
providers := Providers()
found := false
for _, p := range providers {
if p.Name == "foreman" {
found = true
break
}
}
if !found {
t.Error("Providers() does not include foreman")
}
})
t.Run("ProviderByName openai", func(t *testing.T) {
p := ProviderByName("openai")
if p == nil {
t.Fatal("ProviderByName(\"openai\") returned nil")
}
if p.DisplayName != "OpenAI" {
t.Errorf("DisplayName = %q, want %q", p.DisplayName, "OpenAI")
}
})
t.Run("ProviderByName foreman", func(t *testing.T) {
p := ProviderByName("foreman")
if p == nil {
t.Fatal("ProviderByName(\"foreman\") returned nil")
}
if p.DisplayName != "Foreman" {
t.Errorf("DisplayName = %q, want %q", p.DisplayName, "Foreman")
}
})
t.Run("ProviderByName unknown", func(t *testing.T) {
p := ProviderByName("does-not-exist")
if p != nil {
t.Errorf("expected nil for unknown provider, got %+v", p)
}
})
t.Run("Providers count", func(t *testing.T) {
providers := Providers()
// providerRegistry has 9 entries (openai, anthropic, google, deepseek,
// moonshot, xai, groq, ollama, ollama-cloud) + foreman = 10
if len(providers) < 10 {
t.Errorf("len(Providers()) = %d, want >= 10", len(providers))
}
})
}
// ---------------------------------------------------------------------------
// NewRegistry isolation
// ---------------------------------------------------------------------------
func TestNewRegistryIsolation(t *testing.T) {
r1 := NewRegistry()
r2 := NewRegistry()
r1.RegisterAlias("only-in-r1", "openai/gpt-4o")
// r2 should not see r1's alias
r2.mu.RLock()
_, found := r2.aliases["only-in-r1"]
r2.mu.RUnlock()
if found {
t.Error("alias registered in r1 should not appear in r2")
}
}
// ---------------------------------------------------------------------------
// Comma-separated failover chains
// ---------------------------------------------------------------------------
func TestParse_CommaProducesFailover(t *testing.T) {
resetHealthForTest()
r, alpha, beta := testRegistry(func(string) string { return "" })
m, err := r.Parse("alpha/model-a,beta/model-b")
if err != nil {
t.Fatalf("Parse failover spec: %v", err)
}
fp, ok := m.provider.(*failoverProvider)
if !ok {
t.Fatalf("expected *failoverProvider, got %T", m.provider)
}
if len(fp.entries) != 2 {
t.Fatalf("expected 2 entries, got %d", len(fp.entries))
}
if fp.entries[0].specKey != "alpha/model-a" || fp.entries[1].specKey != "beta/model-b" {
t.Errorf("unexpected specKeys: %q, %q", fp.entries[0].specKey, fp.entries[1].specKey)
}
// Complete routes to the first provider and passes the bare model name.
_, err = m.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
if err != nil {
t.Fatal(err)
}
if alpha.lastModel != "model-a" {
t.Errorf("alpha got model %q, want model-a", alpha.lastModel)
}
if beta.lastModel != "" {
t.Errorf("beta should not have been called, got model %q", beta.lastModel)
}
}
func TestParse_NoCommaUnchanged(t *testing.T) {
resetHealthForTest()
r, _, _ := testRegistry(func(string) string { return "" })
m, err := r.Parse("alpha/model-a")
if err != nil {
t.Fatal(err)
}
if _, ok := m.provider.(*failoverProvider); ok {
t.Error("single (comma-free) spec must NOT produce a failover provider")
}
}
func TestParse_CommaSinglePartFallsThrough(t *testing.T) {
resetHealthForTest()
r, _, _ := testRegistry(func(string) string { return "" })
// Trailing comma / whitespace collapses to a single real part.
m, err := r.Parse("alpha/model-a, ")
if err != nil {
t.Fatal(err)
}
if _, ok := m.provider.(*failoverProvider); ok {
t.Error("a single effective part must not produce a failover provider")
}
}
func TestParse_CommaFlattensNested(t *testing.T) {
resetHealthForTest()
r, _, _ := testRegistry(func(string) string { return "" })
// Register an alias that is itself a comma chain.
r.RegisterAlias("pair", "alpha/model-a,beta/model-b")
m, err := r.Parse("pair,beta/model-b2")
if err != nil {
t.Fatal(err)
}
fp, ok := m.provider.(*failoverProvider)
if !ok {
t.Fatalf("expected *failoverProvider, got %T", m.provider)
}
if len(fp.entries) != 3 {
t.Errorf("expected flattened 3 entries, got %d", len(fp.entries))
}
}