1206261e6a
CI / Build, Test & Lint (push) Successful in 10m43s
The transparent comma-Parse path builds failover chains via NewFailoverModel with no options, so defaultFailoverConfig() left the observer nil and observers only fired when a caller passed WithFailoverObserver explicitly. Add a package-level default observer (SetFailoverObserver / DefaultFailoverObserver), guarded by the existing defaultsMu, and seed it in defaultFailoverConfig() so chains built transparently still notify it. An explicit WithFailoverObserver still overrides the default per-chain. mort sets this at boot to persist failover events. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
724 lines
20 KiB
Go
724 lines
20 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
)
|
|
|
|
// recordingProvider captures the model name passed to Complete so tests can
|
|
// verify that Parse resolved to the correct model without network calls.
|
|
type recordingProvider struct {
|
|
lastModel string
|
|
// err, when non-nil, is returned from Complete so failover tests can drive
|
|
// a comma-Parse'd chain through a failover decision. Defaults to nil (success).
|
|
err error
|
|
}
|
|
|
|
func (p *recordingProvider) Complete(_ context.Context, req provider.Request) (provider.Response, error) {
|
|
p.lastModel = req.Model
|
|
if p.err != nil {
|
|
return provider.Response{}, p.err
|
|
}
|
|
return provider.Response{Text: "ok"}, nil
|
|
}
|
|
|
|
func (p *recordingProvider) Stream(_ context.Context, _ provider.Request, events chan<- provider.StreamEvent) error {
|
|
close(events)
|
|
return nil
|
|
}
|
|
|
|
// testRegistry builds a Registry with two mock providers ("alpha" and "beta")
|
|
// and an injectable envLookup. No real API keys or network access required.
|
|
func testRegistry(envFn func(string) string) (*Registry, *recordingProvider, *recordingProvider) {
|
|
alpha := &recordingProvider{}
|
|
beta := &recordingProvider{}
|
|
|
|
r := &Registry{
|
|
providers: map[string]ProviderInfo{
|
|
"alpha": {
|
|
Name: "alpha",
|
|
DisplayName: "Alpha",
|
|
EnvKey: "ALPHA_API_KEY",
|
|
Models: []string{"model-a"},
|
|
New: func(apiKey string, opts ...ClientOption) *Client {
|
|
return NewClient(alpha)
|
|
},
|
|
},
|
|
"beta": {
|
|
Name: "beta",
|
|
DisplayName: "Beta",
|
|
EnvKey: "",
|
|
Models: []string{"model-b"},
|
|
New: func(_ string, opts ...ClientOption) *Client {
|
|
return NewClient(beta)
|
|
},
|
|
},
|
|
},
|
|
order: []string{"alpha", "beta"},
|
|
aliases: make(map[string]string),
|
|
envLookup: envFn,
|
|
}
|
|
return r, alpha, beta
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// splitReasoning
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestSplitReasoning(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
wantBase string
|
|
wantLevel ReasoningLevel
|
|
}{
|
|
{"gpt-4o", "gpt-4o", ""},
|
|
{"gpt-4o:high", "gpt-4o", ReasoningHigh},
|
|
{"gpt-4o:low", "gpt-4o", ReasoningLow},
|
|
{"gpt-4o:medium", "gpt-4o", ReasoningMedium},
|
|
{"qwen3:30b", "qwen3:30b", ""}, // Ollama tag, not a level
|
|
{"qwen3:30b:high", "qwen3:30b", ReasoningHigh}, // tag + reasoning
|
|
{"model:", "model:", ""}, // trailing colon, empty suffix
|
|
{"", "", ""}, // empty string
|
|
{"a:b:c:high", "a:b:c", ReasoningHigh}, // multiple colons
|
|
{"a:b:c:30b", "a:b:c:30b", ""}, // multiple colons, non-level
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
base, level := splitReasoning(tt.input)
|
|
if base != tt.wantBase {
|
|
t.Errorf("splitReasoning(%q) base = %q, want %q", tt.input, base, tt.wantBase)
|
|
}
|
|
if level != tt.wantLevel {
|
|
t.Errorf("splitReasoning(%q) level = %q, want %q", tt.input, level, tt.wantLevel)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// ParseDSN
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestParseDSN(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
raw string
|
|
want DSN
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "full DSN with token",
|
|
raw: "foreman://test-token@foreman-m5.orgrimmar.dudenhoeffer.casa",
|
|
want: DSN{Scheme: "foreman", Token: "test-token", Host: "foreman-m5.orgrimmar.dudenhoeffer.casa"},
|
|
},
|
|
{
|
|
name: "no token",
|
|
raw: "ollama://localhost:11434",
|
|
want: DSN{Scheme: "ollama", Token: "", Host: "localhost:11434"},
|
|
},
|
|
{
|
|
name: "trailing slash stripped",
|
|
raw: "foreman://tok@host.com/",
|
|
want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com"},
|
|
},
|
|
{
|
|
name: "with path",
|
|
raw: "foreman://tok@host.com/v1/api",
|
|
want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com/v1/api"},
|
|
},
|
|
{
|
|
name: "missing scheme",
|
|
raw: "no-scheme-here",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "missing host",
|
|
raw: "foreman://token@",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "empty string",
|
|
raw: "",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "scheme only",
|
|
raw: "foreman://",
|
|
wantErr: true,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := ParseDSN(tt.raw)
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Fatalf("ParseDSN(%q) expected error, got nil", tt.raw)
|
|
}
|
|
if !errors.Is(err, ErrInvalidDSN) {
|
|
t.Errorf("ParseDSN(%q) error = %v, want ErrInvalidDSN", tt.raw, err)
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("ParseDSN(%q) unexpected error: %v", tt.raw, err)
|
|
}
|
|
if got != tt.want {
|
|
t.Errorf("ParseDSN(%q) = %+v, want %+v", tt.raw, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Registry.Parse
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRegistryParse(t *testing.T) {
|
|
t.Run("basic provider/model", func(t *testing.T) {
|
|
r, alpha, _ := testRegistry(nil)
|
|
m, err := r.Parse("alpha/model-a")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if m == nil {
|
|
t.Fatal("expected non-nil model")
|
|
}
|
|
// Verify the model name by exercising it.
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if alpha.lastModel != "model-a" {
|
|
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
|
}
|
|
})
|
|
|
|
t.Run("provider/model with reasoning suffix", func(t *testing.T) {
|
|
r, alpha, _ := testRegistry(nil)
|
|
m, err := r.Parse("alpha/model-a:high")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if m.defaultReasoning != ReasoningHigh {
|
|
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if alpha.lastModel != "model-a" {
|
|
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
|
}
|
|
})
|
|
|
|
t.Run("ollama-style tag preserved", func(t *testing.T) {
|
|
r, _, beta := testRegistry(nil)
|
|
m, err := r.Parse("beta/qwen3:30b")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if m.defaultReasoning != "" {
|
|
t.Errorf("reasoning = %q, want empty (30b is not a level)", m.defaultReasoning)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if beta.lastModel != "qwen3:30b" {
|
|
t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b")
|
|
}
|
|
})
|
|
|
|
t.Run("ollama-style tag with reasoning", func(t *testing.T) {
|
|
r, _, beta := testRegistry(nil)
|
|
m, err := r.Parse("beta/qwen3:30b:high")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if m.defaultReasoning != ReasoningHigh {
|
|
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if beta.lastModel != "qwen3:30b" {
|
|
t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b")
|
|
}
|
|
})
|
|
|
|
t.Run("static alias", func(t *testing.T) {
|
|
r, alpha, _ := testRegistry(nil)
|
|
r.RegisterAlias("fast", "alpha/model-a")
|
|
m, err := r.Parse("fast")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if alpha.lastModel != "model-a" {
|
|
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
|
}
|
|
})
|
|
|
|
t.Run("alias with embedded reasoning", func(t *testing.T) {
|
|
r, alpha, _ := testRegistry(nil)
|
|
r.RegisterAlias("thinking", "alpha/model-a:high")
|
|
m, err := r.Parse("thinking")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if m.defaultReasoning != ReasoningHigh {
|
|
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if alpha.lastModel != "model-a" {
|
|
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
|
}
|
|
})
|
|
|
|
t.Run("user reasoning overrides alias default", func(t *testing.T) {
|
|
r, _, _ := testRegistry(nil)
|
|
r.RegisterAlias("thinking", "alpha/model-a:high")
|
|
m, err := r.Parse("thinking:low")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if m.defaultReasoning != ReasoningLow {
|
|
t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow)
|
|
}
|
|
})
|
|
|
|
t.Run("dynamic resolver", func(t *testing.T) {
|
|
r, alpha, _ := testRegistry(nil)
|
|
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
|
|
if name == "custom" {
|
|
return "alpha/resolved-model", "", true
|
|
}
|
|
return "", "", false
|
|
}))
|
|
m, err := r.Parse("custom")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if alpha.lastModel != "resolved-model" {
|
|
t.Errorf("model = %q, want %q", alpha.lastModel, "resolved-model")
|
|
}
|
|
})
|
|
|
|
t.Run("resolver with default reasoning", func(t *testing.T) {
|
|
r, _, _ := testRegistry(nil)
|
|
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
|
|
if name == "smart" {
|
|
return "alpha/model-a", ReasoningHigh, true
|
|
}
|
|
return "", "", false
|
|
}))
|
|
m, err := r.Parse("smart")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if m.defaultReasoning != ReasoningHigh {
|
|
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
|
|
}
|
|
})
|
|
|
|
t.Run("resolver default reasoning overridden by user", func(t *testing.T) {
|
|
r, _, _ := testRegistry(nil)
|
|
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
|
|
if name == "smart" {
|
|
return "alpha/model-a", ReasoningHigh, true
|
|
}
|
|
return "", "", false
|
|
}))
|
|
m, err := r.Parse("smart:low")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if m.defaultReasoning != ReasoningLow {
|
|
t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow)
|
|
}
|
|
})
|
|
|
|
t.Run("resolver returns empty spec", func(t *testing.T) {
|
|
r, _, _ := testRegistry(nil)
|
|
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
|
|
if name == "bad" {
|
|
return "", "", true
|
|
}
|
|
return "", "", false
|
|
}))
|
|
_, err := r.Parse("bad")
|
|
if err == nil {
|
|
t.Fatal("expected error for empty resolver spec")
|
|
}
|
|
})
|
|
|
|
t.Run("LLM_X env var with token", func(t *testing.T) {
|
|
envFn := func(key string) string {
|
|
if key == "LLM_M5" {
|
|
return "alpha://mytoken@myhost.com"
|
|
}
|
|
return ""
|
|
}
|
|
r, alpha, _ := testRegistry(envFn)
|
|
m, err := r.Parse("m5/qwen3:30b")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if alpha.lastModel != "qwen3:30b" {
|
|
t.Errorf("model = %q, want %q", alpha.lastModel, "qwen3:30b")
|
|
}
|
|
})
|
|
|
|
t.Run("LLM_X env var without token", func(t *testing.T) {
|
|
envFn := func(key string) string {
|
|
if key == "LLM_LOCAL" {
|
|
return "beta://localhost:11434"
|
|
}
|
|
return ""
|
|
}
|
|
r, _, beta := testRegistry(envFn)
|
|
m, err := r.Parse("local/llama3.2")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if beta.lastModel != "llama3.2" {
|
|
t.Errorf("model = %q, want %q", beta.lastModel, "llama3.2")
|
|
}
|
|
})
|
|
|
|
t.Run("LLM_X with hyphenated name", func(t *testing.T) {
|
|
envFn := func(key string) string {
|
|
if key == "LLM_MY_BOX" {
|
|
return "alpha://tok@host.com"
|
|
}
|
|
return ""
|
|
}
|
|
r, alpha, _ := testRegistry(envFn)
|
|
m, err := r.Parse("my-box/some-model")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if alpha.lastModel != "some-model" {
|
|
t.Errorf("model = %q, want %q", alpha.lastModel, "some-model")
|
|
}
|
|
})
|
|
|
|
t.Run("unknown provider no env var", func(t *testing.T) {
|
|
r, _, _ := testRegistry(func(string) string { return "" })
|
|
_, err := r.Parse("unknown/model")
|
|
if err == nil {
|
|
t.Fatal("expected error for unknown provider")
|
|
}
|
|
if !errors.Is(err, ErrUnknownProvider) {
|
|
t.Errorf("error = %v, want ErrUnknownProvider", err)
|
|
}
|
|
})
|
|
|
|
t.Run("bare unknown name no slash", func(t *testing.T) {
|
|
r, _, _ := testRegistry(nil)
|
|
_, err := r.Parse("nonexistent")
|
|
if err == nil {
|
|
t.Fatal("expected error for bare unknown name")
|
|
}
|
|
if !errors.Is(err, ErrUnknownProvider) {
|
|
t.Errorf("error = %v, want ErrUnknownProvider", err)
|
|
}
|
|
})
|
|
|
|
t.Run("alias loop", func(t *testing.T) {
|
|
r, _, _ := testRegistry(nil)
|
|
r.RegisterAlias("a", "b")
|
|
r.RegisterAlias("b", "a")
|
|
_, err := r.Parse("a")
|
|
if err == nil {
|
|
t.Fatal("expected error for alias loop")
|
|
}
|
|
if !errors.Is(err, ErrAliasLoop) {
|
|
t.Errorf("error = %v, want ErrAliasLoop", err)
|
|
}
|
|
})
|
|
|
|
t.Run("deep alias chain within limit", func(t *testing.T) {
|
|
r, alpha, _ := testRegistry(nil)
|
|
// chain: a0 → a1 → a2 → ... → a9 → alpha/model-a (depth 10 is fine, >10 is not)
|
|
r.RegisterAlias("a9", "alpha/model-a")
|
|
for i := 8; i >= 0; i-- {
|
|
r.RegisterAlias(
|
|
"a"+string(rune('0'+i)),
|
|
"a"+string(rune('0'+i+1)),
|
|
)
|
|
}
|
|
m, err := r.Parse("a0")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error for deep but valid chain: %v", err)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if alpha.lastModel != "model-a" {
|
|
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
|
}
|
|
})
|
|
|
|
t.Run("RegisterProvider replaces existing", func(t *testing.T) {
|
|
replaced := &recordingProvider{}
|
|
r, _, _ := testRegistry(nil)
|
|
r.RegisterProvider(ProviderInfo{
|
|
Name: "alpha",
|
|
DisplayName: "Alpha Replaced",
|
|
New: func(_ string, _ ...ClientOption) *Client {
|
|
return NewClient(replaced)
|
|
},
|
|
})
|
|
m, err := r.Parse("alpha/new-model")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if replaced.lastModel != "new-model" {
|
|
t.Errorf("model = %q, want %q", replaced.lastModel, "new-model")
|
|
}
|
|
// Verify order is preserved (alpha is still at index 0).
|
|
providers := r.Providers()
|
|
if providers[0].Name != "alpha" {
|
|
t.Errorf("first provider = %q, want %q (order preserved)", providers[0].Name, "alpha")
|
|
}
|
|
if providers[0].DisplayName != "Alpha Replaced" {
|
|
t.Errorf("display name = %q, want %q", providers[0].DisplayName, "Alpha Replaced")
|
|
}
|
|
})
|
|
|
|
t.Run("RegisterProvider adds new provider", func(t *testing.T) {
|
|
gamma := &recordingProvider{}
|
|
r, _, _ := testRegistry(nil)
|
|
r.RegisterProvider(ProviderInfo{
|
|
Name: "gamma",
|
|
DisplayName: "Gamma",
|
|
New: func(_ string, _ ...ClientOption) *Client {
|
|
return NewClient(gamma)
|
|
},
|
|
})
|
|
m, err := r.Parse("gamma/g-model")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
_, _ = m.Complete(context.Background(), nil)
|
|
if gamma.lastModel != "g-model" {
|
|
t.Errorf("model = %q, want %q", gamma.lastModel, "g-model")
|
|
}
|
|
// New provider should be last.
|
|
providers := r.Providers()
|
|
last := providers[len(providers)-1]
|
|
if last.Name != "gamma" {
|
|
t.Errorf("last provider = %q, want %q", last.Name, "gamma")
|
|
}
|
|
})
|
|
|
|
t.Run("DSN with unknown scheme", func(t *testing.T) {
|
|
envFn := func(key string) string {
|
|
if key == "LLM_X" {
|
|
return "nope://tok@host.com"
|
|
}
|
|
return ""
|
|
}
|
|
r, _, _ := testRegistry(envFn)
|
|
_, err := r.Parse("x/model")
|
|
if err == nil {
|
|
t.Fatal("expected error for unknown DSN scheme")
|
|
}
|
|
if !errors.Is(err, ErrUnknownProvider) {
|
|
t.Errorf("error = %v, want ErrUnknownProvider", err)
|
|
}
|
|
})
|
|
|
|
t.Run("DSN with invalid format", func(t *testing.T) {
|
|
envFn := func(key string) string {
|
|
if key == "LLM_BAD" {
|
|
return "no-scheme"
|
|
}
|
|
return ""
|
|
}
|
|
r, _, _ := testRegistry(envFn)
|
|
_, err := r.Parse("bad/model")
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid DSN")
|
|
}
|
|
if !errors.Is(err, ErrInvalidDSN) {
|
|
t.Errorf("error = %v, want ErrInvalidDSN", err)
|
|
}
|
|
})
|
|
|
|
t.Run("createClient reads env key", func(t *testing.T) {
|
|
var capturedKey string
|
|
r, _, _ := testRegistry(func(key string) string {
|
|
if key == "MY_KEY" {
|
|
return "secret-123"
|
|
}
|
|
return ""
|
|
})
|
|
rec := &recordingProvider{}
|
|
r.RegisterProvider(ProviderInfo{
|
|
Name: "keyed",
|
|
EnvKey: "MY_KEY",
|
|
New: func(apiKey string, opts ...ClientOption) *Client {
|
|
capturedKey = apiKey
|
|
return NewClient(rec)
|
|
},
|
|
})
|
|
_, err := r.Parse("keyed/m")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if capturedKey != "secret-123" {
|
|
t.Errorf("apiKey = %q, want %q", capturedKey, "secret-123")
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Backward compatibility: package-level Providers() and ProviderByName()
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestBackwardCompat(t *testing.T) {
|
|
t.Run("Providers includes foreman", func(t *testing.T) {
|
|
providers := Providers()
|
|
found := false
|
|
for _, p := range providers {
|
|
if p.Name == "foreman" {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Error("Providers() does not include foreman")
|
|
}
|
|
})
|
|
|
|
t.Run("ProviderByName openai", func(t *testing.T) {
|
|
p := ProviderByName("openai")
|
|
if p == nil {
|
|
t.Fatal("ProviderByName(\"openai\") returned nil")
|
|
}
|
|
if p.DisplayName != "OpenAI" {
|
|
t.Errorf("DisplayName = %q, want %q", p.DisplayName, "OpenAI")
|
|
}
|
|
})
|
|
|
|
t.Run("ProviderByName foreman", func(t *testing.T) {
|
|
p := ProviderByName("foreman")
|
|
if p == nil {
|
|
t.Fatal("ProviderByName(\"foreman\") returned nil")
|
|
}
|
|
if p.DisplayName != "Foreman" {
|
|
t.Errorf("DisplayName = %q, want %q", p.DisplayName, "Foreman")
|
|
}
|
|
})
|
|
|
|
t.Run("ProviderByName unknown", func(t *testing.T) {
|
|
p := ProviderByName("does-not-exist")
|
|
if p != nil {
|
|
t.Errorf("expected nil for unknown provider, got %+v", p)
|
|
}
|
|
})
|
|
|
|
t.Run("Providers count", func(t *testing.T) {
|
|
providers := Providers()
|
|
// providerRegistry has 9 entries (openai, anthropic, google, deepseek,
|
|
// moonshot, xai, groq, ollama, ollama-cloud) + foreman = 10
|
|
if len(providers) < 10 {
|
|
t.Errorf("len(Providers()) = %d, want >= 10", len(providers))
|
|
}
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// NewRegistry isolation
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestNewRegistryIsolation(t *testing.T) {
|
|
r1 := NewRegistry()
|
|
r2 := NewRegistry()
|
|
|
|
r1.RegisterAlias("only-in-r1", "openai/gpt-4o")
|
|
|
|
// r2 should not see r1's alias
|
|
r2.mu.RLock()
|
|
_, found := r2.aliases["only-in-r1"]
|
|
r2.mu.RUnlock()
|
|
if found {
|
|
t.Error("alias registered in r1 should not appear in r2")
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Comma-separated failover chains
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestParse_CommaProducesFailover(t *testing.T) {
|
|
resetHealthForTest()
|
|
r, alpha, beta := testRegistry(func(string) string { return "" })
|
|
|
|
m, err := r.Parse("alpha/model-a,beta/model-b")
|
|
if err != nil {
|
|
t.Fatalf("Parse failover spec: %v", err)
|
|
}
|
|
fp, ok := m.provider.(*failoverProvider)
|
|
if !ok {
|
|
t.Fatalf("expected *failoverProvider, got %T", m.provider)
|
|
}
|
|
if len(fp.entries) != 2 {
|
|
t.Fatalf("expected 2 entries, got %d", len(fp.entries))
|
|
}
|
|
if fp.entries[0].specKey != "alpha/model-a" || fp.entries[1].specKey != "beta/model-b" {
|
|
t.Errorf("unexpected specKeys: %q, %q", fp.entries[0].specKey, fp.entries[1].specKey)
|
|
}
|
|
// Complete routes to the first provider and passes the bare model name.
|
|
_, err = m.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if alpha.lastModel != "model-a" {
|
|
t.Errorf("alpha got model %q, want model-a", alpha.lastModel)
|
|
}
|
|
if beta.lastModel != "" {
|
|
t.Errorf("beta should not have been called, got model %q", beta.lastModel)
|
|
}
|
|
}
|
|
|
|
func TestParse_NoCommaUnchanged(t *testing.T) {
|
|
resetHealthForTest()
|
|
r, _, _ := testRegistry(func(string) string { return "" })
|
|
m, err := r.Parse("alpha/model-a")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, ok := m.provider.(*failoverProvider); ok {
|
|
t.Error("single (comma-free) spec must NOT produce a failover provider")
|
|
}
|
|
}
|
|
|
|
func TestParse_CommaSinglePartFallsThrough(t *testing.T) {
|
|
resetHealthForTest()
|
|
r, _, _ := testRegistry(func(string) string { return "" })
|
|
// Trailing comma / whitespace collapses to a single real part.
|
|
m, err := r.Parse("alpha/model-a, ")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, ok := m.provider.(*failoverProvider); ok {
|
|
t.Error("a single effective part must not produce a failover provider")
|
|
}
|
|
}
|
|
|
|
func TestParse_CommaFlattensNested(t *testing.T) {
|
|
resetHealthForTest()
|
|
r, _, _ := testRegistry(func(string) string { return "" })
|
|
// Register an alias that is itself a comma chain.
|
|
r.RegisterAlias("pair", "alpha/model-a,beta/model-b")
|
|
m, err := r.Parse("pair,beta/model-b2")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
fp, ok := m.provider.(*failoverProvider)
|
|
if !ok {
|
|
t.Fatalf("expected *failoverProvider, got %T", m.provider)
|
|
}
|
|
if len(fp.entries) != 3 {
|
|
t.Errorf("expected flattened 3 entries, got %d", len(fp.entries))
|
|
}
|
|
}
|