package llm import ( "context" "errors" "testing" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) // recordingProvider captures the model name passed to Complete so tests can // verify that Parse resolved to the correct model without network calls. type recordingProvider struct { lastModel string // err, when non-nil, is returned from Complete so failover tests can drive // a comma-Parse'd chain through a failover decision. Defaults to nil (success). err error } func (p *recordingProvider) Complete(_ context.Context, req provider.Request) (provider.Response, error) { p.lastModel = req.Model if p.err != nil { return provider.Response{}, p.err } return provider.Response{Text: "ok"}, nil } func (p *recordingProvider) Stream(_ context.Context, _ provider.Request, events chan<- provider.StreamEvent) error { close(events) return nil } // testRegistry builds a Registry with two mock providers ("alpha" and "beta") // and an injectable envLookup. No real API keys or network access required. func testRegistry(envFn func(string) string) (*Registry, *recordingProvider, *recordingProvider) { alpha := &recordingProvider{} beta := &recordingProvider{} r := &Registry{ providers: map[string]ProviderInfo{ "alpha": { Name: "alpha", DisplayName: "Alpha", EnvKey: "ALPHA_API_KEY", Models: []string{"model-a"}, New: func(apiKey string, opts ...ClientOption) *Client { return NewClient(alpha) }, }, "beta": { Name: "beta", DisplayName: "Beta", EnvKey: "", Models: []string{"model-b"}, New: func(_ string, opts ...ClientOption) *Client { return NewClient(beta) }, }, }, order: []string{"alpha", "beta"}, aliases: make(map[string]string), envLookup: envFn, } return r, alpha, beta } // --------------------------------------------------------------------------- // splitReasoning // --------------------------------------------------------------------------- func TestSplitReasoning(t *testing.T) { tests := []struct { input string wantBase string wantLevel ReasoningLevel }{ {"gpt-4o", "gpt-4o", ""}, {"gpt-4o:high", "gpt-4o", ReasoningHigh}, {"gpt-4o:low", "gpt-4o", ReasoningLow}, {"gpt-4o:medium", "gpt-4o", ReasoningMedium}, {"qwen3:30b", "qwen3:30b", ""}, // Ollama tag, not a level {"qwen3:30b:high", "qwen3:30b", ReasoningHigh}, // tag + reasoning {"model:", "model:", ""}, // trailing colon, empty suffix {"", "", ""}, // empty string {"a:b:c:high", "a:b:c", ReasoningHigh}, // multiple colons {"a:b:c:30b", "a:b:c:30b", ""}, // multiple colons, non-level } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { base, level := splitReasoning(tt.input) if base != tt.wantBase { t.Errorf("splitReasoning(%q) base = %q, want %q", tt.input, base, tt.wantBase) } if level != tt.wantLevel { t.Errorf("splitReasoning(%q) level = %q, want %q", tt.input, level, tt.wantLevel) } }) } } // --------------------------------------------------------------------------- // ParseDSN // --------------------------------------------------------------------------- func TestParseDSN(t *testing.T) { tests := []struct { name string raw string want DSN wantErr bool }{ { name: "full DSN with token", raw: "foreman://test-token@foreman-m5.orgrimmar.dudenhoeffer.casa", want: DSN{Scheme: "foreman", Token: "test-token", Host: "foreman-m5.orgrimmar.dudenhoeffer.casa"}, }, { name: "no token", raw: "ollama://localhost:11434", want: DSN{Scheme: "ollama", Token: "", Host: "localhost:11434"}, }, { name: "trailing slash stripped", raw: "foreman://tok@host.com/", want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com"}, }, { name: "with path", raw: "foreman://tok@host.com/v1/api", want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com/v1/api"}, }, { name: "missing scheme", raw: "no-scheme-here", wantErr: true, }, { name: "missing host", raw: "foreman://token@", wantErr: true, }, { name: "empty string", raw: "", wantErr: true, }, { name: "scheme only", raw: "foreman://", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := ParseDSN(tt.raw) if tt.wantErr { if err == nil { t.Fatalf("ParseDSN(%q) expected error, got nil", tt.raw) } if !errors.Is(err, ErrInvalidDSN) { t.Errorf("ParseDSN(%q) error = %v, want ErrInvalidDSN", tt.raw, err) } return } if err != nil { t.Fatalf("ParseDSN(%q) unexpected error: %v", tt.raw, err) } if got != tt.want { t.Errorf("ParseDSN(%q) = %+v, want %+v", tt.raw, got, tt.want) } }) } } // --------------------------------------------------------------------------- // Registry.Parse // --------------------------------------------------------------------------- func TestRegistryParse(t *testing.T) { t.Run("basic provider/model", func(t *testing.T) { r, alpha, _ := testRegistry(nil) m, err := r.Parse("alpha/model-a") if err != nil { t.Fatalf("unexpected error: %v", err) } if m == nil { t.Fatal("expected non-nil model") } // Verify the model name by exercising it. _, _ = m.Complete(context.Background(), nil) if alpha.lastModel != "model-a" { t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") } }) t.Run("provider/model with reasoning suffix", func(t *testing.T) { r, alpha, _ := testRegistry(nil) m, err := r.Parse("alpha/model-a:high") if err != nil { t.Fatalf("unexpected error: %v", err) } if m.defaultReasoning != ReasoningHigh { t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh) } _, _ = m.Complete(context.Background(), nil) if alpha.lastModel != "model-a" { t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") } }) t.Run("ollama-style tag preserved", func(t *testing.T) { r, _, beta := testRegistry(nil) m, err := r.Parse("beta/qwen3:30b") if err != nil { t.Fatalf("unexpected error: %v", err) } if m.defaultReasoning != "" { t.Errorf("reasoning = %q, want empty (30b is not a level)", m.defaultReasoning) } _, _ = m.Complete(context.Background(), nil) if beta.lastModel != "qwen3:30b" { t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b") } }) t.Run("ollama-style tag with reasoning", func(t *testing.T) { r, _, beta := testRegistry(nil) m, err := r.Parse("beta/qwen3:30b:high") if err != nil { t.Fatalf("unexpected error: %v", err) } if m.defaultReasoning != ReasoningHigh { t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh) } _, _ = m.Complete(context.Background(), nil) if beta.lastModel != "qwen3:30b" { t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b") } }) t.Run("static alias", func(t *testing.T) { r, alpha, _ := testRegistry(nil) r.RegisterAlias("fast", "alpha/model-a") m, err := r.Parse("fast") if err != nil { t.Fatalf("unexpected error: %v", err) } _, _ = m.Complete(context.Background(), nil) if alpha.lastModel != "model-a" { t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") } }) t.Run("alias with embedded reasoning", func(t *testing.T) { r, alpha, _ := testRegistry(nil) r.RegisterAlias("thinking", "alpha/model-a:high") m, err := r.Parse("thinking") if err != nil { t.Fatalf("unexpected error: %v", err) } if m.defaultReasoning != ReasoningHigh { t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh) } _, _ = m.Complete(context.Background(), nil) if alpha.lastModel != "model-a" { t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") } }) t.Run("user reasoning overrides alias default", func(t *testing.T) { r, _, _ := testRegistry(nil) r.RegisterAlias("thinking", "alpha/model-a:high") m, err := r.Parse("thinking:low") if err != nil { t.Fatalf("unexpected error: %v", err) } if m.defaultReasoning != ReasoningLow { t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow) } }) t.Run("dynamic resolver", func(t *testing.T) { r, alpha, _ := testRegistry(nil) r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) { if name == "custom" { return "alpha/resolved-model", "", true } return "", "", false })) m, err := r.Parse("custom") if err != nil { t.Fatalf("unexpected error: %v", err) } _, _ = m.Complete(context.Background(), nil) if alpha.lastModel != "resolved-model" { t.Errorf("model = %q, want %q", alpha.lastModel, "resolved-model") } }) t.Run("resolver with default reasoning", func(t *testing.T) { r, _, _ := testRegistry(nil) r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) { if name == "smart" { return "alpha/model-a", ReasoningHigh, true } return "", "", false })) m, err := r.Parse("smart") if err != nil { t.Fatalf("unexpected error: %v", err) } if m.defaultReasoning != ReasoningHigh { t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh) } }) t.Run("resolver default reasoning overridden by user", func(t *testing.T) { r, _, _ := testRegistry(nil) r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) { if name == "smart" { return "alpha/model-a", ReasoningHigh, true } return "", "", false })) m, err := r.Parse("smart:low") if err != nil { t.Fatalf("unexpected error: %v", err) } if m.defaultReasoning != ReasoningLow { t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow) } }) t.Run("resolver returns empty spec", func(t *testing.T) { r, _, _ := testRegistry(nil) r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) { if name == "bad" { return "", "", true } return "", "", false })) _, err := r.Parse("bad") if err == nil { t.Fatal("expected error for empty resolver spec") } }) t.Run("LLM_X env var with token", func(t *testing.T) { envFn := func(key string) string { if key == "LLM_M5" { return "alpha://mytoken@myhost.com" } return "" } r, alpha, _ := testRegistry(envFn) m, err := r.Parse("m5/qwen3:30b") if err != nil { t.Fatalf("unexpected error: %v", err) } _, _ = m.Complete(context.Background(), nil) if alpha.lastModel != "qwen3:30b" { t.Errorf("model = %q, want %q", alpha.lastModel, "qwen3:30b") } }) t.Run("LLM_X env var without token", func(t *testing.T) { envFn := func(key string) string { if key == "LLM_LOCAL" { return "beta://localhost:11434" } return "" } r, _, beta := testRegistry(envFn) m, err := r.Parse("local/llama3.2") if err != nil { t.Fatalf("unexpected error: %v", err) } _, _ = m.Complete(context.Background(), nil) if beta.lastModel != "llama3.2" { t.Errorf("model = %q, want %q", beta.lastModel, "llama3.2") } }) t.Run("LLM_X with hyphenated name", func(t *testing.T) { envFn := func(key string) string { if key == "LLM_MY_BOX" { return "alpha://tok@host.com" } return "" } r, alpha, _ := testRegistry(envFn) m, err := r.Parse("my-box/some-model") if err != nil { t.Fatalf("unexpected error: %v", err) } _, _ = m.Complete(context.Background(), nil) if alpha.lastModel != "some-model" { t.Errorf("model = %q, want %q", alpha.lastModel, "some-model") } }) t.Run("unknown provider no env var", func(t *testing.T) { r, _, _ := testRegistry(func(string) string { return "" }) _, err := r.Parse("unknown/model") if err == nil { t.Fatal("expected error for unknown provider") } if !errors.Is(err, ErrUnknownProvider) { t.Errorf("error = %v, want ErrUnknownProvider", err) } }) t.Run("bare unknown name no slash", func(t *testing.T) { r, _, _ := testRegistry(nil) _, err := r.Parse("nonexistent") if err == nil { t.Fatal("expected error for bare unknown name") } if !errors.Is(err, ErrUnknownProvider) { t.Errorf("error = %v, want ErrUnknownProvider", err) } }) t.Run("alias loop", func(t *testing.T) { r, _, _ := testRegistry(nil) r.RegisterAlias("a", "b") r.RegisterAlias("b", "a") _, err := r.Parse("a") if err == nil { t.Fatal("expected error for alias loop") } if !errors.Is(err, ErrAliasLoop) { t.Errorf("error = %v, want ErrAliasLoop", err) } }) t.Run("deep alias chain within limit", func(t *testing.T) { r, alpha, _ := testRegistry(nil) // chain: a0 → a1 → a2 → ... → a9 → alpha/model-a (depth 10 is fine, >10 is not) r.RegisterAlias("a9", "alpha/model-a") for i := 8; i >= 0; i-- { r.RegisterAlias( "a"+string(rune('0'+i)), "a"+string(rune('0'+i+1)), ) } m, err := r.Parse("a0") if err != nil { t.Fatalf("unexpected error for deep but valid chain: %v", err) } _, _ = m.Complete(context.Background(), nil) if alpha.lastModel != "model-a" { t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") } }) t.Run("RegisterProvider replaces existing", func(t *testing.T) { replaced := &recordingProvider{} r, _, _ := testRegistry(nil) r.RegisterProvider(ProviderInfo{ Name: "alpha", DisplayName: "Alpha Replaced", New: func(_ string, _ ...ClientOption) *Client { return NewClient(replaced) }, }) m, err := r.Parse("alpha/new-model") if err != nil { t.Fatalf("unexpected error: %v", err) } _, _ = m.Complete(context.Background(), nil) if replaced.lastModel != "new-model" { t.Errorf("model = %q, want %q", replaced.lastModel, "new-model") } // Verify order is preserved (alpha is still at index 0). providers := r.Providers() if providers[0].Name != "alpha" { t.Errorf("first provider = %q, want %q (order preserved)", providers[0].Name, "alpha") } if providers[0].DisplayName != "Alpha Replaced" { t.Errorf("display name = %q, want %q", providers[0].DisplayName, "Alpha Replaced") } }) t.Run("RegisterProvider adds new provider", func(t *testing.T) { gamma := &recordingProvider{} r, _, _ := testRegistry(nil) r.RegisterProvider(ProviderInfo{ Name: "gamma", DisplayName: "Gamma", New: func(_ string, _ ...ClientOption) *Client { return NewClient(gamma) }, }) m, err := r.Parse("gamma/g-model") if err != nil { t.Fatalf("unexpected error: %v", err) } _, _ = m.Complete(context.Background(), nil) if gamma.lastModel != "g-model" { t.Errorf("model = %q, want %q", gamma.lastModel, "g-model") } // New provider should be last. providers := r.Providers() last := providers[len(providers)-1] if last.Name != "gamma" { t.Errorf("last provider = %q, want %q", last.Name, "gamma") } }) t.Run("DSN with unknown scheme", func(t *testing.T) { envFn := func(key string) string { if key == "LLM_X" { return "nope://tok@host.com" } return "" } r, _, _ := testRegistry(envFn) _, err := r.Parse("x/model") if err == nil { t.Fatal("expected error for unknown DSN scheme") } if !errors.Is(err, ErrUnknownProvider) { t.Errorf("error = %v, want ErrUnknownProvider", err) } }) t.Run("DSN with invalid format", func(t *testing.T) { envFn := func(key string) string { if key == "LLM_BAD" { return "no-scheme" } return "" } r, _, _ := testRegistry(envFn) _, err := r.Parse("bad/model") if err == nil { t.Fatal("expected error for invalid DSN") } if !errors.Is(err, ErrInvalidDSN) { t.Errorf("error = %v, want ErrInvalidDSN", err) } }) t.Run("createClient reads env key", func(t *testing.T) { var capturedKey string r, _, _ := testRegistry(func(key string) string { if key == "MY_KEY" { return "secret-123" } return "" }) rec := &recordingProvider{} r.RegisterProvider(ProviderInfo{ Name: "keyed", EnvKey: "MY_KEY", New: func(apiKey string, opts ...ClientOption) *Client { capturedKey = apiKey return NewClient(rec) }, }) _, err := r.Parse("keyed/m") if err != nil { t.Fatalf("unexpected error: %v", err) } if capturedKey != "secret-123" { t.Errorf("apiKey = %q, want %q", capturedKey, "secret-123") } }) } // --------------------------------------------------------------------------- // Backward compatibility: package-level Providers() and ProviderByName() // --------------------------------------------------------------------------- func TestBackwardCompat(t *testing.T) { t.Run("Providers includes foreman", func(t *testing.T) { providers := Providers() found := false for _, p := range providers { if p.Name == "foreman" { found = true break } } if !found { t.Error("Providers() does not include foreman") } }) t.Run("ProviderByName openai", func(t *testing.T) { p := ProviderByName("openai") if p == nil { t.Fatal("ProviderByName(\"openai\") returned nil") } if p.DisplayName != "OpenAI" { t.Errorf("DisplayName = %q, want %q", p.DisplayName, "OpenAI") } }) t.Run("ProviderByName foreman", func(t *testing.T) { p := ProviderByName("foreman") if p == nil { t.Fatal("ProviderByName(\"foreman\") returned nil") } if p.DisplayName != "Foreman" { t.Errorf("DisplayName = %q, want %q", p.DisplayName, "Foreman") } }) t.Run("ProviderByName unknown", func(t *testing.T) { p := ProviderByName("does-not-exist") if p != nil { t.Errorf("expected nil for unknown provider, got %+v", p) } }) t.Run("Providers count", func(t *testing.T) { providers := Providers() // providerRegistry has 9 entries (openai, anthropic, google, deepseek, // moonshot, xai, groq, ollama, ollama-cloud) + foreman = 10 if len(providers) < 10 { t.Errorf("len(Providers()) = %d, want >= 10", len(providers)) } }) } // --------------------------------------------------------------------------- // NewRegistry isolation // --------------------------------------------------------------------------- func TestNewRegistryIsolation(t *testing.T) { r1 := NewRegistry() r2 := NewRegistry() r1.RegisterAlias("only-in-r1", "openai/gpt-4o") // r2 should not see r1's alias r2.mu.RLock() _, found := r2.aliases["only-in-r1"] r2.mu.RUnlock() if found { t.Error("alias registered in r1 should not appear in r2") } } // --------------------------------------------------------------------------- // Comma-separated failover chains // --------------------------------------------------------------------------- func TestParse_CommaProducesFailover(t *testing.T) { resetHealthForTest() r, alpha, beta := testRegistry(func(string) string { return "" }) m, err := r.Parse("alpha/model-a,beta/model-b") if err != nil { t.Fatalf("Parse failover spec: %v", err) } fp, ok := m.provider.(*failoverProvider) if !ok { t.Fatalf("expected *failoverProvider, got %T", m.provider) } if len(fp.entries) != 2 { t.Fatalf("expected 2 entries, got %d", len(fp.entries)) } if fp.entries[0].specKey != "alpha/model-a" || fp.entries[1].specKey != "beta/model-b" { t.Errorf("unexpected specKeys: %q, %q", fp.entries[0].specKey, fp.entries[1].specKey) } // Complete routes to the first provider and passes the bare model name. _, err = m.Complete(context.Background(), []Message{{Role: RoleUser, Content: Content{Text: "hi"}}}) if err != nil { t.Fatal(err) } if alpha.lastModel != "model-a" { t.Errorf("alpha got model %q, want model-a", alpha.lastModel) } if beta.lastModel != "" { t.Errorf("beta should not have been called, got model %q", beta.lastModel) } } func TestParse_NoCommaUnchanged(t *testing.T) { resetHealthForTest() r, _, _ := testRegistry(func(string) string { return "" }) m, err := r.Parse("alpha/model-a") if err != nil { t.Fatal(err) } if _, ok := m.provider.(*failoverProvider); ok { t.Error("single (comma-free) spec must NOT produce a failover provider") } } func TestParse_CommaSinglePartFallsThrough(t *testing.T) { resetHealthForTest() r, _, _ := testRegistry(func(string) string { return "" }) // Trailing comma / whitespace collapses to a single real part. m, err := r.Parse("alpha/model-a, ") if err != nil { t.Fatal(err) } if _, ok := m.provider.(*failoverProvider); ok { t.Error("a single effective part must not produce a failover provider") } } func TestParse_CommaFlattensNested(t *testing.T) { resetHealthForTest() r, _, _ := testRegistry(func(string) string { return "" }) // Register an alias that is itself a comma chain. r.RegisterAlias("pair", "alpha/model-a,beta/model-b") m, err := r.Parse("pair,beta/model-b2") if err != nil { t.Fatal(err) } fp, ok := m.provider.(*failoverProvider) if !ok { t.Fatalf("expected *failoverProvider, got %T", m.provider) } if len(fp.entries) != 3 { t.Errorf("expected flattened 3 entries, got %d", len(fp.entries)) } }