diff --git a/v2/CLAUDE.md b/v2/CLAUDE.md index 840ca28..9dc8d38 100644 --- a/v2/CLAUDE.md +++ b/v2/CLAUDE.md @@ -32,3 +32,15 @@ 6. Streaming via pull-based `StreamReader.Next()` 7. Middleware for logging, retry, timeout, usage tracking 8. Ollama uses the native `/api/chat` API rather than the OpenAI-compat `/v1` endpoint. Native API supports `think: false` for thinking-capable models, has more reliable tool calling, and is approximately 15-20% lower latency. Both local and cloud share the same provider; only the apiKey/baseURL differ. `llm.Ollama()` targets `http://localhost:11434` with no Authorization header; `llm.OllamaCloud(key)` targets `https://ollama.com` with `Authorization: Bearer `. + +### DD#9 — Parse() function and extensible Registry (2026-05-23) +**Context:** mort's ParseModelRequest resolves "provider/model" strings but is +mort-specific. Multi-instance providers (foreman) need named targets. +**Decision:** Add `llm.Parse(spec)` backed by an extensible `Registry`. Supports +aliases, dynamic resolvers, and `LLM_X` env var DSNs for named targets. +Provider/model syntax: `"openai/gpt-4o"`, aliases: `"fast"`, named targets: +`"m5/qwen3:30b"` (reads `LLM_M5` env var). Registry is extensible via +`RegisterProvider`, `RegisterAlias`, `RegisterResolver`. +**Consequence:** Any go-llm consumer gets model-string parsing. mort migrates by +registering its tier aliases as resolvers. Foreman instances are addressed via +`LLM_X` DSN env vars without code changes. diff --git a/v2/constructors.go b/v2/constructors.go index c819505..70c5994 100644 --- a/v2/constructors.go +++ b/v2/constructors.go @@ -119,6 +119,22 @@ func Ollama(opts ...ClientOption) *Client { return NewClient(ollamaProvider.New("", cfg.baseURL)) } +// Foreman creates a client targeting a foreman daemon (a private, authenticated +// Ollama endpoint with queuing and observability). The token is sent as a Bearer +// token; pass "" for unauthenticated (network-trusted) deployments. Use +// WithBaseURL to set the foreman host URL. +// +// Example: +// +// model := llm.Foreman("my-token", llm.WithBaseURL("https://foreman.local")).Model("qwen3:30b") +func Foreman(token string, opts ...ClientOption) *Client { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + return NewClient(ollamaProvider.New(token, cfg.baseURL)) +} + // OllamaCloud creates a client targeting Ollama Cloud (https://ollama.com). // The apiKey is required and is sent as `Authorization: Bearer `. Use // WithBaseURL to point at a private Ollama deployment that requires auth. diff --git a/v2/errors.go b/v2/errors.go index 52db82c..1fe5c2b 100644 --- a/v2/errors.go +++ b/v2/errors.go @@ -17,4 +17,16 @@ var ( // ErrNoStructuredOutput is returned when the model did not return a structured output tool call. ErrNoStructuredOutput = errors.New("model did not return structured output") + + // ErrAliasLoop is returned when alias resolution exceeds the maximum depth (10), + // indicating a cycle such as "a" → "b" → "a". + ErrAliasLoop = errors.New("alias resolution loop detected (depth > 10)") + + // ErrUnknownProvider is returned when a spec references a provider name that + // is not registered and has no corresponding LLM_X environment variable. + ErrUnknownProvider = errors.New("unknown provider") + + // ErrInvalidDSN is returned when a DSN string (from an LLM_X env var) cannot + // be parsed. Expected format: scheme://[token@]host[/path]. + ErrInvalidDSN = errors.New("invalid DSN") ) diff --git a/v2/parse.go b/v2/parse.go new file mode 100644 index 0000000..3b2b5d3 --- /dev/null +++ b/v2/parse.go @@ -0,0 +1,226 @@ +package llm + +import ( + "fmt" + "os" + "strings" +) + +// DSN represents a parsed Data Source Name for an LLM provider endpoint. +// Format: scheme://[token@]host[/path] +// +// Why: multi-instance providers (e.g., multiple foreman daemons) need a compact +// way to encode scheme, credentials, and host in a single env var. +// What: holds the three components after parsing. +// Test: ParseDSN("foreman://tok@host") → {Scheme:"foreman", Token:"tok", Host:"host"}. +type DSN struct { + Scheme string // provider type: "foreman", "ollama", etc. + Token string // API key / bearer token; empty = none + Host string // hostname[:port][/path], no scheme prefix +} + +// ParseDSN parses a raw DSN string into its components. +// Expected format: scheme://[token@]host[/path] +// +// Why: LLM_X env vars encode provider type, optional credentials, and host +// in a single string; this function decodes that. +// What: splits on "://", then optional "@" for token, remainder is host. +// Test: valid DSNs parse correctly, missing scheme or host returns ErrInvalidDSN. +func ParseDSN(raw string) (DSN, error) { + schemeEnd := strings.Index(raw, "://") + if schemeEnd < 0 { + return DSN{}, fmt.Errorf("%w: missing scheme://: %q", ErrInvalidDSN, raw) + } + scheme := raw[:schemeEnd] + rest := raw[schemeEnd+3:] + + var token, host string + if atIdx := strings.Index(rest, "@"); atIdx >= 0 { + token = rest[:atIdx] + host = rest[atIdx+1:] + } else { + host = rest + } + host = strings.TrimRight(host, "/") + if host == "" { + return DSN{}, fmt.Errorf("%w: missing host: %q", ErrInvalidDSN, raw) + } + return DSN{Scheme: scheme, Token: token, Host: host}, nil +} + +// splitReasoning strips a trailing ":low", ":medium", or ":high" reasoning +// suffix from a spec string. Only the last colon-delimited segment is +// considered, and only if it matches a known ReasoningLevel. This preserves +// Ollama-style tags like ":30b" or ":14b" which are not reasoning levels. +// +// Why: reasoning level is encoded as a suffix in the spec grammar, but +// colons also appear in Ollama model tags — this function disambiguates. +// What: returns the base string and the extracted level (empty if none). +// Test: "gpt-4o:high" → ("gpt-4o", high); "qwen3:30b" → ("qwen3:30b", ""). +func splitReasoning(s string) (string, ReasoningLevel) { + idx := strings.LastIndex(s, ":") + if idx < 0 || idx == len(s)-1 { + return s, "" + } + suffix := ReasoningLevel(s[idx+1:]) + switch suffix { + case ReasoningLow, ReasoningMedium, ReasoningHigh: + return s[:idx], suffix + } + return s, "" +} + +// Parse resolves a spec string to a ready-to-use *Model using the +// DefaultRegistry. See Registry.Parse for the full grammar and resolution +// order. +// +// Why: provides a one-call entry point for resolving model strings without +// requiring callers to interact with the Registry directly. +// What: delegates to DefaultRegistry.Parse. +// Test: Parse("openai/gpt-4o") returns a non-nil *Model. +func Parse(spec string) (*Model, error) { + return DefaultRegistry.Parse(spec) +} + +// Parse resolves a spec string to a ready-to-use *Model. +// +// Spec grammar: +// +// spec = alias | provider "/" model | envname "/" model +// alias = name (registered via RegisterAlias or matched by a Resolver) +// provider = registered-name (e.g., "openai", "foreman") +// envname = name (resolved via LLM_{UPPER(name)} env var containing a DSN) +// model = everything after the first "/" +// +// Any spec may carry a ":reasoning" suffix (":low", ":medium", ":high") after +// the last colon. Ollama-style tags like ":30b" are NOT consumed as reasoning. +// +// Resolution order: +// 1. Strip reasoning suffix +// 2. Check static aliases → recurse +// 3. Check dynamic resolvers → recurse +// 4. Split on first "/" → provider/model +// 5. Look up provider in registry +// 6. Look up LLM_{UPPER(left)} env var → parse DSN → create client +// 7. Return ErrUnknownProvider +// +// Why: consumers need a single function to go from a user-supplied string +// (CLI flag, config file, database row) to a ready-to-use Model. +// What: walks the resolution chain and returns the final *Model with reasoning applied. +// Test: see parse_test.go for comprehensive table-driven tests. +func (r *Registry) Parse(spec string) (*Model, error) { + m, level, err := r.parse(spec, 0) + if err != nil { + return nil, err + } + if level != "" { + m = m.WithReasoning(level) + } + return m, nil +} + +// parse is the internal recursive resolver. depth is bounded to prevent +// alias loops. +func (r *Registry) parse(spec string, depth int) (*Model, ReasoningLevel, error) { + if depth > 10 { + return nil, "", ErrAliasLoop + } + + // 1. Strip reasoning suffix. + base, userLevel := splitReasoning(spec) + + // 2. Check static aliases. + r.mu.RLock() + target, isAlias := r.aliases[base] + r.mu.RUnlock() + if isAlias { + m, aliasLevel, err := r.parse(target, depth+1) + if err != nil { + return nil, "", err + } + if userLevel != "" { + return m, userLevel, nil + } + return m, aliasLevel, nil + } + + // 3. Check dynamic resolvers (copy slice under lock to avoid holding + // the lock while calling back — resolvers may access the registry). + r.mu.RLock() + resolvers := make([]Resolver, len(r.resolvers)) + copy(resolvers, r.resolvers) + r.mu.RUnlock() + for _, res := range resolvers { + if resolved, defaultLevel, ok := res.Resolve(base); ok { + if resolved == "" { + return nil, "", fmt.Errorf("resolver returned empty spec for %q", base) + } + m, resolvedLevel, err := r.parse(resolved, depth+1) + if err != nil { + return nil, "", err + } + level := resolvedLevel + if defaultLevel != "" && level == "" { + level = defaultLevel + } + if userLevel != "" { + level = userLevel + } + return m, level, nil + } + } + + // 4. Split on first "/". + slashIdx := strings.Index(base, "/") + if slashIdx < 0 { + return nil, "", fmt.Errorf("%w: %q is not an alias and has no provider/ prefix", ErrUnknownProvider, spec) + } + left := base[:slashIdx] + right := base[slashIdx+1:] + + // 5. Look up in provider registry. + if info := r.ProviderByName(left); info != nil { + client := r.createClient(info) + return client.Model(right), userLevel, nil + } + + // 6. Check LLM_{UPPER(left)} env var. + envKey := "LLM_" + strings.ToUpper(strings.ReplaceAll(left, "-", "_")) + lookup := r.envLookup + if lookup == nil { + lookup = os.Getenv + } + envVal := lookup(envKey) + if envVal == "" { + return nil, "", fmt.Errorf("%w: %q (checked registry and %s env var)", ErrUnknownProvider, left, envKey) + } + dsn, err := ParseDSN(envVal) + if err != nil { + return nil, "", fmt.Errorf("parse %s: %w", envKey, err) + } + schemeInfo := r.ProviderByName(dsn.Scheme) + if schemeInfo == nil { + return nil, "", fmt.Errorf("%w: DSN scheme %q in %s is not a registered provider", ErrUnknownProvider, dsn.Scheme, envKey) + } + url := "https://" + dsn.Host + client := schemeInfo.New(dsn.Token, WithBaseURL(url)) + return client.Model(right), userLevel, nil +} + +// createClient builds a Client from a ProviderInfo, reading the API key from +// the environment when the provider specifies an EnvKey. +// +// Why: avoids duplicating env-var lookup logic across parse paths. +// What: reads the env var (if any), calls info.New with the key. +// Test: indirectly tested via Parse("openai/gpt-4o") with injected envLookup. +func (r *Registry) createClient(info *ProviderInfo) *Client { + apiKey := "" + if info.EnvKey != "" { + lookup := r.envLookup + if lookup == nil { + lookup = os.Getenv + } + apiKey = lookup(info.EnvKey) + } + return info.New(apiKey) +} diff --git a/v2/parse_test.go b/v2/parse_test.go new file mode 100644 index 0000000..73c15b4 --- /dev/null +++ b/v2/parse_test.go @@ -0,0 +1,639 @@ +package llm + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// recordingProvider captures the model name passed to Complete so tests can +// verify that Parse resolved to the correct model without network calls. +type recordingProvider struct { + lastModel string +} + +func (p *recordingProvider) Complete(_ context.Context, req provider.Request) (provider.Response, error) { + p.lastModel = req.Model + return provider.Response{Text: "ok"}, nil +} + +func (p *recordingProvider) Stream(_ context.Context, _ provider.Request, events chan<- provider.StreamEvent) error { + close(events) + return nil +} + +// testRegistry builds a Registry with two mock providers ("alpha" and "beta") +// and an injectable envLookup. No real API keys or network access required. +func testRegistry(envFn func(string) string) (*Registry, *recordingProvider, *recordingProvider) { + alpha := &recordingProvider{} + beta := &recordingProvider{} + + r := &Registry{ + providers: map[string]ProviderInfo{ + "alpha": { + Name: "alpha", + DisplayName: "Alpha", + EnvKey: "ALPHA_API_KEY", + Models: []string{"model-a"}, + New: func(apiKey string, opts ...ClientOption) *Client { + return NewClient(alpha) + }, + }, + "beta": { + Name: "beta", + DisplayName: "Beta", + EnvKey: "", + Models: []string{"model-b"}, + New: func(_ string, opts ...ClientOption) *Client { + return NewClient(beta) + }, + }, + }, + order: []string{"alpha", "beta"}, + aliases: make(map[string]string), + envLookup: envFn, + } + return r, alpha, beta +} + +// --------------------------------------------------------------------------- +// splitReasoning +// --------------------------------------------------------------------------- + +func TestSplitReasoning(t *testing.T) { + tests := []struct { + input string + wantBase string + wantLevel ReasoningLevel + }{ + {"gpt-4o", "gpt-4o", ""}, + {"gpt-4o:high", "gpt-4o", ReasoningHigh}, + {"gpt-4o:low", "gpt-4o", ReasoningLow}, + {"gpt-4o:medium", "gpt-4o", ReasoningMedium}, + {"qwen3:30b", "qwen3:30b", ""}, // Ollama tag, not a level + {"qwen3:30b:high", "qwen3:30b", ReasoningHigh}, // tag + reasoning + {"model:", "model:", ""}, // trailing colon, empty suffix + {"", "", ""}, // empty string + {"a:b:c:high", "a:b:c", ReasoningHigh}, // multiple colons + {"a:b:c:30b", "a:b:c:30b", ""}, // multiple colons, non-level + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + base, level := splitReasoning(tt.input) + if base != tt.wantBase { + t.Errorf("splitReasoning(%q) base = %q, want %q", tt.input, base, tt.wantBase) + } + if level != tt.wantLevel { + t.Errorf("splitReasoning(%q) level = %q, want %q", tt.input, level, tt.wantLevel) + } + }) + } +} + +// --------------------------------------------------------------------------- +// ParseDSN +// --------------------------------------------------------------------------- + +func TestParseDSN(t *testing.T) { + tests := []struct { + name string + raw string + want DSN + wantErr bool + }{ + { + name: "full DSN with token", + raw: "foreman://test-token@foreman-m5.orgrimmar.dudenhoeffer.casa", + want: DSN{Scheme: "foreman", Token: "test-token", Host: "foreman-m5.orgrimmar.dudenhoeffer.casa"}, + }, + { + name: "no token", + raw: "ollama://localhost:11434", + want: DSN{Scheme: "ollama", Token: "", Host: "localhost:11434"}, + }, + { + name: "trailing slash stripped", + raw: "foreman://tok@host.com/", + want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com"}, + }, + { + name: "with path", + raw: "foreman://tok@host.com/v1/api", + want: DSN{Scheme: "foreman", Token: "tok", Host: "host.com/v1/api"}, + }, + { + name: "missing scheme", + raw: "no-scheme-here", + wantErr: true, + }, + { + name: "missing host", + raw: "foreman://token@", + wantErr: true, + }, + { + name: "empty string", + raw: "", + wantErr: true, + }, + { + name: "scheme only", + raw: "foreman://", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseDSN(tt.raw) + if tt.wantErr { + if err == nil { + t.Fatalf("ParseDSN(%q) expected error, got nil", tt.raw) + } + if !errors.Is(err, ErrInvalidDSN) { + t.Errorf("ParseDSN(%q) error = %v, want ErrInvalidDSN", tt.raw, err) + } + return + } + if err != nil { + t.Fatalf("ParseDSN(%q) unexpected error: %v", tt.raw, err) + } + if got != tt.want { + t.Errorf("ParseDSN(%q) = %+v, want %+v", tt.raw, got, tt.want) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Registry.Parse +// --------------------------------------------------------------------------- + +func TestRegistryParse(t *testing.T) { + t.Run("basic provider/model", func(t *testing.T) { + r, alpha, _ := testRegistry(nil) + m, err := r.Parse("alpha/model-a") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m == nil { + t.Fatal("expected non-nil model") + } + // Verify the model name by exercising it. + _, _ = m.Complete(context.Background(), nil) + if alpha.lastModel != "model-a" { + t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") + } + }) + + t.Run("provider/model with reasoning suffix", func(t *testing.T) { + r, alpha, _ := testRegistry(nil) + m, err := r.Parse("alpha/model-a:high") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m.defaultReasoning != ReasoningHigh { + t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh) + } + _, _ = m.Complete(context.Background(), nil) + if alpha.lastModel != "model-a" { + t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") + } + }) + + t.Run("ollama-style tag preserved", func(t *testing.T) { + r, _, beta := testRegistry(nil) + m, err := r.Parse("beta/qwen3:30b") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m.defaultReasoning != "" { + t.Errorf("reasoning = %q, want empty (30b is not a level)", m.defaultReasoning) + } + _, _ = m.Complete(context.Background(), nil) + if beta.lastModel != "qwen3:30b" { + t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b") + } + }) + + t.Run("ollama-style tag with reasoning", func(t *testing.T) { + r, _, beta := testRegistry(nil) + m, err := r.Parse("beta/qwen3:30b:high") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m.defaultReasoning != ReasoningHigh { + t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh) + } + _, _ = m.Complete(context.Background(), nil) + if beta.lastModel != "qwen3:30b" { + t.Errorf("model = %q, want %q", beta.lastModel, "qwen3:30b") + } + }) + + t.Run("static alias", func(t *testing.T) { + r, alpha, _ := testRegistry(nil) + r.RegisterAlias("fast", "alpha/model-a") + m, err := r.Parse("fast") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, _ = m.Complete(context.Background(), nil) + if alpha.lastModel != "model-a" { + t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") + } + }) + + t.Run("alias with embedded reasoning", func(t *testing.T) { + r, alpha, _ := testRegistry(nil) + r.RegisterAlias("thinking", "alpha/model-a:high") + m, err := r.Parse("thinking") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m.defaultReasoning != ReasoningHigh { + t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh) + } + _, _ = m.Complete(context.Background(), nil) + if alpha.lastModel != "model-a" { + t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") + } + }) + + t.Run("user reasoning overrides alias default", func(t *testing.T) { + r, _, _ := testRegistry(nil) + r.RegisterAlias("thinking", "alpha/model-a:high") + m, err := r.Parse("thinking:low") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m.defaultReasoning != ReasoningLow { + t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow) + } + }) + + t.Run("dynamic resolver", func(t *testing.T) { + r, alpha, _ := testRegistry(nil) + r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) { + if name == "custom" { + return "alpha/resolved-model", "", true + } + return "", "", false + })) + m, err := r.Parse("custom") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, _ = m.Complete(context.Background(), nil) + if alpha.lastModel != "resolved-model" { + t.Errorf("model = %q, want %q", alpha.lastModel, "resolved-model") + } + }) + + t.Run("resolver with default reasoning", func(t *testing.T) { + r, _, _ := testRegistry(nil) + r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) { + if name == "smart" { + return "alpha/model-a", ReasoningHigh, true + } + return "", "", false + })) + m, err := r.Parse("smart") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m.defaultReasoning != ReasoningHigh { + t.Errorf("reasoning = %q, want %q", m.defaultReasoning, ReasoningHigh) + } + }) + + t.Run("resolver default reasoning overridden by user", func(t *testing.T) { + r, _, _ := testRegistry(nil) + r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) { + if name == "smart" { + return "alpha/model-a", ReasoningHigh, true + } + return "", "", false + })) + m, err := r.Parse("smart:low") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m.defaultReasoning != ReasoningLow { + t.Errorf("reasoning = %q, want %q (user override)", m.defaultReasoning, ReasoningLow) + } + }) + + t.Run("resolver returns empty spec", func(t *testing.T) { + r, _, _ := testRegistry(nil) + r.RegisterResolver(ResolverFunc(func(name string) (string, ReasoningLevel, bool) { + if name == "bad" { + return "", "", true + } + return "", "", false + })) + _, err := r.Parse("bad") + if err == nil { + t.Fatal("expected error for empty resolver spec") + } + }) + + t.Run("LLM_X env var with token", func(t *testing.T) { + envFn := func(key string) string { + if key == "LLM_M5" { + return "alpha://mytoken@myhost.com" + } + return "" + } + r, alpha, _ := testRegistry(envFn) + m, err := r.Parse("m5/qwen3:30b") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, _ = m.Complete(context.Background(), nil) + if alpha.lastModel != "qwen3:30b" { + t.Errorf("model = %q, want %q", alpha.lastModel, "qwen3:30b") + } + }) + + t.Run("LLM_X env var without token", func(t *testing.T) { + envFn := func(key string) string { + if key == "LLM_LOCAL" { + return "beta://localhost:11434" + } + return "" + } + r, _, beta := testRegistry(envFn) + m, err := r.Parse("local/llama3.2") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, _ = m.Complete(context.Background(), nil) + if beta.lastModel != "llama3.2" { + t.Errorf("model = %q, want %q", beta.lastModel, "llama3.2") + } + }) + + t.Run("LLM_X with hyphenated name", func(t *testing.T) { + envFn := func(key string) string { + if key == "LLM_MY_BOX" { + return "alpha://tok@host.com" + } + return "" + } + r, alpha, _ := testRegistry(envFn) + m, err := r.Parse("my-box/some-model") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, _ = m.Complete(context.Background(), nil) + if alpha.lastModel != "some-model" { + t.Errorf("model = %q, want %q", alpha.lastModel, "some-model") + } + }) + + t.Run("unknown provider no env var", func(t *testing.T) { + r, _, _ := testRegistry(func(string) string { return "" }) + _, err := r.Parse("unknown/model") + if err == nil { + t.Fatal("expected error for unknown provider") + } + if !errors.Is(err, ErrUnknownProvider) { + t.Errorf("error = %v, want ErrUnknownProvider", err) + } + }) + + t.Run("bare unknown name no slash", func(t *testing.T) { + r, _, _ := testRegistry(nil) + _, err := r.Parse("nonexistent") + if err == nil { + t.Fatal("expected error for bare unknown name") + } + if !errors.Is(err, ErrUnknownProvider) { + t.Errorf("error = %v, want ErrUnknownProvider", err) + } + }) + + t.Run("alias loop", func(t *testing.T) { + r, _, _ := testRegistry(nil) + r.RegisterAlias("a", "b") + r.RegisterAlias("b", "a") + _, err := r.Parse("a") + if err == nil { + t.Fatal("expected error for alias loop") + } + if !errors.Is(err, ErrAliasLoop) { + t.Errorf("error = %v, want ErrAliasLoop", err) + } + }) + + t.Run("deep alias chain within limit", func(t *testing.T) { + r, alpha, _ := testRegistry(nil) + // chain: a0 → a1 → a2 → ... → a9 → alpha/model-a (depth 10 is fine, >10 is not) + r.RegisterAlias("a9", "alpha/model-a") + for i := 8; i >= 0; i-- { + r.RegisterAlias( + "a"+string(rune('0'+i)), + "a"+string(rune('0'+i+1)), + ) + } + m, err := r.Parse("a0") + if err != nil { + t.Fatalf("unexpected error for deep but valid chain: %v", err) + } + _, _ = m.Complete(context.Background(), nil) + if alpha.lastModel != "model-a" { + t.Errorf("model = %q, want %q", alpha.lastModel, "model-a") + } + }) + + t.Run("RegisterProvider replaces existing", func(t *testing.T) { + replaced := &recordingProvider{} + r, _, _ := testRegistry(nil) + r.RegisterProvider(ProviderInfo{ + Name: "alpha", + DisplayName: "Alpha Replaced", + New: func(_ string, _ ...ClientOption) *Client { + return NewClient(replaced) + }, + }) + m, err := r.Parse("alpha/new-model") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, _ = m.Complete(context.Background(), nil) + if replaced.lastModel != "new-model" { + t.Errorf("model = %q, want %q", replaced.lastModel, "new-model") + } + // Verify order is preserved (alpha is still at index 0). + providers := r.Providers() + if providers[0].Name != "alpha" { + t.Errorf("first provider = %q, want %q (order preserved)", providers[0].Name, "alpha") + } + if providers[0].DisplayName != "Alpha Replaced" { + t.Errorf("display name = %q, want %q", providers[0].DisplayName, "Alpha Replaced") + } + }) + + t.Run("RegisterProvider adds new provider", func(t *testing.T) { + gamma := &recordingProvider{} + r, _, _ := testRegistry(nil) + r.RegisterProvider(ProviderInfo{ + Name: "gamma", + DisplayName: "Gamma", + New: func(_ string, _ ...ClientOption) *Client { + return NewClient(gamma) + }, + }) + m, err := r.Parse("gamma/g-model") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + _, _ = m.Complete(context.Background(), nil) + if gamma.lastModel != "g-model" { + t.Errorf("model = %q, want %q", gamma.lastModel, "g-model") + } + // New provider should be last. + providers := r.Providers() + last := providers[len(providers)-1] + if last.Name != "gamma" { + t.Errorf("last provider = %q, want %q", last.Name, "gamma") + } + }) + + t.Run("DSN with unknown scheme", func(t *testing.T) { + envFn := func(key string) string { + if key == "LLM_X" { + return "nope://tok@host.com" + } + return "" + } + r, _, _ := testRegistry(envFn) + _, err := r.Parse("x/model") + if err == nil { + t.Fatal("expected error for unknown DSN scheme") + } + if !errors.Is(err, ErrUnknownProvider) { + t.Errorf("error = %v, want ErrUnknownProvider", err) + } + }) + + t.Run("DSN with invalid format", func(t *testing.T) { + envFn := func(key string) string { + if key == "LLM_BAD" { + return "no-scheme" + } + return "" + } + r, _, _ := testRegistry(envFn) + _, err := r.Parse("bad/model") + if err == nil { + t.Fatal("expected error for invalid DSN") + } + if !errors.Is(err, ErrInvalidDSN) { + t.Errorf("error = %v, want ErrInvalidDSN", err) + } + }) + + t.Run("createClient reads env key", func(t *testing.T) { + var capturedKey string + r, _, _ := testRegistry(func(key string) string { + if key == "MY_KEY" { + return "secret-123" + } + return "" + }) + rec := &recordingProvider{} + r.RegisterProvider(ProviderInfo{ + Name: "keyed", + EnvKey: "MY_KEY", + New: func(apiKey string, opts ...ClientOption) *Client { + capturedKey = apiKey + return NewClient(rec) + }, + }) + _, err := r.Parse("keyed/m") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if capturedKey != "secret-123" { + t.Errorf("apiKey = %q, want %q", capturedKey, "secret-123") + } + }) +} + +// --------------------------------------------------------------------------- +// Backward compatibility: package-level Providers() and ProviderByName() +// --------------------------------------------------------------------------- + +func TestBackwardCompat(t *testing.T) { + t.Run("Providers includes foreman", func(t *testing.T) { + providers := Providers() + found := false + for _, p := range providers { + if p.Name == "foreman" { + found = true + break + } + } + if !found { + t.Error("Providers() does not include foreman") + } + }) + + t.Run("ProviderByName openai", func(t *testing.T) { + p := ProviderByName("openai") + if p == nil { + t.Fatal("ProviderByName(\"openai\") returned nil") + } + if p.DisplayName != "OpenAI" { + t.Errorf("DisplayName = %q, want %q", p.DisplayName, "OpenAI") + } + }) + + t.Run("ProviderByName foreman", func(t *testing.T) { + p := ProviderByName("foreman") + if p == nil { + t.Fatal("ProviderByName(\"foreman\") returned nil") + } + if p.DisplayName != "Foreman" { + t.Errorf("DisplayName = %q, want %q", p.DisplayName, "Foreman") + } + }) + + t.Run("ProviderByName unknown", func(t *testing.T) { + p := ProviderByName("does-not-exist") + if p != nil { + t.Errorf("expected nil for unknown provider, got %+v", p) + } + }) + + t.Run("Providers count", func(t *testing.T) { + providers := Providers() + // providerRegistry has 9 entries (openai, anthropic, google, deepseek, + // moonshot, xai, groq, ollama, ollama-cloud) + foreman = 10 + if len(providers) < 10 { + t.Errorf("len(Providers()) = %d, want >= 10", len(providers)) + } + }) +} + +// --------------------------------------------------------------------------- +// NewRegistry isolation +// --------------------------------------------------------------------------- + +func TestNewRegistryIsolation(t *testing.T) { + r1 := NewRegistry() + r2 := NewRegistry() + + r1.RegisterAlias("only-in-r1", "openai/gpt-4o") + + // r2 should not see r1's alias + r2.mu.RLock() + _, found := r2.aliases["only-in-r1"] + r2.mu.RUnlock() + if found { + t.Error("alias registered in r1 should not appear in r2") + } +} diff --git a/v2/registry.go b/v2/registry.go index 4c9f59e..a76dad0 100644 --- a/v2/registry.go +++ b/v2/registry.go @@ -1,6 +1,9 @@ package llm import ( + "os" + "sync" + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/deepseek" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/groq" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/moonshot" @@ -38,7 +41,7 @@ type ProviderInfo struct { // providerRegistry is the in-process list of known providers. Order is // intentional: the three original providers first, then OpenAI-compatible -// additions in the order they were added. +// additions in the order they were added. This slice seeds NewRegistry(). var providerRegistry = []ProviderInfo{ { Name: "openai", @@ -151,24 +154,189 @@ var providerRegistry = []ProviderInfo{ }, New: OllamaCloud, }, + { + Name: "foreman", + DisplayName: "Foreman", + EnvKey: "", // no single env key; discovered via LLM_* DSNs + DefaultURL: "", // always requires a URL + Models: []string{}, + New: Foreman, + }, } +// --------------------------------------------------------------------------- +// Resolver — dynamic model-spec resolution +// --------------------------------------------------------------------------- + +// Resolver resolves an alias or short name to a full spec string. Consumers +// register resolvers for dynamic lookups (e.g., database-backed tier aliases). +type Resolver interface { + // Resolve returns the resolved spec and an optional default reasoning + // level. ok is false when the resolver does not handle this name. + Resolve(name string) (spec string, defaultReasoning ReasoningLevel, ok bool) +} + +// ResolverFunc adapts a plain function to the Resolver interface. +type ResolverFunc func(name string) (string, ReasoningLevel, bool) + +// Resolve implements Resolver by calling the underlying function. +func (f ResolverFunc) Resolve(name string) (string, ReasoningLevel, bool) { + return f(name) +} + +// --------------------------------------------------------------------------- +// Registry — extensible provider/alias/resolver store +// --------------------------------------------------------------------------- + +// Registry holds providers, static aliases, and dynamic resolvers. Use +// NewRegistry to create one pre-populated with the built-in providers, or +// use the package-level DefaultRegistry. +type Registry struct { + mu sync.RWMutex + providers map[string]ProviderInfo + order []string // insertion order for Providers() + aliases map[string]string + resolvers []Resolver + envLookup func(string) string // defaults to os.Getenv +} + +// NewRegistry creates a Registry pre-populated with all built-in providers +// (the same set returned by the providerRegistry package variable). +// +// Why: provides a fresh, isolated registry for testing or multi-tenant +// scenarios while reusing the canonical provider list. +// What: copies every entry from providerRegistry into a new Registry. +// Test: call NewRegistry(), verify Providers() length matches providerRegistry +// and ProviderByName("openai") is non-nil. +func NewRegistry() *Registry { + r := &Registry{ + providers: make(map[string]ProviderInfo, len(providerRegistry)), + order: make([]string, 0, len(providerRegistry)), + aliases: make(map[string]string), + envLookup: os.Getenv, + } + for _, info := range providerRegistry { + r.providers[info.Name] = info + r.order = append(r.order, info.Name) + } + return r +} + +// DefaultRegistry is the package-level registry used by the convenience +// functions Parse, Providers, ProviderByName, RegisterProvider, RegisterAlias, +// and RegisterResolver. Initialized in init() with all built-in providers. +var DefaultRegistry *Registry + +func init() { + DefaultRegistry = NewRegistry() +} + +// RegisterProvider adds or replaces a provider in the registry. When +// replacing, the provider keeps its original position in the ordered list. +// +// Why: allows consumers to override built-in factories (e.g., wrapping with +// middleware) or add entirely new providers at runtime. +// What: upserts info by Name into the provider map and order slice. +// Test: register a custom "openai" factory, verify ProviderByName returns it. +func (r *Registry) RegisterProvider(info ProviderInfo) { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.providers[info.Name]; !exists { + r.order = append(r.order, info.Name) + } + r.providers[info.Name] = info +} + +// RegisterAlias maps a short name to a full spec string. The spec is resolved +// recursively by Parse, so an alias can point to another alias or to a +// "provider/model" string. +// +// Why: lets consumers define convenient shortcuts like "fast" → "openai/gpt-4o-mini". +// What: stores name→spec in the alias map. +// Test: register "fast" → "openai/gpt-4o-mini", parse "fast", verify model. +func (r *Registry) RegisterAlias(name, spec string) { + r.mu.Lock() + defer r.mu.Unlock() + r.aliases[name] = spec +} + +// RegisterResolver appends a dynamic resolver. Resolvers are checked in +// registration order after static aliases. A resolver may return a spec +// string that is itself an alias or "provider/model" — it will be recursed. +// +// Why: supports dynamic alias sources (databases, remote config) without +// requiring static registration of every possible name. +// What: appends res to the resolver list. +// Test: register a ResolverFunc, parse a name it handles, verify resolution. +func (r *Registry) RegisterResolver(res Resolver) { + r.mu.Lock() + defer r.mu.Unlock() + r.resolvers = append(r.resolvers, res) +} + +// ProviderByName returns the registered ProviderInfo with the given name, or +// nil if no such provider is registered. Name matching is exact. +// +// Why: callers need to look up provider metadata by name for factory calls, +// discovery, and DSN scheme resolution. +// What: returns a copy of the ProviderInfo or nil. +// Test: verify ProviderByName("openai") is non-nil, ProviderByName("nope") is nil. +func (r *Registry) ProviderByName(name string) *ProviderInfo { + r.mu.RLock() + defer r.mu.RUnlock() + + if info, ok := r.providers[name]; ok { + return &info + } + return nil +} + +// Providers returns a copy of all registered providers in insertion order. +// +// Why: CLI pickers and admin tools need the full list for display. +// What: returns a freshly allocated slice of ProviderInfo copies. +// Test: verify length matches expected count after registration. +func (r *Registry) Providers() []ProviderInfo { + r.mu.RLock() + defer r.mu.RUnlock() + + out := make([]ProviderInfo, 0, len(r.order)) + for _, name := range r.order { + if info, ok := r.providers[name]; ok { + out = append(out, info) + } + } + return out +} + +// --------------------------------------------------------------------------- +// Package-level convenience functions — delegate to DefaultRegistry +// --------------------------------------------------------------------------- + // Providers returns a copy of the registered provider list so callers cannot // mutate library state. func Providers() []ProviderInfo { - out := make([]ProviderInfo, len(providerRegistry)) - copy(out, providerRegistry) - return out + return DefaultRegistry.Providers() } // ProviderByName returns the registered ProviderInfo with the given name, or // nil if no such provider is registered. Name matching is exact. func ProviderByName(name string) *ProviderInfo { - for i := range providerRegistry { - if providerRegistry[i].Name == name { - p := providerRegistry[i] - return &p - } - } - return nil + return DefaultRegistry.ProviderByName(name) +} + +// RegisterProvider adds or replaces a provider in the DefaultRegistry. +func RegisterProvider(info ProviderInfo) { + DefaultRegistry.RegisterProvider(info) +} + +// RegisterAlias maps a short name to a full spec in the DefaultRegistry. +func RegisterAlias(name, spec string) { + DefaultRegistry.RegisterAlias(name, spec) +} + +// RegisterResolver appends a dynamic resolver to the DefaultRegistry. +func RegisterResolver(res Resolver) { + DefaultRegistry.RegisterResolver(res) }