feat(v2): add Parse() function and extensible Registry for model string resolution
Introduces llm.Parse(spec) backed by an extensible Registry that resolves model strings like "openai/gpt-4o", aliases like "fast", and named targets like "m5/qwen3:30b" (via LLM_M5 env var DSNs) into ready-to-use *Model objects. Extension points: RegisterProvider, RegisterAlias, RegisterResolver. Adds Foreman constructor and sentinel errors ErrAliasLoop, ErrUnknownProvider, ErrInvalidDSN. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,639 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
// recordingProvider captures the model name passed to Complete so tests can
|
||||
// verify that Parse resolved to the correct model without network calls.
|
||||
type recordingProvider struct {
|
||||
lastModel string
|
||||
}
|
||||
|
||||
func (p *recordingProvider) Complete(_ context.Context, req provider.Request) (provider.Response, error) {
|
||||
p.lastModel = req.Model
|
||||
return provider.Response{Text: "ok"}, nil
|
||||
}
|
||||
|
||||
func (p *recordingProvider) Stream(_ context.Context, _ provider.Request, events chan<- provider.StreamEvent) error {
|
||||
close(events)
|
||||
return nil
|
||||
}
|
||||
|
||||
// testRegistry builds a Registry with two mock providers ("alpha" and "beta")
|
||||
// and an injectable envLookup. No real API keys or network access required.
|
||||
func testRegistry(envFn func(string) string) (*Registry, *recordingProvider, *recordingProvider) {
|
||||
alpha := &recordingProvider{}
|
||||
beta := &recordingProvider{}
|
||||
|
||||
r := &Registry{
|
||||
providers: map[string]ProviderInfo{
|
||||
"alpha": {
|
||||
Name: "alpha",
|
||||
DisplayName: "Alpha",
|
||||
EnvKey: "ALPHA_API_KEY",
|
||||
Models: []string{"model-a"},
|
||||
New: func(apiKey string, opts ...ClientOption) *Client {
|
||||
return NewClient(alpha)
|
||||
},
|
||||
},
|
||||
"beta": {
|
||||
Name: "beta",
|
||||
DisplayName: "Beta",
|
||||
EnvKey: "",
|
||||
Models: []string{"model-b"},
|
||||
New: func(_ string, opts ...ClientOption) *Client {
|
||||
return NewClient(beta)
|
||||
},
|
||||
},
|
||||
},
|
||||
order: []string{"alpha", "beta"},
|
||||
aliases: make(map[string]string),
|
||||
envLookup: envFn,
|
||||
}
|
||||
return r, alpha, beta
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// splitReasoning
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSplitReasoning(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantBase string
|
||||
wantLevel ReasoningLevel
|
||||
}{
|
||||
{"gpt-4o", "gpt-4o", ""},
|
||||
{"gpt-4o:high", "gpt-4o", ReasoningHigh},
|
||||
{"gpt-4o:low", "gpt-4o", ReasoningLow},
|
||||
{"gpt-4o:medium", "gpt-4o", ReasoningMedium},
|
||||
{"qwen3:30b", "qwen3:30b", ""}, // Ollama tag, not a level
|
||||
{"qwen3:30b:high", "qwen3:30b", ReasoningHigh}, // tag + reasoning
|
||||
{"model:", "model:", ""}, // trailing colon, empty suffix
|
||||
{"", "", ""}, // empty string
|
||||
{"a:b:c:high", "a:b:c", ReasoningHigh}, // multiple colons
|
||||
{"a:b:c:30b", "a:b:c:30b", ""}, // multiple colons, non-level
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
base, level := splitReasoning(tt.input)
|
||||
if base != tt.wantBase {
|
||||
t.Errorf("splitReasoning(%q) base = %q, want %q", tt.input, base, tt.wantBase)
|
||||
}
|
||||
if level != tt.wantLevel {
|
||||
t.Errorf("splitReasoning(%q) level = %q, want %q", tt.input, level, tt.wantLevel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ParseDSN
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestParseDSN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
want DSN
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "full DSN with token",
|
||||
raw: "foreman://test-token@foreman-m5.orgrimmar.dudenhoeffer.casa",
|
||||
want: DSN{Scheme: "foreman", Token: "test-token", Host: "foreman-m5.orgrimmar.dudenhoeffer.casa"},
|
||||
},
|
||||
{
|
||||
name: "no token",
|
||||
raw: "ollama://localhost:11434",
|
||||
want: DSN{Scheme: "ollama", Token: "", Host: "localhost:11434"},
|
||||
},
|
||||
{
|
||||
name: "trailing slash stripped",
|
||||
raw: "foreman://tok@host.com/",
|
||||
want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com"},
|
||||
},
|
||||
{
|
||||
name: "with path",
|
||||
raw: "foreman://tok@host.com/v1/api",
|
||||
want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com/v1/api"},
|
||||
},
|
||||
{
|
||||
name: "missing scheme",
|
||||
raw: "no-scheme-here",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
raw: "foreman://token@",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
raw: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "scheme only",
|
||||
raw: "foreman://",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseDSN(tt.raw)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("ParseDSN(%q) expected error, got nil", tt.raw)
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidDSN) {
|
||||
t.Errorf("ParseDSN(%q) error = %v, want ErrInvalidDSN", tt.raw, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("ParseDSN(%q) unexpected error: %v", tt.raw, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseDSN(%q) = %+v, want %+v", tt.raw, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Registry.Parse
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRegistryParse(t *testing.T) {
|
||||
t.Run("basic provider/model", func(t *testing.T) {
|
||||
r, alpha, _ := testRegistry(nil)
|
||||
m, err := r.Parse("alpha/model-a")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil model")
|
||||
}
|
||||
// Verify the model name by exercising it.
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if alpha.lastModel != "model-a" {
|
||||
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("provider/model with reasoning suffix", func(t *testing.T) {
|
||||
r, alpha, _ := testRegistry(nil)
|
||||
m, err := r.Parse("alpha/model-a:high")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if m.defaultReasoning != ReasoningHigh {
|
||||
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if alpha.lastModel != "model-a" {
|
||||
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ollama-style tag preserved", func(t *testing.T) {
|
||||
r, _, beta := testRegistry(nil)
|
||||
m, err := r.Parse("beta/qwen3:30b")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if m.defaultReasoning != "" {
|
||||
t.Errorf("reasoning = %q, want empty (30b is not a level)", m.defaultReasoning)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if beta.lastModel != "qwen3:30b" {
|
||||
t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ollama-style tag with reasoning", func(t *testing.T) {
|
||||
r, _, beta := testRegistry(nil)
|
||||
m, err := r.Parse("beta/qwen3:30b:high")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if m.defaultReasoning != ReasoningHigh {
|
||||
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if beta.lastModel != "qwen3:30b" {
|
||||
t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("static alias", func(t *testing.T) {
|
||||
r, alpha, _ := testRegistry(nil)
|
||||
r.RegisterAlias("fast", "alpha/model-a")
|
||||
m, err := r.Parse("fast")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if alpha.lastModel != "model-a" {
|
||||
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias with embedded reasoning", func(t *testing.T) {
|
||||
r, alpha, _ := testRegistry(nil)
|
||||
r.RegisterAlias("thinking", "alpha/model-a:high")
|
||||
m, err := r.Parse("thinking")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if m.defaultReasoning != ReasoningHigh {
|
||||
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if alpha.lastModel != "model-a" {
|
||||
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("user reasoning overrides alias default", func(t *testing.T) {
|
||||
r, _, _ := testRegistry(nil)
|
||||
r.RegisterAlias("thinking", "alpha/model-a:high")
|
||||
m, err := r.Parse("thinking:low")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if m.defaultReasoning != ReasoningLow {
|
||||
t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("dynamic resolver", func(t *testing.T) {
|
||||
r, alpha, _ := testRegistry(nil)
|
||||
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
|
||||
if name == "custom" {
|
||||
return "alpha/resolved-model", "", true
|
||||
}
|
||||
return "", "", false
|
||||
}))
|
||||
m, err := r.Parse("custom")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if alpha.lastModel != "resolved-model" {
|
||||
t.Errorf("model = %q, want %q", alpha.lastModel, "resolved-model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("resolver with default reasoning", func(t *testing.T) {
|
||||
r, _, _ := testRegistry(nil)
|
||||
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
|
||||
if name == "smart" {
|
||||
return "alpha/model-a", ReasoningHigh, true
|
||||
}
|
||||
return "", "", false
|
||||
}))
|
||||
m, err := r.Parse("smart")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if m.defaultReasoning != ReasoningHigh {
|
||||
t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("resolver default reasoning overridden by user", func(t *testing.T) {
|
||||
r, _, _ := testRegistry(nil)
|
||||
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
|
||||
if name == "smart" {
|
||||
return "alpha/model-a", ReasoningHigh, true
|
||||
}
|
||||
return "", "", false
|
||||
}))
|
||||
m, err := r.Parse("smart:low")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if m.defaultReasoning != ReasoningLow {
|
||||
t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("resolver returns empty spec", func(t *testing.T) {
|
||||
r, _, _ := testRegistry(nil)
|
||||
r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) {
|
||||
if name == "bad" {
|
||||
return "", "", true
|
||||
}
|
||||
return "", "", false
|
||||
}))
|
||||
_, err := r.Parse("bad")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty resolver spec")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LLM_X env var with token", func(t *testing.T) {
|
||||
envFn := func(key string) string {
|
||||
if key == "LLM_M5" {
|
||||
return "alpha://mytoken@myhost.com"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
r, alpha, _ := testRegistry(envFn)
|
||||
m, err := r.Parse("m5/qwen3:30b")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if alpha.lastModel != "qwen3:30b" {
|
||||
t.Errorf("model = %q, want %q", alpha.lastModel, "qwen3:30b")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LLM_X env var without token", func(t *testing.T) {
|
||||
envFn := func(key string) string {
|
||||
if key == "LLM_LOCAL" {
|
||||
return "beta://localhost:11434"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
r, _, beta := testRegistry(envFn)
|
||||
m, err := r.Parse("local/llama3.2")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if beta.lastModel != "llama3.2" {
|
||||
t.Errorf("model = %q, want %q", beta.lastModel, "llama3.2")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LLM_X with hyphenated name", func(t *testing.T) {
|
||||
envFn := func(key string) string {
|
||||
if key == "LLM_MY_BOX" {
|
||||
return "alpha://tok@host.com"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
r, alpha, _ := testRegistry(envFn)
|
||||
m, err := r.Parse("my-box/some-model")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if alpha.lastModel != "some-model" {
|
||||
t.Errorf("model = %q, want %q", alpha.lastModel, "some-model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown provider no env var", func(t *testing.T) {
|
||||
r, _, _ := testRegistry(func(string) string { return "" })
|
||||
_, err := r.Parse("unknown/model")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown provider")
|
||||
}
|
||||
if !errors.Is(err, ErrUnknownProvider) {
|
||||
t.Errorf("error = %v, want ErrUnknownProvider", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bare unknown name no slash", func(t *testing.T) {
|
||||
r, _, _ := testRegistry(nil)
|
||||
_, err := r.Parse("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for bare unknown name")
|
||||
}
|
||||
if !errors.Is(err, ErrUnknownProvider) {
|
||||
t.Errorf("error = %v, want ErrUnknownProvider", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias loop", func(t *testing.T) {
|
||||
r, _, _ := testRegistry(nil)
|
||||
r.RegisterAlias("a", "b")
|
||||
r.RegisterAlias("b", "a")
|
||||
_, err := r.Parse("a")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for alias loop")
|
||||
}
|
||||
if !errors.Is(err, ErrAliasLoop) {
|
||||
t.Errorf("error = %v, want ErrAliasLoop", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deep alias chain within limit", func(t *testing.T) {
|
||||
r, alpha, _ := testRegistry(nil)
|
||||
// chain: a0 → a1 → a2 → ... → a9 → alpha/model-a (depth 10 is fine, >10 is not)
|
||||
r.RegisterAlias("a9", "alpha/model-a")
|
||||
for i := 8; i >= 0; i-- {
|
||||
r.RegisterAlias(
|
||||
"a"+string(rune('0'+i)),
|
||||
"a"+string(rune('0'+i+1)),
|
||||
)
|
||||
}
|
||||
m, err := r.Parse("a0")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for deep but valid chain: %v", err)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if alpha.lastModel != "model-a" {
|
||||
t.Errorf("model = %q, want %q", alpha.lastModel, "model-a")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RegisterProvider replaces existing", func(t *testing.T) {
|
||||
replaced := &recordingProvider{}
|
||||
r, _, _ := testRegistry(nil)
|
||||
r.RegisterProvider(ProviderInfo{
|
||||
Name: "alpha",
|
||||
DisplayName: "Alpha Replaced",
|
||||
New: func(_ string, _ ...ClientOption) *Client {
|
||||
return NewClient(replaced)
|
||||
},
|
||||
})
|
||||
m, err := r.Parse("alpha/new-model")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if replaced.lastModel != "new-model" {
|
||||
t.Errorf("model = %q, want %q", replaced.lastModel, "new-model")
|
||||
}
|
||||
// Verify order is preserved (alpha is still at index 0).
|
||||
providers := r.Providers()
|
||||
if providers[0].Name != "alpha" {
|
||||
t.Errorf("first provider = %q, want %q (order preserved)", providers[0].Name, "alpha")
|
||||
}
|
||||
if providers[0].DisplayName != "Alpha Replaced" {
|
||||
t.Errorf("display name = %q, want %q", providers[0].DisplayName, "Alpha Replaced")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RegisterProvider adds new provider", func(t *testing.T) {
|
||||
gamma := &recordingProvider{}
|
||||
r, _, _ := testRegistry(nil)
|
||||
r.RegisterProvider(ProviderInfo{
|
||||
Name: "gamma",
|
||||
DisplayName: "Gamma",
|
||||
New: func(_ string, _ ...ClientOption) *Client {
|
||||
return NewClient(gamma)
|
||||
},
|
||||
})
|
||||
m, err := r.Parse("gamma/g-model")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_, _ = m.Complete(context.Background(), nil)
|
||||
if gamma.lastModel != "g-model" {
|
||||
t.Errorf("model = %q, want %q", gamma.lastModel, "g-model")
|
||||
}
|
||||
// New provider should be last.
|
||||
providers := r.Providers()
|
||||
last := providers[len(providers)-1]
|
||||
if last.Name != "gamma" {
|
||||
t.Errorf("last provider = %q, want %q", last.Name, "gamma")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DSN with unknown scheme", func(t *testing.T) {
|
||||
envFn := func(key string) string {
|
||||
if key == "LLM_X" {
|
||||
return "nope://tok@host.com"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
r, _, _ := testRegistry(envFn)
|
||||
_, err := r.Parse("x/model")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown DSN scheme")
|
||||
}
|
||||
if !errors.Is(err, ErrUnknownProvider) {
|
||||
t.Errorf("error = %v, want ErrUnknownProvider", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DSN with invalid format", func(t *testing.T) {
|
||||
envFn := func(key string) string {
|
||||
if key == "LLM_BAD" {
|
||||
return "no-scheme"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
r, _, _ := testRegistry(envFn)
|
||||
_, err := r.Parse("bad/model")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid DSN")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidDSN) {
|
||||
t.Errorf("error = %v, want ErrInvalidDSN", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("createClient reads env key", func(t *testing.T) {
|
||||
var capturedKey string
|
||||
r, _, _ := testRegistry(func(key string) string {
|
||||
if key == "MY_KEY" {
|
||||
return "secret-123"
|
||||
}
|
||||
return ""
|
||||
})
|
||||
rec := &recordingProvider{}
|
||||
r.RegisterProvider(ProviderInfo{
|
||||
Name: "keyed",
|
||||
EnvKey: "MY_KEY",
|
||||
New: func(apiKey string, opts ...ClientOption) *Client {
|
||||
capturedKey = apiKey
|
||||
return NewClient(rec)
|
||||
},
|
||||
})
|
||||
_, err := r.Parse("keyed/m")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if capturedKey != "secret-123" {
|
||||
t.Errorf("apiKey = %q, want %q", capturedKey, "secret-123")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Backward compatibility: package-level Providers() and ProviderByName()
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBackwardCompat(t *testing.T) {
|
||||
t.Run("Providers includes foreman", func(t *testing.T) {
|
||||
providers := Providers()
|
||||
found := false
|
||||
for _, p := range providers {
|
||||
if p.Name == "foreman" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Providers() does not include foreman")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProviderByName openai", func(t *testing.T) {
|
||||
p := ProviderByName("openai")
|
||||
if p == nil {
|
||||
t.Fatal("ProviderByName(\"openai\") returned nil")
|
||||
}
|
||||
if p.DisplayName != "OpenAI" {
|
||||
t.Errorf("DisplayName = %q, want %q", p.DisplayName, "OpenAI")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProviderByName foreman", func(t *testing.T) {
|
||||
p := ProviderByName("foreman")
|
||||
if p == nil {
|
||||
t.Fatal("ProviderByName(\"foreman\") returned nil")
|
||||
}
|
||||
if p.DisplayName != "Foreman" {
|
||||
t.Errorf("DisplayName = %q, want %q", p.DisplayName, "Foreman")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ProviderByName unknown", func(t *testing.T) {
|
||||
p := ProviderByName("does-not-exist")
|
||||
if p != nil {
|
||||
t.Errorf("expected nil for unknown provider, got %+v", p)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Providers count", func(t *testing.T) {
|
||||
providers := Providers()
|
||||
// providerRegistry has 9 entries (openai, anthropic, google, deepseek,
|
||||
// moonshot, xai, groq, ollama, ollama-cloud) + foreman = 10
|
||||
if len(providers) < 10 {
|
||||
t.Errorf("len(Providers()) = %d, want >= 10", len(providers))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NewRegistry isolation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewRegistryIsolation(t *testing.T) {
|
||||
r1 := NewRegistry()
|
||||
r2 := NewRegistry()
|
||||
|
||||
r1.RegisterAlias("only-in-r1", "openai/gpt-4o")
|
||||
|
||||
// r2 should not see r1's alias
|
||||
r2.mu.RLock()
|
||||
_, found := r2.aliases["only-in-r1"]
|
||||
r2.mu.RUnlock()
|
||||
if found {
|
||||
t.Error("alias registered in r1 should not appear in r2")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user