diff --git a/CLAUDE.md b/CLAUDE.md index 6c39425..5ed4ecf 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -27,9 +27,15 @@ majordomo Registry, Parse, env-DSN loading, chain executor, re-exports llm/ canonical contract: Message/Part/Request/Response/Option, Tool/Toolbox, Capabilities, Stream, Model, Provider, errors health/ clock-injected health tracker (bench/backoff) - media/ image normalization to target capabilities (Phase 3) + media/ image normalization to target capabilities (sniff real + format, downscale, transcode, byte ladder; ErrUnsupported + for what can't fit) — chains normalize PER TARGET provider/fake/ scriptable in-memory provider for hermetic tests - provider/{openai,anthropic,ollama,google}/ (Phases 3-4) + provider/openai/ Chat Completions client (+ all OpenAI-compat targets) + provider/anthropic/ Messages API client (+ Anthropic-compat targets) + provider/ollama/ one native /api/chat client serving the ollama, + ollama-cloud, and foreman built-ins via presets + provider/google/ Gemini on the official genai SDK (Phase 4) agent/ Agent run loop (Phase 5) skill/ Skill interface + composition (Phase 6) examples/ one runnable example per hard requirement (Phase 7-8) diff --git a/README.md b/README.md index 8356ed8..cccdbbf 100644 --- a/README.md +++ b/README.md @@ -100,12 +100,25 @@ Chains are health-tracked per target: | Provider | Spec name | Key env var | Default endpoint | |----------|-----------|-------------|------------------| -| OpenAI (+compatible) | `openai` | `OPENAI_API_KEY` | api.openai.com *(pending)* | -| Anthropic (+compatible) | `anthropic` | `ANTHROPIC_API_KEY` | api.anthropic.com *(pending)* | +| OpenAI (+compatible) | `openai` | `OPENAI_API_KEY` | https://api.openai.com/v1 | +| Anthropic (+compatible) | `anthropic` | `ANTHROPIC_API_KEY` | https://api.anthropic.com | | Google (Gemini) | `google` | `GOOGLE_API_KEY` / `GEMINI_API_KEY` | Gen AI API *(pending)* | -| Ollama Cloud | `ollama-cloud` | `OLLAMA_API_KEY` | https://ollama.com *(pending)* | -| Ollama (local) | `ollama` | — | `OLLAMA_HOST` or http://localhost:11434 *(pending)* | -| foreman | `foreman` | — (token via DSN) | requires DSN/base URL *(pending)* | +| Ollama Cloud | `ollama-cloud` | `OLLAMA_API_KEY` | https://ollama.com | +| Ollama (local) | `ollama` | — | `OLLAMA_HOST` or http://localhost:11434 | +| foreman | `foreman` | — (token via DSN) | requires an LLM_* DSN or `ollama.Foreman(url, token)` | + +OpenAI-compatible / Anthropic-compatible endpoints: construct the provider +with a name and base URL and register it — + +```go +reg.RegisterProvider(openai.New( + openai.WithName("groq"), + openai.WithBaseURL("https://api.groq.com/openai/v1"), + openai.WithAPIKey(key), + // openai.WithLegacyMaxTokens(), // for servers that only honor max_tokens +)) +// now "groq/llama-3.3-70b" works in Parse, chains, and aliases +``` ### `LLM_*` env-DSN provider definitions @@ -139,11 +152,15 @@ Implement the two-method `Provider` interface and register it: reg.RegisterProvider(myProvider) // now "myprovider/model-x" parses, chains, aliases ``` -## Multimodality *(pending — Phase 3)* +## Multimodality -Attach images without knowing the target's limits; majordomo normalizes -(downscale, re-encode, count/size limits) against the resolved target's -declared capabilities and rejects clearly what cannot fit. +Attach images without knowing the target's limits. Before each attempt the +request is normalized against the **actual serving target's** declared +capabilities: the real format is sniffed from the bytes, oversize images +are downscaled (aspect preserved), disallowed formats are re-encoded, and +byte budgets are enforced by a quality ladder. What cannot be made to fit +is rejected with a clear `ErrUnsupported` error — and in a chain, the +request simply advances to the next (e.g. vision-capable) element. ```go resp, err := m.Generate(ctx, majordomo.Request{ @@ -154,7 +171,7 @@ resp, err := m.Generate(ctx, majordomo.Request{ }) ``` -## Tool calls *(canonical API ready; provider wiring pending — Phase 3)* +## Tool calls ```go weather := majordomo.Tool{ @@ -171,14 +188,20 @@ resp, _ := m.Generate(ctx, req, majordomo.WithTools(weather)) // resp.ToolCalls → execute → append ToolResultsMessage → continue ``` -## Structured output *(canonical API ready; provider wiring pending — Phase 3)* +Each provider maps this one shape to its native function-calling format +(OpenAI tools/tool_calls, Anthropic tool_use/tool_result, Ollama tools with +object arguments). Tool-call ids are synthesized when a backend omits them; +streaming buffers tool-call arguments until they parse. + +## Structured output ```go resp, _ := m.Generate(ctx, req, majordomo.WithSchema(schemaJSON, "answer")) ``` -A generic `Generate[T]` helper (schema from your struct, unmarshal into it) -lands with the agent phase. +Maps to OpenAI `response_format: json_schema`, Anthropic +`output_config.format`, and Ollama `format`. A generic `Generate[T]` helper +(schema from your struct, unmarshal into it) lands with the agent phase. ## Agents & skills *(pending — Phases 5–6)* @@ -189,17 +212,25 @@ skills = reusable instruction+tool bundles attachable to any agent. | Provider | Resolve/Parse | Chat | Streaming | Tools | Structured | Images | Env DSN | |----------------------|:---:|:---:|:---:|:---:|:---:|:---:|:---:| -| OpenAI (+compatible) | ✅ | pending | pending | pending | pending | pending | ✅ | -| Anthropic (+compat) | ✅ | pending | pending | pending | pending | pending | ✅ | +| OpenAI (+compatible) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Anthropic (+compat) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | Google (Gemini) | ✅ | pending | pending | pending | pending | pending | ✅ | -| Ollama Cloud | ✅ | pending | pending | pending | pending | pending | ✅ | -| Ollama (local) | ✅ | pending | pending | pending | pending | pending | ✅ | -| foreman | ✅ | pending | pending | pending | pending | pending | ✅ | +| Ollama Cloud | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Ollama (local) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| foreman | ✅ | ✅ | ✅¹ | ✅ | ✅ | ✅ | ✅ | | fake (testing) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | — | +¹ foreman's daemon currently buffers sync chat responses (no token-by-token +streaming); majordomo's stream API works against it and delivers the +response as a single delta plus final event. + +Notes: Ollama has no native tool_choice — `"none"` drops the tools; +`"required"`/named choices are best-effort ignored there. + Cross-cutting: Parse grammar ✅ · aliases/tiers ✅ · failover chains ✅ · -health tracking/backoff ✅ · LLM_* env DSNs ✅ · media pipeline pending · -agent loop pending · skills pending · `Generate[T]` pending. +health tracking/backoff ✅ · LLM_* env DSNs ✅ · media pipeline ✅ +(per-target normalization in chains) · agent loop pending · skills pending +· `Generate[T]` pending. ## Development diff --git a/builtin.go b/builtin.go index 0b83908..0109097 100644 --- a/builtin.go +++ b/builtin.go @@ -3,8 +3,12 @@ package majordomo import ( "context" "fmt" + "net/http" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" + "gitea.stevedudenhoeffer.com/steve/majordomo/provider/anthropic" + "gitea.stevedudenhoeffer.com/steve/majordomo/provider/ollama" + "gitea.stevedudenhoeffer.com/steve/majordomo/provider/openai" ) // Built-in provider names. Real client implementations land per-phase @@ -20,23 +24,78 @@ const ( ) // registerBuiltins installs the built-in providers and env-DSN scheme -// factories into a fresh registry. -func registerBuiltins(r *Registry) { - stub := func(kind string) SchemeFactory { - return func(name string, dsn DSN) (llm.Provider, error) { - return &stubProvider{name: name, kind: kind, baseURL: dsn.BaseURL(), token: dsn.Token}, nil +// factories into a fresh registry. httpClient, when non-nil, is used by +// every provider and factory the registry itself constructs. +func registerBuiltins(r *Registry, httpClient *http.Client) { + ollamaOpts := func(extra ...ollama.Option) []ollama.Option { + if httpClient != nil { + extra = append(extra, ollama.WithHTTPClient(httpClient)) } + return extra } - for _, kind := range []string{ - ProviderOpenAI, ProviderAnthropic, ProviderGoogle, - ProviderOllama, ProviderOllamaCloud, ProviderForeman, - } { - r.providers[kind] = &stubProvider{name: kind, kind: kind} - r.schemes[kind] = stub(kind) + // Native-Ollama family: three names over one client with presets. + r.providers[ProviderOllama] = ollama.Local(ollamaOpts()...) + r.providers[ProviderOllamaCloud] = ollama.Cloud(ollamaOpts()...) + // foreman has no default URL; the no-DSN registration resolves but + // errors on use with a clear message (use an LLM_* DSN or + // ollama.Foreman(...) + RegisterProvider). + r.providers[ProviderForeman] = ollama.New(ollamaOpts(ollama.WithName(ProviderForeman))...) + + ollamaScheme := func(name string, dsn DSN) (llm.Provider, error) { + return ollama.New(ollamaOpts( + ollama.WithName(name), + ollama.WithBaseURL(dsn.BaseURL()), + ollama.WithToken(dsn.Token), + )...), nil } + r.schemes[ProviderOllama] = ollamaScheme + r.schemes[ProviderOllamaCloud] = ollamaScheme + r.schemes[ProviderForeman] = ollamaScheme + + // OpenAI and OpenAI-compatible endpoints. + openaiOpts := func(extra ...openai.Option) []openai.Option { + if httpClient != nil { + extra = append(extra, openai.WithHTTPClient(httpClient)) + } + return extra + } + r.providers[ProviderOpenAI] = openai.New(openaiOpts()...) + r.schemes[ProviderOpenAI] = func(name string, dsn DSN) (llm.Provider, error) { + return openai.New(openaiOpts( + openai.WithName(name), + openai.WithBaseURL(dsn.BaseURL()), + openai.WithAPIKey(dsn.Token), + )...), nil + } + + // Anthropic and Anthropic-compatible endpoints. + anthropicOpts := func(extra ...anthropic.Option) []anthropic.Option { + if httpClient != nil { + extra = append(extra, anthropic.WithHTTPClient(httpClient)) + } + return extra + } + r.providers[ProviderAnthropic] = anthropic.New(anthropicOpts()...) + r.schemes[ProviderAnthropic] = func(name string, dsn DSN) (llm.Provider, error) { + return anthropic.New(anthropicOpts( + anthropic.WithName(name), + anthropic.WithBaseURL(dsn.BaseURL()), + anthropic.WithAPIKey(dsn.Token), + )...), nil + } + + // Google lands in its own phase; stub until then. + r.providers[ProviderGoogle] = &stubProvider{name: ProviderGoogle, kind: ProviderGoogle} + r.schemes[ProviderGoogle] = stubScheme(ProviderGoogle) // "gemini" is an alternate scheme for the Google provider. - r.schemes["gemini"] = stub(ProviderGoogle) + r.schemes["gemini"] = stubScheme(ProviderGoogle) +} + +func stubScheme(kind string) SchemeFactory { + return func(name string, dsn DSN) (llm.Provider, error) { + return &stubProvider{name: name, kind: kind, baseURL: dsn.BaseURL(), token: dsn.Token}, nil + } } // stubProvider stands in for a provider implementation that lands in a diff --git a/chain.go b/chain.go index 42eb2e0..26928de 100644 --- a/chain.go +++ b/chain.go @@ -7,6 +7,7 @@ import ( "gitea.stevedudenhoeffer.com/steve/majordomo/health" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" + "gitea.stevedudenhoeffer.com/steve/majordomo/media" ) // ErrChainExhausted reports that every element of a failover chain failed @@ -65,8 +66,8 @@ func (c *chain) Capabilities() llm.Capabilities { // Generate tries each target per the chain semantics above. func (c *chain) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) { req = req.Apply(opts...) - return chainDo(ctx, c, func(ctx context.Context, t chainTarget) (*llm.Response, error) { - return t.model.Generate(ctx, req) + return chainDo(ctx, c, req, func(ctx context.Context, t chainTarget, nreq llm.Request) (*llm.Response, error) { + return t.model.Generate(ctx, nreq) }) } @@ -76,14 +77,17 @@ func (c *chain) Generate(ctx context.Context, req llm.Request, opts ...llm.Optio // (replaying half-delivered output would duplicate content). func (c *chain) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) { req = req.Apply(opts...) - return chainDo(ctx, c, func(ctx context.Context, t chainTarget) (llm.Stream, error) { - return t.model.Stream(ctx, req) + return chainDo(ctx, c, req, func(ctx context.Context, t chainTarget, nreq llm.Request) (llm.Stream, error) { + return t.model.Stream(ctx, nreq) }) } // chainDo runs the head-to-tail failover algorithm around an attempt -// function, generic over the result type (response vs stream). -func chainDo[T any](ctx context.Context, c *chain, attempt func(context.Context, chainTarget) (T, error)) (T, error) { +// function, generic over the result type (response vs stream). Before each +// target is tried, the request's media is normalized against THAT target's +// capabilities (ADR-0008/0009) — a request that cannot be made to fit one +// target advances to the next without a health penalty. +func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func(context.Context, chainTarget, llm.Request) (T, error)) (T, error) { var zero T var failures []error @@ -94,12 +98,20 @@ func chainDo[T any](ctx context.Context, c *chain, attempt func(context.Context, continue } + nreq, err := media.Normalize(req, t.model.Capabilities()) + if err != nil { + // Always ErrUnsupported-wrapped: this target cannot take the + // request by declaration. Advance, no health penalty. + failures = append(failures, fmt.Errorf("%s: %w", t.key, err)) + continue + } + retries := c.cfg.retries() for attemptN := 0; ; attemptN++ { if err := ctx.Err(); err != nil { return zero, err } - result, err := attempt(ctx, t) + result, err := attempt(ctx, t, nreq) if err == nil { c.tracker.ReportSuccess(t.key) return result, nil diff --git a/docs/adr/0009-multimodal-strategy.md b/docs/adr/0009-multimodal-strategy.md new file mode 100644 index 0000000..ca32db4 --- /dev/null +++ b/docs/adr/0009-multimodal-strategy.md @@ -0,0 +1,57 @@ +# ADR-0009: Multimodal strategy — normalize per target, enforce at the provider + +**Status:** Accepted — 2026-06-10 + +## Context + +Every provider (and some models) imposes different image rules: max +dimensions/bytes, allowed MIME types, max images per request. A caller must +be able to attach an image without knowing the eventual target — especially +with failover chains, where the serving target isn't known until runtime. + +## Decision + +Two cooperating layers: + +1. **`media.Normalize(req, caps)`** — the transformation point. The chain + executor calls it **per target, per attempt**, against the actual + target's capabilities, before the provider sees the request: + - The real format is **sniffed from magic bytes** and wins over the + declared MIME (callers lie; jpeg/png/gif/webp recognized). + - Already-fitting images pass through untouched (fast path: zero copies). + - Oversize dimensions downscale (aspect-preserving) with a hand-rolled + box-filter — stdlib has no scaler and `x/image` stays out per + ADR-0007; box-average quality is ample for vision input. + - Disallowed MIME re-encodes: original format if allowed, else JPEG + (q85), else PNG, else the first allowed encodable type. + - Byte budgets enforce via a quality ladder (jpeg 85→65→45→30) then + dimension halving; ~6 attempts before giving up. + - WebP cannot be decoded by stdlib: it passes through when it fits and + is allowed; any needed transform is a clear error. + - Everything that cannot be made to fit errors **wrapping + `llm.ErrUnsupported`** — never silently dropped. +2. **Provider backstop** — each provider cheaply enforces its effective + capabilities at request time (image count/MIME/bytes, plus + tools/structured/streaming support flags) and rejects with + `ErrUnsupported`. This keeps providers honest for expert callers who + build models directly without the registry. + +Chain semantics: a normalization failure for one target **advances** to the +next element with no health penalty (the target isn't sick, it's just +incapable) — so `fp/text-only,fp/vision` serves an image request from the +vision element automatically. + +Canonical image content stays **bytes + MIME** (ADR-0002); no URL fetching. + +## Consequences + +- A 100×50 PNG sent at a 32px-cap target arrives as a 32×16 PNG; the same + request served by an 8000px target arrives untouched. +- Conditional provider rules (e.g. Anthropic's 2000px cap above 20 images) + are approximated by the flat declared caps — conservative and simple. + +## Alternatives considered + +- Normalize once against chain-intersection caps: over-restricts every + request for the sake of rarely-used fallbacks. Rejected (ADR-0008). +- `x/image/draw` scalers: a dependency for one function. Rejected. diff --git a/docs/adr/0010-tools-structured-output-mapping.md b/docs/adr/0010-tools-structured-output-mapping.md new file mode 100644 index 0000000..94d4751 --- /dev/null +++ b/docs/adr/0010-tools-structured-output-mapping.md @@ -0,0 +1,49 @@ +# ADR-0010: Tools and structured output — one canonical shape, native mappings + +**Status:** Accepted — 2026-06-10 + +## Context + +Tool calling and schema-constrained output exist on every target but with +different wire shapes (verified against current docs, June 2026; shapes +recorded in each provider's package doc). The canonical API must hide all +of it. + +## Decision + +Canonical: `Tool{Name, Description, Parameters (JSON Schema), Handler}`; +`Response.ToolCalls[]{ID, Name, Arguments json.RawMessage}`; results return +as `ToolResultsMessage(ToolResult{ID, Name, Content, IsError})`. Structured +output via `WithSchema(schema, name)`. Per-provider mapping: + +| Concern | OpenAI(+compat) | Anthropic(+compat) | Ollama/foreman | Google (Phase 4) | +|---|---|---|---|---| +| Tool def | `tools[].function{name,description,parameters}` | `tools[]{name,description,input_schema}` | `tools[].function` | `FunctionDeclaration.ParametersJsonSchema` | +| Call args | JSON **string** → RawMessage | `tool_use.input` object | `arguments` **object** | `FunctionCall.Args` map | +| Results | one `role:tool` msg per result (`tool_call_id`) | one **user** msg, `tool_result` blocks (`is_error` native) | `role:tool` + `tool_name` | `FunctionResponse` parts | +| IsError | `"ERROR: "` content prefix | `is_error: true` | `"ERROR: "` prefix | response payload field | +| Forced choice | `tool_choice` string / named object | `{"type":"any"/"tool"/"none"}` | none → drop tools; others best-effort ignored | `FunctionCallingConfig` modes | +| Structured | `response_format json_schema` (no strict flag) | `output_config.format json_schema` (GA mechanism) | `format: ` | `ResponseJsonSchema` + JSON MIME | + +Cross-cutting decisions: + +- **Missing call ids are synthesized** (`call_`) — Ollama and some + compat servers omit them; the agent loop needs ids to match results. +- **Streaming buffers tool-call arguments to completion** (ADR-0002): + OpenAI fragments accumulate by index, Anthropic `input_json_delta` + fragments accumulate per block; consumers only ever see parseable calls. +- **No strict-mode flag is sent** to OpenAI: strict mode imposes schema + constraints (every property required, additionalProperties:false) that + caller-supplied schemas may not satisfy. The `Generate[T]` reflector + (Phase 5) emits strict-compatible schemas anyway. +- `SchemaName` feeds providers that need a name (OpenAI; default + "response"); others ignore it. +- Tool handlers never panic the loop: `Toolbox.Execute`/`ExecuteTool` + recover panics and JSON-encode results (ADR to agent loop, Phase 5). + +## Consequences + +- One test matrix per provider asserts the exact wire JSON both directions; + drift is caught by httptest fixtures, not in production. +- Ollama's missing tool_choice means "required" cannot be enforced there — + documented in the README matrix rather than emulated. diff --git a/docs/adr/README.md b/docs/adr/README.md index 6346769..d7bf294 100644 --- a/docs/adr/README.md +++ b/docs/adr/README.md @@ -12,3 +12,5 @@ One decision per file, append-only; supersede rather than rewrite. | [0006](0006-health-and-backoff.md) | Model health tracking and backoff | Accepted | | [0007](0007-dependency-policy.md) | Dependency policy — stdlib-first, hand-rolled REST clients | Accepted | | [0008](0008-chain-semantics.md) | Failover-chain execution semantics | Accepted | +| [0009](0009-multimodal-strategy.md) | Multimodal strategy — normalize per target, enforce at provider | Accepted | +| [0010](0010-tools-structured-output-mapping.md) | Tools and structured output — canonical shape, native mappings | Accepted | diff --git a/env_test.go b/env_test.go index dda67bd..fcbede9 100644 --- a/env_test.go +++ b/env_test.go @@ -1,10 +1,17 @@ package majordomo import ( + "context" + "encoding/json" "errors" + "io" + "net/http" + "net/http/httptest" "slices" "strings" "testing" + + "gitea.stevedudenhoeffer.com/steve/majordomo/provider/ollama" ) func TestParseDSN(t *testing.T) { @@ -71,19 +78,16 @@ func TestLoadEnvForeman(t *testing.T) { if !ok { t.Fatalf("provider %q not registered", name) } - sp, ok := p.(*stubProvider) + op, ok := p.(*ollama.Provider) if !ok { - t.Fatalf("provider %q is %T, want *stubProvider (phase 1)", name, p) + t.Fatalf("provider %q is %T, want *ollama.Provider (foreman scheme)", name, p) } - if sp.kind != ProviderForeman { - t.Errorf("provider %q kind = %q, want foreman", name, sp.kind) + if op.Name() != name { + t.Errorf("provider name = %q, want %q", op.Name(), name) } wantURL := "https://foreman-" + name + ".orgrimmar.dudenhoeffer.casa" - if sp.baseURL != wantURL { - t.Errorf("provider %q baseURL = %q, want %q", name, sp.baseURL, wantURL) - } - if sp.token != "test-token-change-me" { - t.Errorf("provider %q token = %q, want the DSN userinfo", name, sp.token) + if op.BaseURL() != wantURL { + t.Errorf("provider %q baseURL = %q, want %q", name, op.BaseURL(), wantURL) } } @@ -193,3 +197,54 @@ func TestNewLoadsProcessEnv(t *testing.T) { t.Error("New() should eagerly load LLM_ENVTEST from the process environment") } } + +// TestEnvForemanChatRoundTrip is the required end-to-end case: an LLM_* +// foreman DSN resolves through Parse and serves a real chat over the wire +// (TLS test server, since env DSNs always dial https), with the DSN token +// arriving as the bearer credential. +func TestEnvForemanChatRoundTrip(t *testing.T) { + var gotAuth, gotPath string + var gotModel string + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + gotPath = r.URL.Path + var body struct { + Model string `json:"model"` + } + _ = json.NewDecoder(r.Body).Decode(&body) + gotModel = body.Model + _, _ = io.WriteString(w, `{"message":{"role":"assistant","content":"hi from foreman"},"done":true,"done_reason":"stop","prompt_eval_count":2,"eval_count":3}`) + })) + defer ts.Close() + + host := strings.TrimPrefix(ts.URL, "https://") + r := New( + WithoutEnvProviders(), + WithEnvLookup(func(string) string { return "" }), + WithHTTPClient(ts.Client()), + ) + if err := r.LoadEnv(map[string]string{"LLM_FM": "foreman://round-trip-token@" + host}); err != nil { + t.Fatalf("LoadEnv: %v", err) + } + + m, err := r.Parse("fm/qwen3:30b") + if err != nil { + t.Fatalf("Parse: %v", err) + } + resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hello")}}) + if err != nil { + t.Fatalf("Generate: %v", err) + } + if resp.Text() != "hi from foreman" { + t.Errorf("text = %q", resp.Text()) + } + if resp.Model != "fm/qwen3:30b" { + t.Errorf("resp.Model = %q, want fm/qwen3:30b", resp.Model) + } + if gotAuth != "Bearer round-trip-token" { + t.Errorf("auth = %q, want the DSN token as bearer", gotAuth) + } + if gotPath != "/api/chat" || gotModel != "qwen3:30b" { + t.Errorf("path=%q model=%q", gotPath, gotModel) + } +} diff --git a/failover_test.go b/failover_test.go index dc0a976..625c21c 100644 --- a/failover_test.go +++ b/failover_test.go @@ -1,8 +1,12 @@ package majordomo import ( + "bytes" "context" "errors" + "image" + "image/color" + "image/png" "slices" "strings" "sync" @@ -325,6 +329,95 @@ func TestSingleTargetGetsChainSemantics(t *testing.T) { } } +// pngImage encodes a width×height PNG for media tests. +func pngImage(t *testing.T, width, height int) []byte { + t.Helper() + img := image.NewRGBA(image.Rect(0, 0, width, height)) + for y := range height { + for x := range width { + img.Set(x, y, color.RGBA{R: uint8(x), G: uint8(y), B: 128, A: 255}) + } + } + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + t.Fatalf("encode png: %v", err) + } + return buf.Bytes() +} + +// TestChainNormalizesMediaPerTarget: the request's image is downscaled to +// the capabilities of the target that actually serves it. +func TestChainNormalizesMediaPerTarget(t *testing.T) { + r := newTestRegistry(t) + fp := fake.New("fp", + fake.WithModelCapabilities("small-vision", llm.Capabilities{ + MaxImagesPerReq: 2, + MaxImageDimension: 32, + AllowedImageMIME: []string{"image/png"}, + }), + ) + r.RegisterProvider(fp) + + m, _ := r.Parse("fp/small-vision") + _, err := m.Generate(context.Background(), Request{Messages: []Message{ + UserParts(Text("describe"), Image("image/png", pngImage(t, 100, 50))), + }}) + if err != nil { + t.Fatalf("Generate: %v", err) + } + + calls := fp.Calls() + if len(calls) != 1 { + t.Fatalf("calls = %d", len(calls)) + } + var img llm.ImagePart + for _, part := range calls[0].Request.Messages[0].Parts { + if ip, ok := part.(llm.ImagePart); ok { + img = ip + } + } + if img.Data == nil { + t.Fatal("no image reached the provider") + } + cfg, err := png.DecodeConfig(bytes.NewReader(img.Data)) + if err != nil { + t.Fatalf("decode delivered image: %v", err) + } + if cfg.Width != 32 || cfg.Height != 16 { + t.Errorf("delivered image = %dx%d, want 32x16 (downscaled to target cap)", cfg.Width, cfg.Height) + } +} + +// TestChainAdvancesPastImagelessTarget: a text-only head can't take an +// image request; the chain advances to a vision-capable element with no +// health penalty. +func TestChainAdvancesPastImagelessTarget(t *testing.T) { + r := newTestRegistry(t) + fp := fake.New("fp", + fake.WithModelCapabilities("text-only", llm.Capabilities{SupportsTools: true}), + fake.WithModelCapabilities("vision", llm.Capabilities{MaxImagesPerReq: 4}), + ) + r.RegisterProvider(fp) + fp.Enqueue("vision", fake.Reply("a tasteful png")) + + m, _ := r.Parse("fp/text-only,fp/vision") + resp, err := m.Generate(context.Background(), Request{Messages: []Message{ + UserParts(Text("what is this?"), Image("image/png", pngImage(t, 8, 8))), + }}) + if err != nil { + t.Fatalf("Generate: %v", err) + } + if resp.Text() != "a tasteful png" { + t.Errorf("text = %q", resp.Text()) + } + if got := fp.CallCount("text-only"); got != 0 { + t.Errorf("text-only target saw %d calls, want 0 (normalization rejects pre-send)", got) + } + if !r.Health().Available("fp/text-only") { + t.Error("media rejection must not penalize health") + } +} + // TestHTTP529ClassifiedTransient: Anthropic's "overloaded" status fails // over like any other transient error. func TestHTTP529FailsOver(t *testing.T) { diff --git a/majordomo.go b/majordomo.go index 9a77942..d8c4f4d 100644 --- a/majordomo.go +++ b/majordomo.go @@ -80,26 +80,26 @@ var ErrModelNotFound = llm.ErrModelNotFound var ErrUnsupported = llm.ErrUnsupported // Re-exported content and message constructors. -func Text(s string) Part { return llm.Text(s) } -func Image(mime string, data []byte) Part { return llm.Image(mime, data) } -func SystemText(s string) Message { return llm.SystemText(s) } -func UserText(s string) Message { return llm.UserText(s) } -func UserParts(parts ...Part) Message { return llm.UserParts(parts...) } -func AssistantText(s string) Message { return llm.AssistantText(s) } +func Text(s string) Part { return llm.Text(s) } +func Image(mime string, data []byte) Part { return llm.Image(mime, data) } +func SystemText(s string) Message { return llm.SystemText(s) } +func UserText(s string) Message { return llm.UserText(s) } +func UserParts(parts ...Part) Message { return llm.UserParts(parts...) } +func AssistantText(s string) Message { return llm.AssistantText(s) } func ToolResultsMessage(results ...ToolResult) Message { return llm.ToolResultsMessage(results...) } -func NewToolbox(name string, tools ...Tool) *Toolbox { return llm.NewToolbox(name, tools...) } +func NewToolbox(name string, tools ...Tool) *Toolbox { return llm.NewToolbox(name, tools...) } // Re-exported request options. -func WithSystem(s string) Option { return llm.WithSystem(s) } -func WithTools(tools ...Tool) Option { return llm.WithTools(tools...) } -func WithToolbox(b *Toolbox) Option { return llm.WithToolbox(b) } -func WithToolChoice(choice string) Option { return llm.WithToolChoice(choice) } +func WithSystem(s string) Option { return llm.WithSystem(s) } +func WithTools(tools ...Tool) Option { return llm.WithTools(tools...) } +func WithToolbox(b *Toolbox) Option { return llm.WithToolbox(b) } +func WithToolChoice(choice string) Option { return llm.WithToolChoice(choice) } func WithSchema(schema json.RawMessage, name string) Option { return llm.WithSchema(schema, name) } -func WithTemperature(t float64) Option { return llm.WithTemperature(t) } -func WithTopP(p float64) Option { return llm.WithTopP(p) } -func WithMaxTokens(n int) Option { return llm.WithMaxTokens(n) } -func WithStopSequences(stops ...string) Option { return llm.WithStopSequences(stops...) } -func WithReasoningEffort(level string) Option { return llm.WithReasoningEffort(level) } +func WithTemperature(t float64) Option { return llm.WithTemperature(t) } +func WithTopP(p float64) Option { return llm.WithTopP(p) } +func WithMaxTokens(n int) Option { return llm.WithMaxTokens(n) } +func WithStopSequences(stops ...string) Option { return llm.WithStopSequences(stops...) } +func WithReasoningEffort(level string) Option { return llm.WithReasoningEffort(level) } // WithModelCapabilities re-exports llm.WithCapabilities for Provider.Model // calls made through this package. diff --git a/media/media.go b/media/media.go new file mode 100644 index 0000000..e6e0bc8 --- /dev/null +++ b/media/media.go @@ -0,0 +1,293 @@ +// Package media fits request images to a target's declared capabilities. +// +// Normalize sniffs each image's real format from magic bytes (declared MIME +// types lie), corrects the part's MIME, and passes through anything that +// already satisfies the target's llm.Capabilities. Images that do not fit +// are decoded, downscaled (never upscaled), and re-encoded into an allowed +// format and byte budget. Anything that cannot honestly be made to fit — +// undecodable formats, impossible byte budgets, too many images, images for +// a text-only target — fails with an error wrapping llm.ErrUnsupported so a +// failover chain can advance to a more capable target without a health +// penalty. +// +// Why a separate package: every provider would otherwise duplicate the same +// decode/scale/encode pipeline. Providers keep only a cheap capability +// enforcement backstop; this package performs the actual transformation, +// once, against whichever target a chain is currently attempting. +package media + +import ( + "bytes" + "fmt" + "image" + "image/gif" + "image/jpeg" + "image/png" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// Normalize returns a copy of req whose images fit caps, transforming +// (downscale, re-encode) where needed. The input request is never mutated. +// +// Fast paths: a request with no image parts, or whose images already satisfy +// caps, is returned unchanged with all underlying slices shared. When any +// image transforms, the Messages slice and the Parts slices of affected +// messages are copied (copy-on-write); untouched parts stay shared. +// +// Images that cannot be made to fit return an error wrapping +// llm.ErrUnsupported. +func Normalize(req llm.Request, caps llm.Capabilities) (llm.Request, error) { + total := 0 + for i := range req.Messages { + for _, p := range req.Messages[i].Parts { + if _, ok := p.(llm.ImagePart); ok { + total++ + } + } + } + if total == 0 { + return req, nil + } + if !caps.SupportsImages() { + return llm.Request{}, fmt.Errorf("media: %w: target does not accept image input (request carries %d image(s))", llm.ErrUnsupported, total) + } + // Why error instead of dropping the overflow: silently removing an image + // changes the question the caller asked; the honest move is to refuse and + // let a chain try a roomier target. + if total > caps.MaxImagesPerReq { + return llm.Request{}, fmt.Errorf("media: %w: request carries %d images, target allows at most %d per request", llm.ErrUnsupported, total, caps.MaxImagesPerReq) + } + + out := req + copiedMessages := false + for mi := range req.Messages { + copiedParts := false + for pi, part := range req.Messages[mi].Parts { + ip, ok := part.(llm.ImagePart) + if !ok { + continue + } + norm, changed, err := normalizeImage(ip, caps) + if err != nil { + return llm.Request{}, fmt.Errorf("media: message %d, part %d: %w", mi, pi, err) + } + if !changed { + continue + } + if !copiedMessages { + out.Messages = make([]llm.Message, len(req.Messages)) + copy(out.Messages, req.Messages) + copiedMessages = true + } + if !copiedParts { + parts := make([]llm.Part, len(req.Messages[mi].Parts)) + copy(parts, req.Messages[mi].Parts) + out.Messages[mi].Parts = parts + copiedParts = true + } + out.Messages[mi].Parts[pi] = norm + } + } + return out, nil +} + +// Info reports an image part's sniffed format ("jpeg", "png", "gif", or +// "webp") and pixel dimensions. It is a cheap metadata read — the pixels are +// never decoded. webp is recognized by signature but not decodable with the +// standard library, so it reports format "webp" with zero dimensions and a +// nil error. +func Info(p llm.ImagePart) (format string, width, height int, err error) { + format = sniff(p.Data) + switch format { + case "": + return "", 0, 0, fmt.Errorf("media: image bytes match no known format (jpeg, png, gif, webp)") + case "webp": + return "webp", 0, 0, nil + } + cfg, _, err := image.DecodeConfig(bytes.NewReader(p.Data)) + if err != nil { + return format, 0, 0, fmt.Errorf("media: decode %s config: %w", format, err) + } + return format, cfg.Width, cfg.Height, nil +} + +// normalizeImage fits one image part to caps. It returns the (possibly +// transformed) part and whether it differs from the input. A corrected MIME +// with untouched bytes still counts as changed so Normalize copy-on-writes +// the containing slices. +func normalizeImage(p llm.ImagePart, caps llm.Capabilities) (llm.ImagePart, bool, error) { + // Why sniff instead of trusting p.MIME: callers routinely mislabel image + // bytes, and providers reject mismatches; the bytes are the truth. + format := sniff(p.Data) + if format == "" { + return p, false, fmt.Errorf("%w: image bytes (declared %q) match no known format (jpeg, png, gif, webp)", llm.ErrUnsupported, p.MIME) + } + realMIME := "image/" + format + changed := false + if p.MIME != realMIME { + p.MIME = realMIME + changed = true + } + + mimeOK := caps.MIMEAllowed(realMIME) + fitsBytes := caps.MaxImageBytes == 0 || len(p.Data) <= caps.MaxImageBytes + fitsDims := true + if caps.MaxImageDimension > 0 && format != "webp" { + // Cheap header-only dimension read; a failure forces the transform + // path, which surfaces the real decode error. + cfg, _, err := image.DecodeConfig(bytes.NewReader(p.Data)) + if err != nil { + fitsDims = false + } else { + fitsDims = cfg.Width <= caps.MaxImageDimension && cfg.Height <= caps.MaxImageDimension + } + } + // Why webp skips the dimension check: the stdlib cannot read webp + // headers, so dimensions are unverifiable; if MIME and bytes fit we pass + // it through rather than reject a possibly-fine image. + if mimeOK && fitsBytes && fitsDims { + return p, changed, nil + } + + // Transformation required from here on, which needs a real decode. + if format == "webp" { + return p, false, fmt.Errorf("%w: image is webp (%d bytes), which the Go standard library cannot decode; provide jpeg, png, or gif instead", llm.ErrUnsupported, len(p.Data)) + } + img, _, err := image.Decode(bytes.NewReader(p.Data)) + if err != nil { + return p, false, fmt.Errorf("%w: cannot decode %s image for transformation: %v", llm.ErrUnsupported, format, err) + } + + if caps.MaxImageDimension > 0 { + b := img.Bounds() + if b.Dx() > caps.MaxImageDimension || b.Dy() > caps.MaxImageDimension { + nw, nh := fitDims(b.Dx(), b.Dy(), caps.MaxImageDimension) + img = downscale(img, nw, nh) + } + } + + target, err := targetMIME(realMIME, caps) + if err != nil { + return p, false, err + } + data, err := encodeFit(img, target, caps.MaxImageBytes) + if err != nil { + return p, false, err + } + return llm.ImagePart{MIME: target, Data: data}, true, nil +} + +// sniff identifies an image format from its magic bytes, returning "jpeg", +// "png", "gif", "webp", or "" when nothing matches. +func sniff(data []byte) string { + switch { + case len(data) >= 3 && data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF: + return "jpeg" + case len(data) >= 4 && data[0] == 0x89 && data[1] == 'P' && data[2] == 'N' && data[3] == 'G': + return "png" + case len(data) >= 4 && string(data[:4]) == "GIF8": + return "gif" + case len(data) >= 12 && string(data[:4]) == "RIFF" && string(data[8:12]) == "WEBP": + return "webp" + default: + return "" + } +} + +// encodableMIME reports whether the stdlib can encode the given image type. +func encodableMIME(mime string) bool { + switch mime { + case "image/jpeg", "image/png", "image/gif": + return true + } + return false +} + +// targetMIME picks the re-encode format: the original when allowed, else +// jpeg, else png, else the first allowed encodable type (gif). When nothing +// allowed is encodable (e.g. only webp), it errors with llm.ErrUnsupported. +func targetMIME(original string, caps llm.Capabilities) (string, error) { + if encodableMIME(original) && caps.MIMEAllowed(original) { + return original, nil + } + // Why jpeg before png: vision inputs are photographs more often than + // screenshots, and jpeg's quality knob is the only size lever we have + // for the byte-budget loop. + for _, m := range []string{"image/jpeg", "image/png"} { + if caps.MIMEAllowed(m) { + return m, nil + } + } + // An empty allow-list permits everything and was caught above, so the + // list is non-empty here: take its first encodable entry. + for _, m := range caps.AllowedImageMIME { + if encodableMIME(m) { + return m, nil + } + } + return "", fmt.Errorf("%w: none of the allowed image types %v can be encoded with the Go standard library", llm.ErrUnsupported, caps.AllowedImageMIME) +} + +// encodeFit encodes img as mime within maxBytes (0 = no limit), trading +// jpeg quality first and then resolution for size. The ladder is fixed +// (jpeg: q85, q65, q45, q30, then half and quarter dimensions at q65; +// png/gif: full, half, quarter dimensions) — at most six attempts, since an +// image that survives a 16x pixel reduction over budget will not be saved +// by further fiddling. +func encodeFit(img image.Image, mime string, maxBytes int) ([]byte, error) { + type attempt struct { + div int // divide both dimensions by this + quality int // jpeg quality; ignored for png/gif + } + var ladder []attempt + if mime == "image/jpeg" { + ladder = []attempt{{1, 85}, {1, 65}, {1, 45}, {1, 30}, {2, 65}, {4, 65}} + } else { + ladder = []attempt{{1, 0}, {2, 0}, {4, 0}} + } + + scaled := map[int]image.Image{1: img} + smallest := -1 + for _, a := range ladder { + cur, ok := scaled[a.div] + if !ok { + b := img.Bounds() + nw, nh := max(b.Dx()/a.div, 1), max(b.Dy()/a.div, 1) + cur = downscale(img, nw, nh) + scaled[a.div] = cur + } + data, err := encodeImage(cur, mime, a.quality) + if err != nil { + return nil, fmt.Errorf("encode %s: %w", mime, err) + } + if maxBytes == 0 || len(data) <= maxBytes { + return data, nil + } + if smallest == -1 || len(data) < smallest { + smallest = len(data) + } + } + return nil, fmt.Errorf("%w: image cannot be reduced to the %d-byte limit; smallest achievable %s encoding is %d bytes", llm.ErrUnsupported, maxBytes, mime, smallest) +} + +// encodeImage encodes img into the given MIME type. quality applies to jpeg +// only. +func encodeImage(img image.Image, mime string, quality int) ([]byte, error) { + var buf bytes.Buffer + var err error + switch mime { + case "image/jpeg": + err = jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}) + case "image/png": + err = png.Encode(&buf, img) + case "image/gif": + err = gif.Encode(&buf, img, nil) + default: + return nil, fmt.Errorf("no stdlib encoder for %q", mime) + } + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} diff --git a/media/media_test.go b/media/media_test.go new file mode 100644 index 0000000..2e3e92c --- /dev/null +++ b/media/media_test.go @@ -0,0 +1,513 @@ +package media + +import ( + "bytes" + "errors" + "image" + "image/color" + "image/gif" + "image/jpeg" + "image/png" + "math/rand/v2" + "strings" + "testing" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// --- test image builders ------------------------------------------------- + +// gradient builds a smooth w x h RGBA image (compresses well). +func gradient(w, h int) *image.RGBA { + img := image.NewRGBA(image.Rect(0, 0, w, h)) + for y := 0; y < h; y++ { + for x := 0; x < w; x++ { + img.SetRGBA(x, y, color.RGBA{ + R: uint8(x * 255 / max(w-1, 1)), + G: uint8(y * 255 / max(h-1, 1)), + B: 128, + A: 255, + }) + } + } + return img +} + +// noisy builds a w x h image of deterministic random pixels (compresses +// terribly — ideal for exercising the byte-budget ladder). +func noisy(w, h int) *image.RGBA { + rng := rand.New(rand.NewPCG(1, 2)) + img := image.NewRGBA(image.Rect(0, 0, w, h)) + for i := range img.Pix { + img.Pix[i] = uint8(rng.UintN(256)) + } + return img +} + +func encPNG(t *testing.T, img image.Image) []byte { + t.Helper() + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + t.Fatalf("png encode: %v", err) + } + return buf.Bytes() +} + +func encJPEG(t *testing.T, img image.Image) []byte { + t.Helper() + var buf bytes.Buffer + if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 90}); err != nil { + t.Fatalf("jpeg encode: %v", err) + } + return buf.Bytes() +} + +func encGIF(t *testing.T, img image.Image) []byte { + t.Helper() + var buf bytes.Buffer + if err := gif.Encode(&buf, img, nil); err != nil { + t.Fatalf("gif encode: %v", err) + } + return buf.Bytes() +} + +// webpBlob is a minimal byte sequence carrying the RIFF/WEBP signature. +// The stdlib cannot decode webp, so sniffing is all that ever reads it. +func webpBlob() []byte { + b := []byte("RIFF") + b = append(b, 0x1a, 0x00, 0x00, 0x00) + b = append(b, "WEBPVP8 "...) + b = append(b, make([]byte, 18)...) + return b +} + +func imgReq(parts ...llm.Part) llm.Request { + return llm.Request{Messages: []llm.Message{llm.UserParts(parts...)}} +} + +// firstImage returns the first image part in the request. +func firstImage(t *testing.T, req llm.Request) llm.ImagePart { + t.Helper() + for _, m := range req.Messages { + for _, p := range m.Parts { + if ip, ok := p.(llm.ImagePart); ok { + return ip + } + } + } + t.Fatal("no image part in request") + return llm.ImagePart{} +} + +// --- fast paths ----------------------------------------------------------- + +func TestNormalizeFastPathNoImages(t *testing.T) { + req := llm.Request{Messages: []llm.Message{llm.UserText("hello")}} + got, err := Normalize(req, llm.Capabilities{}) // even a no-image target + if err != nil { + t.Fatalf("Normalize: %v", err) + } + if &got.Messages[0] != &req.Messages[0] { + t.Error("messages slice was copied on the no-image fast path") + } + if &got.Messages[0].Parts[0] != &req.Messages[0].Parts[0] { + t.Error("parts slice was copied on the no-image fast path") + } +} + +func TestNormalizeFastPathFittingImages(t *testing.T) { + data := encPNG(t, gradient(20, 10)) + req := imgReq(llm.Text("look:"), llm.Image("image/png", data)) + caps := llm.Capabilities{ + MaxImagesPerReq: 4, + MaxImageBytes: len(data) + 100, + MaxImageDimension: 64, + AllowedImageMIME: []string{"image/png"}, + } + got, err := Normalize(req, caps) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + if &got.Messages[0] != &req.Messages[0] { + t.Error("messages slice was copied although every image already fits") + } + if &got.Messages[0].Parts[1] != &req.Messages[0].Parts[1] { + t.Error("parts slice was copied although every image already fits") + } +} + +// --- rejection paths ------------------------------------------------------ + +func TestNormalizeImagesUnsupported(t *testing.T) { + req := imgReq(llm.Image("image/png", encPNG(t, gradient(4, 4)))) + _, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 0}) + if !errors.Is(err, llm.ErrUnsupported) { + t.Fatalf("err = %v, want ErrUnsupported", err) + } + if !strings.Contains(err.Error(), "does not accept image input") { + t.Errorf("err message %q lacks explanation", err) + } +} + +func TestNormalizeTooManyImages(t *testing.T) { + img := llm.Image("image/png", encPNG(t, gradient(4, 4))) + req := llm.Request{Messages: []llm.Message{ + llm.UserParts(img, img), + llm.UserParts(img), + }} + _, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 2}) + if !errors.Is(err, llm.ErrUnsupported) { + t.Fatalf("err = %v, want ErrUnsupported", err) + } + if !strings.Contains(err.Error(), "3 images") || !strings.Contains(err.Error(), "at most 2") { + t.Errorf("err message %q lacks the counts", err) + } +} + +func TestNormalizeGarbageBytes(t *testing.T) { + req := imgReq(llm.Image("image/png", []byte("certainly not an image"))) + _, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 1}) + if !errors.Is(err, llm.ErrUnsupported) { + t.Fatalf("err = %v, want ErrUnsupported", err) + } + if !strings.Contains(err.Error(), "no known format") { + t.Errorf("err message %q lacks a clear explanation", err) + } +} + +// --- MIME sniffing & correction -------------------------------------------- + +func TestNormalizeMIMECorrection(t *testing.T) { + data := encPNG(t, gradient(8, 8)) + req := imgReq(llm.Image("image/jpeg", data)) // caller lies: bytes are png + got, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 1}) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + ip := firstImage(t, got) + if ip.MIME != "image/png" { + t.Errorf("MIME = %q, want sniff-corrected %q", ip.MIME, "image/png") + } + if !bytes.Equal(ip.Data, data) { + t.Error("image bytes changed although only the MIME needed correcting") + } + if orig := firstImage(t, req); orig.MIME != "image/jpeg" { + t.Errorf("input request mutated: MIME now %q", orig.MIME) + } +} + +func TestNormalizeCopyOnWrite(t *testing.T) { + data := encPNG(t, gradient(8, 8)) + req := llm.Request{Messages: []llm.Message{ + llm.UserText("untouched message"), + llm.UserParts(llm.Text("untouched part"), llm.Image("image/jpeg", data)), + }} + got, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 1}) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + if &got.Messages[0] == &req.Messages[0] { + t.Error("messages slice shared although a part changed (mutation hazard)") + } + if &got.Messages[0].Parts[0] != &req.Messages[0].Parts[0] { + t.Error("parts slice of the untouched message was copied") + } + if &got.Messages[1].Parts[0] == &req.Messages[1].Parts[0] { + t.Error("parts slice of the changed message is still shared (mutation hazard)") + } +} + +// --- dimension capping ------------------------------------------------------ + +func TestNormalizeDownscale(t *testing.T) { + req := imgReq(llm.Image("image/png", encPNG(t, gradient(200, 100)))) + caps := llm.Capabilities{MaxImagesPerReq: 1, MaxImageDimension: 50} + got, err := Normalize(req, caps) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + format, w, h, err := Info(firstImage(t, got)) + if err != nil { + t.Fatalf("Info: %v", err) + } + if format != "png" { + t.Errorf("format = %q, want original format %q preserved", format, "png") + } + if w != 50 || h != 25 { + t.Errorf("dimensions = %dx%d, want 50x25 (aspect preserved)", w, h) + } +} + +func TestNormalizeDownscalePortrait(t *testing.T) { + req := imgReq(llm.Image("image/png", encPNG(t, gradient(100, 200)))) + caps := llm.Capabilities{MaxImagesPerReq: 1, MaxImageDimension: 50} + got, err := Normalize(req, caps) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + _, w, h, err := Info(firstImage(t, got)) + if err != nil { + t.Fatalf("Info: %v", err) + } + if w != 25 || h != 50 { + t.Errorf("dimensions = %dx%d, want 25x50 (aspect preserved)", w, h) + } +} + +// --- transcoding ------------------------------------------------------------- + +func TestNormalizeTranscode(t *testing.T) { + tests := []struct { + name string + data []byte + mime string + allowed []string + want string + }{ + {"png to jpeg", encPNG(t, gradient(16, 16)), "image/png", []string{"image/jpeg"}, "jpeg"}, + {"jpeg to png", encJPEG(t, gradient(16, 16)), "image/jpeg", []string{"image/png"}, "png"}, + {"gif to png", encGIF(t, gradient(16, 16)), "image/gif", []string{"image/png"}, "png"}, + {"png to gif fallback", encPNG(t, gradient(16, 16)), "image/png", []string{"image/gif"}, "gif"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := imgReq(llm.Image(tt.mime, tt.data)) + caps := llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: tt.allowed} + got, err := Normalize(req, caps) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + ip := firstImage(t, got) + if ip.MIME != "image/"+tt.want { + t.Errorf("MIME = %q, want %q", ip.MIME, "image/"+tt.want) + } + format, w, h, err := Info(ip) + if err != nil { + t.Fatalf("Info: %v", err) + } + if format != tt.want { + t.Errorf("sniffed format = %q, want %q", format, tt.want) + } + if w != 16 || h != 16 { + t.Errorf("dimensions = %dx%d, want 16x16 (no resize needed)", w, h) + } + }) + } +} + +func TestNormalizeNoEncodableAllowedType(t *testing.T) { + // png needs transcoding but the only allowed type is webp, which the + // stdlib cannot encode. + req := imgReq(llm.Image("image/png", encPNG(t, gradient(8, 8)))) + caps := llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/webp"}} + _, err := Normalize(req, caps) + if !errors.Is(err, llm.ErrUnsupported) { + t.Fatalf("err = %v, want ErrUnsupported", err) + } + if !strings.Contains(err.Error(), "image/webp") { + t.Errorf("err message %q does not name the unencodable allowed types", err) + } +} + +// --- byte budget --------------------------------------------------------------- + +func TestNormalizeByteBudgetFits(t *testing.T) { + // Random noise defeats q85 jpeg at full size; the ladder must walk down + // quality and then resolution until the encoding fits. + req := imgReq(llm.Image("image/png", encPNG(t, noisy(256, 256)))) + caps := llm.Capabilities{ + MaxImagesPerReq: 1, + AllowedImageMIME: []string{"image/jpeg"}, + MaxImageBytes: 8 * 1024, + } + got, err := Normalize(req, caps) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + ip := firstImage(t, got) + if len(ip.Data) > caps.MaxImageBytes { + t.Errorf("len(Data) = %d, exceeds budget %d", len(ip.Data), caps.MaxImageBytes) + } + if format, _, _, err := Info(ip); err != nil || format != "jpeg" { + t.Errorf("Info = %q, %v; want jpeg, nil", format, err) + } +} + +func TestNormalizeByteBudgetImpossible(t *testing.T) { + req := imgReq(llm.Image("image/png", encPNG(t, noisy(256, 256)))) + caps := llm.Capabilities{ + MaxImagesPerReq: 1, + AllowedImageMIME: []string{"image/jpeg"}, + MaxImageBytes: 10, // no image fits in 10 bytes + } + _, err := Normalize(req, caps) + if !errors.Is(err, llm.ErrUnsupported) { + t.Fatalf("err = %v, want ErrUnsupported", err) + } + if !strings.Contains(err.Error(), "10-byte limit") { + t.Errorf("err message %q lacks the budget", err) + } + if !strings.Contains(err.Error(), "smallest achievable") { + t.Errorf("err message %q lacks the achieved size", err) + } +} + +// --- webp --------------------------------------------------------------------- + +func TestNormalizeWebPPassThrough(t *testing.T) { + data := webpBlob() + req := imgReq(llm.Image("image/webp", data)) + caps := llm.Capabilities{ + MaxImagesPerReq: 1, + MaxImageBytes: 1024, + MaxImageDimension: 50, // unverifiable for webp; must not force a transform + AllowedImageMIME: []string{"image/webp"}, + } + got, err := Normalize(req, caps) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + if &got.Messages[0] != &req.Messages[0] { + t.Error("request copied although the webp image passes through") + } + if ip := firstImage(t, got); !bytes.Equal(ip.Data, data) { + t.Error("webp bytes changed on pass-through") + } +} + +func TestNormalizeWebPNeedsTransform(t *testing.T) { + req := imgReq(llm.Image("image/webp", webpBlob())) + caps := llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/jpeg"}} + _, err := Normalize(req, caps) + if !errors.Is(err, llm.ErrUnsupported) { + t.Fatalf("err = %v, want ErrUnsupported", err) + } + if !strings.Contains(err.Error(), "webp") { + t.Errorf("err message %q does not name the format", err) + } + if !strings.Contains(err.Error(), "jpeg, png, or gif") { + t.Errorf("err message %q does not say what to provide instead", err) + } +} + +// --- input immutability ---------------------------------------------------------- + +func TestNormalizeInputNotMutated(t *testing.T) { + data := encPNG(t, gradient(200, 100)) + snapshot := bytes.Clone(data) + req := llm.Request{ + System: "sys", + Messages: []llm.Message{ + llm.UserParts(llm.Text("scale me"), llm.Image("image/jpeg", data)), + }, + } + caps := llm.Capabilities{ + MaxImagesPerReq: 1, + MaxImageDimension: 50, + AllowedImageMIME: []string{"image/jpeg"}, + } + got, err := Normalize(req, caps) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + orig := firstImage(t, req) + if orig.MIME != "image/jpeg" { + t.Errorf("input MIME mutated to %q", orig.MIME) + } + if !bytes.Equal(orig.Data, snapshot) { + t.Error("input image bytes mutated") + } + if txt := req.Messages[0].Parts[0].(llm.TextPart); txt.Text != "scale me" { + t.Errorf("input text part mutated: %q", txt.Text) + } + if ip := firstImage(t, got); bytes.Equal(ip.Data, snapshot) { + t.Error("output image was expected to transform but is byte-identical") + } +} + +// --- alpha handling ---------------------------------------------------------------- + +func TestNormalizeAlphaPNGToJPEG(t *testing.T) { + img := image.NewRGBA(image.Rect(0, 0, 32, 32)) + for y := 0; y < 32; y++ { + for x := 0; x < 32; x++ { + img.SetRGBA(x, y, color.RGBA{R: 200, G: 60, B: 30, A: uint8(x * 8)}) + } + } + req := imgReq(llm.Image("image/png", encPNG(t, img))) + caps := llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/jpeg"}} + got, err := Normalize(req, caps) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + ip := firstImage(t, got) + decoded, err := jpeg.Decode(bytes.NewReader(ip.Data)) + if err != nil { + t.Fatalf("decoding transcoded jpeg: %v", err) + } + if b := decoded.Bounds(); b.Dx() != 32 || b.Dy() != 32 { + t.Errorf("decoded dimensions = %dx%d, want 32x32", b.Dx(), b.Dy()) + } +} + +// --- Info ---------------------------------------------------------------------------- + +func TestInfo(t *testing.T) { + pngData := encPNG(t, gradient(10, 7)) + jpegData := encJPEG(t, gradient(5, 9)) + gifData := encGIF(t, gradient(6, 4)) + + tests := []struct { + name string + part llm.ImagePart + format string + w, h int + wantErr bool + }{ + {"png", llm.ImagePart{MIME: "image/png", Data: pngData}, "png", 10, 7, false}, + {"jpeg", llm.ImagePart{MIME: "image/jpeg", Data: jpegData}, "jpeg", 5, 9, false}, + {"gif", llm.ImagePart{MIME: "image/gif", Data: gifData}, "gif", 6, 4, false}, + {"mislabeled png", llm.ImagePart{MIME: "image/jpeg", Data: pngData}, "png", 10, 7, false}, + {"webp", llm.ImagePart{MIME: "image/webp", Data: webpBlob()}, "webp", 0, 0, false}, + {"garbage", llm.ImagePart{MIME: "image/png", Data: []byte("nope")}, "", 0, 0, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + format, w, h, err := Info(tt.part) + if tt.wantErr { + if err == nil { + t.Fatal("Info: expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("Info: %v", err) + } + if format != tt.format || w != tt.w || h != tt.h { + t.Errorf("Info = %q, %d, %d; want %q, %d, %d", format, w, h, tt.format, tt.w, tt.h) + } + }) + } +} + +// --- byte-cap pass-through interplay ---------------------------------------------------- + +func TestNormalizeOversizeBytesTriggersTransform(t *testing.T) { + // A fitting MIME and dimension but an over-budget payload must re-encode, + // not pass through. + data := encPNG(t, noisy(64, 64)) + req := imgReq(llm.Image("image/png", data)) + caps := llm.Capabilities{ + MaxImagesPerReq: 1, + MaxImageBytes: len(data) / 2, + AllowedImageMIME: []string{"image/png", "image/jpeg"}, + } + got, err := Normalize(req, caps) + if err != nil { + t.Fatalf("Normalize: %v", err) + } + ip := firstImage(t, got) + if len(ip.Data) > caps.MaxImageBytes { + t.Errorf("len(Data) = %d, exceeds budget %d", len(ip.Data), caps.MaxImageBytes) + } +} diff --git a/media/scale.go b/media/scale.go new file mode 100644 index 0000000..f21a8ab --- /dev/null +++ b/media/scale.go @@ -0,0 +1,54 @@ +package media + +import "image" + +// fitDims scales (w, h) so the longer side equals limit, preserving aspect +// ratio with round-half-up on the shorter side, floored at 1 pixel. +func fitDims(w, h, limit int) (int, int) { + if w >= h { + return limit, max((h*limit+w/2)/w, 1) + } + return max((w*limit+h/2)/h, 1), limit +} + +// downscale resizes src to dw x dh using area averaging (a box filter): each +// destination pixel is the mean of its corresponding source region. +// +// Why hand-rolled: the stdlib has no scaler and ADR-0007 bars +// golang.org/x/image without a new ADR. Area averaging is dependency-free, +// alias-resistant when shrinking (every source pixel contributes exactly +// once), and entirely adequate quality for vision-model input. It is only +// ever called to shrink — Normalize never upscales. +func downscale(src image.Image, dw, dh int) *image.RGBA { + b := src.Bounds() + sw, sh := b.Dx(), b.Dy() + dst := image.NewRGBA(image.Rect(0, 0, dw, dh)) + for dy := 0; dy < dh; dy++ { + // Integer box edges: destination pixel dy covers source rows + // [dy*sh/dh, (dy+1)*sh/dh), widened to at least one row. + sy0 := dy * sh / dh + sy1 := max((dy+1)*sh/dh, sy0+1) + for dx := 0; dx < dw; dx++ { + sx0 := dx * sw / dw + sx1 := max((dx+1)*sw/dw, sx0+1) + var r, g, bl, a uint64 + for sy := sy0; sy < sy1; sy++ { + for sx := sx0; sx < sx1; sx++ { + pr, pg, pb, pa := src.At(b.Min.X+sx, b.Min.Y+sy).RGBA() + r += uint64(pr) + g += uint64(pg) + bl += uint64(pb) + a += uint64(pa) + } + } + n := uint64((sy1 - sy0) * (sx1 - sx0)) + i := dst.PixOffset(dx, dy) + // RGBA() returns 16-bit channels; average, then drop to 8 bits. + dst.Pix[i+0] = uint8(r / n >> 8) + dst.Pix[i+1] = uint8(g / n >> 8) + dst.Pix[i+2] = uint8(bl / n >> 8) + dst.Pix[i+3] = uint8(a / n >> 8) + } + } + return dst +} diff --git a/parse_test.go b/parse_test.go index 6b52847..53d2bc3 100644 --- a/parse_test.go +++ b/parse_test.go @@ -63,8 +63,8 @@ func TestParseModelIDIsVerbatim(t *testing.T) { // Everything after the first slash, up to the next comma, is the model // id: colons and additional slashes pass through untouched. for spec, want := range map[string]string{ - "ollama-cloud/minimax-m3:cloud": "ollama-cloud/minimax-m3:cloud", - "google/models/gemini-3.0-pro": "google/models/gemini-3.0-pro", + "ollama-cloud/minimax-m3:cloud": "ollama-cloud/minimax-m3:cloud", + "google/models/gemini-3.0-pro": "google/models/gemini-3.0-pro", "ollama-cloud/qwen3-coder:480b-cloud": "ollama-cloud/qwen3-coder:480b-cloud", } { m, err := r.Parse(spec) diff --git a/progress.md b/progress.md index f332bb1..982a020 100644 --- a/progress.md +++ b/progress.md @@ -1,5 +1,45 @@ # progress +## 2026-06-10 — Phase 3: REST providers (OpenAI, Anthropic, Ollama×3) + media + +**Landed:** +- `provider/openai`: Chat Completions client for OpenAI and every + OpenAI-compatible endpoint (tools with string-arguments mapping, strict + SSE streaming incl. by-index tool-call assembly and the empty-choices + usage chunk, response_format json_schema, max_completion_tokens with a + WithLegacyMaxTokens compat option, reasoning_effort). +- `provider/anthropic`: Messages API client (anthropic-version 2023-06-01, + required-max_tokens defaulting, tool_use/tool_result blocks with native + is_error, GA structured output via output_config.format, full SSE event + parser with input_json_delta buffering, 529-overloaded classified + transient, usage sums cache tokens). +- `provider/ollama`: ONE native /api/chat client serving ollama (local, + OLLAMA_HOST normalization), ollama-cloud (https://ollama.com + bearer + OLLAMA_API_KEY), and foreman (base URL + bearer; tolerates its + buffered-single-object "streaming"). Object tool arguments, tool_name + results, format-schema structured output, think-level mapping, NDJSON + streaming with 16MB lines. +- `media/`: normalization pipeline per ADR-0009 (magic-byte sniffing, + box-filter downscale, transcode preference ladder, byte-budget quality + ladder, webp passthrough-or-reject, copy-on-write, everything-unfittable + wraps ErrUnsupported). +- Chain executor now normalizes media PER TARGET before each attempt and + advances penalty-free past targets that can't take the request (proven: + text-only head + vision fallback; per-target downscale assertions). +- Registry: real providers + scheme factories wired for openai, anthropic, + ollama, ollama-cloud, foreman (google still stubbed, Phase 4); + WithHTTPClient registry option; required env-foreman TLS chat round-trip + test (LLM_FM=foreman://token@host → Parse("fm/qwen3:30b") → bearer + arrives, chat answers). +- ADR-0009 (multimodal), ADR-0010 (tools/structured mapping); README + matrix flipped to ✅ for the four landed provider families; ~70 new + hermetic tests across the three provider packages + media. +- Run note: openai/anthropic/media were built by three parallel + subagents against the frozen llm contract; ollama/foreman, chain wiring, + and registry integration done in the main line. All gates green. + +**Next:** Phase 4 — Google provider on google.golang.org/genai. + ## 2026-06-10 — Phase 2: health + failover chain, proven **Landed:** the full deterministic failover test matrix over the fake diff --git a/provider/anthropic/anthropic.go b/provider/anthropic/anthropic.go new file mode 100644 index 0000000..b9b0309 --- /dev/null +++ b/provider/anthropic/anthropic.go @@ -0,0 +1,319 @@ +// Package anthropic implements llm.Provider for the Anthropic Messages API +// and Anthropic-compatible endpoints. +// +// API surface targeted: POST {base}/v1/messages with headers x-api-key, +// anthropic-version: 2023-06-01, and content-type: application/json, per the +// platform.claude.com Messages API reference as of June 2026. Streaming uses +// the documented SSE event sequence (message_start, content_block_start, +// content_block_delta, content_block_stop, message_delta, message_stop). +// Structured output uses the GA output_config.format mechanism with +// {"type":"json_schema"}; the result arrives as JSON text in the first text +// content block. +// +// Why a hand-rolled client (no SDK): ADR-0007 — majordomo is stdlib-first, +// and the canonical llm contract needs only a narrow slice of the API. +package anthropic + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +const ( + defaultName = "anthropic" + defaultBaseURL = "https://api.anthropic.com" + + // apiVersion is the anthropic-version header value. 2023-06-01 remains + // the current (and only) stable version string as of June 2026. + apiVersion = "2023-06-01" + + // defaultMaxTokens is used when Request.MaxTokens is 0, because the + // Messages API requires max_tokens on every request. + defaultMaxTokens = 4096 +) + +// defaultCapabilities reflects the documented first-party API image limits: +// 100 images per request (200K-context models), 10 MB per image, 8000 px per +// side, and the four supported media types. +func defaultCapabilities() llm.Capabilities { + return llm.Capabilities{ + SupportsTools: true, + SupportsStructured: true, + SupportsStreaming: true, + MaxImagesPerReq: 100, + MaxImageBytes: 10 << 20, + MaxImageDimension: 8000, + AllowedImageMIME: []string{ + "image/jpeg", "image/png", "image/gif", "image/webp", + }, + } +} + +// Provider is an llm.Provider backed by the Anthropic Messages API. +type Provider struct { + name string + apiKey string + baseURL string + client *http.Client + caps llm.Capabilities + maxTokens int +} + +// Option configures the provider at construction. +type Option func(*Provider) + +// WithAPIKey sets the API key explicitly, bypassing the ANTHROPIC_API_KEY +// environment default. +func WithAPIKey(key string) Option { + return func(p *Provider) { p.apiKey = key } +} + +// WithBaseURL points the provider at an Anthropic-compatible endpoint. A +// trailing slash is trimmed; "/v1/messages" is appended per request. +func WithBaseURL(u string) Option { + return func(p *Provider) { p.baseURL = strings.TrimRight(u, "/") } +} + +// WithHTTPClient replaces the HTTP client (timeouts, proxies, test doubles). +func WithHTTPClient(c *http.Client) Option { + return func(p *Provider) { p.client = c } +} + +// WithName overrides the registry name. Why: an Anthropic-compatible +// endpoint registered under its own name must surface that name in +// Response.Model and errors, not "anthropic". +func WithName(name string) Option { + return func(p *Provider) { p.name = name } +} + +// WithDefaultCapabilities replaces the provider-default capabilities. +func WithDefaultCapabilities(caps llm.Capabilities) Option { + return func(p *Provider) { p.caps = caps } +} + +// WithDefaultMaxTokens overrides the max_tokens value used when +// Request.MaxTokens is 0. Why: the Messages API rejects requests without +// max_tokens, so the provider must always send something. +func WithDefaultMaxTokens(n int) Option { + return func(p *Provider) { p.maxTokens = n } +} + +// New creates an Anthropic provider. It never fails: a missing API key +// (no WithAPIKey and no ANTHROPIC_API_KEY in the environment) surfaces as a +// 401-style *llm.APIError at request time, not at construction. +func New(opts ...Option) *Provider { + p := &Provider{ + name: defaultName, + baseURL: defaultBaseURL, + client: http.DefaultClient, + caps: defaultCapabilities(), + maxTokens: defaultMaxTokens, + } + for _, opt := range opts { + opt(p) + } + if p.apiKey == "" { + p.apiKey = os.Getenv("ANTHROPIC_API_KEY") + } + return p +} + +// Name implements llm.Provider. +func (p *Provider) Name() string { return p.name } + +// Model implements llm.Provider. The id is passed through verbatim — it is +// never validated against a catalog. +func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) { + cfg := llm.ApplyModelOptions(opts) + caps := p.caps + if cfg.Capabilities != nil { + caps = *cfg.Capabilities + } + return &model{provider: p, id: id, caps: caps}, nil +} + +type model struct { + provider *Provider + id string + caps llm.Capabilities +} + +// Capabilities implements llm.Model. +func (m *model) Capabilities() llm.Capabilities { return m.caps } + +// fullName is the "provider/model" identifier used in Response.Model. +func (m *model) fullName() string { return m.provider.name + "/" + m.id } + +// Generate implements llm.Model. +func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) { + req = req.Apply(opts...) + if err := m.enforceCapabilities(req); err != nil { + return nil, err + } + httpResp, err := m.do(ctx, req, false) + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + if httpResp.StatusCode/100 != 2 { + return nil, m.apiError(httpResp) + } + var wr wireResponse + if err := json.NewDecoder(httpResp.Body).Decode(&wr); err != nil { + return nil, fmt.Errorf("%s: decode response: %w", m.provider.name, err) + } + return m.toResponse(&wr), nil +} + +// Stream implements llm.Model. A non-2xx status is returned as an error from +// Stream itself, before any events are delivered. +func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) { + req = req.Apply(opts...) + if err := m.enforceCapabilities(req); err != nil { + return nil, err + } + httpResp, err := m.do(ctx, req, true) + if err != nil { + return nil, err + } + if httpResp.StatusCode/100 != 2 { + defer httpResp.Body.Close() + return nil, m.apiError(httpResp) + } + return newStream(m, httpResp.Body), nil +} + +// enforceCapabilities is the honest backstop behind the media layer: it +// rejects (rather than silently mutates) requests the target cannot serve. +// Why: a separate media layer resizes/transcodes images BEFORE requests +// reach the provider, so anything still out of bounds here is a real error. +func (m *model) enforceCapabilities(req llm.Request) error { + images := 0 + for _, msg := range req.Messages { + for _, part := range msg.Parts { + img, ok := part.(llm.ImagePart) + if !ok { + continue + } + images++ + if !m.caps.SupportsImages() { + return fmt.Errorf("%w: %s does not accept image input", llm.ErrUnsupported, m.fullName()) + } + if !m.caps.MIMEAllowed(img.MIME) { + return fmt.Errorf("%w: %s does not accept image MIME %q", llm.ErrUnsupported, m.fullName(), img.MIME) + } + if m.caps.MaxImageBytes > 0 && len(img.Data) > m.caps.MaxImageBytes { + return fmt.Errorf("%w: image of %d bytes exceeds %s limit of %d bytes", + llm.ErrUnsupported, len(img.Data), m.fullName(), m.caps.MaxImageBytes) + } + } + } + if m.caps.MaxImagesPerReq > 0 && images > m.caps.MaxImagesPerReq { + return fmt.Errorf("%w: request carries %d images, %s allows at most %d", + llm.ErrUnsupported, images, m.fullName(), m.caps.MaxImagesPerReq) + } + return nil +} + +// do builds and executes one Messages API call. Transport errors are wrapped +// with context but NOT converted to *llm.APIError, so llm.Classify still +// sees the underlying net.Error / syscall errno. +func (m *model) do(ctx context.Context, req llm.Request, streaming bool) (*http.Response, error) { + p := m.provider + if p.apiKey == "" { + // Why request-time, not construction-time: New never fails by + // convention, and a 401-shaped APIError classifies permanent so + // chains fail fast past a misconfigured target. + return nil, &llm.APIError{ + Provider: p.name, + Model: m.id, + Status: http.StatusUnauthorized, + Code: "authentication_error", + Message: "no API key configured: set ANTHROPIC_API_KEY or use WithAPIKey", + } + } + + body, err := json.Marshal(buildWireRequest(m.id, req, p.maxTokens, streaming)) + if err != nil { + return nil, fmt.Errorf("%s: encode request: %w", p.name, err) + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/messages", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("%s: build request: %w", p.name, err) + } + httpReq.Header.Set("x-api-key", p.apiKey) + httpReq.Header.Set("anthropic-version", apiVersion) + httpReq.Header.Set("content-type", "application/json") + if streaming { + httpReq.Header.Set("accept", "text/event-stream") + } + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("%s: do request: %w", p.name, err) + } + return resp, nil +} + +// apiError converts a non-2xx response into *llm.APIError, filling Code and +// Message from the documented {"type":"error","error":{...}} body when it +// parses, and falling back to the raw body text when it does not. +func (m *model) apiError(resp *http.Response) error { + apiErr := &llm.APIError{ + Provider: m.provider.name, + Model: m.id, + Status: resp.StatusCode, + } + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return apiErr + } + var we wireErrorEnvelope + if json.Unmarshal(body, &we) == nil && we.Error.Type != "" { + apiErr.Code = we.Error.Type + apiErr.Message = we.Error.Message + } else { + apiErr.Message = strings.TrimSpace(string(body)) + } + return apiErr +} + +// toResponse maps a wire response onto the canonical llm.Response. Thinking +// and other unrecognized block types are tolerated and skipped — they are +// not part of the canonical content vocabulary. +func (m *model) toResponse(wr *wireResponse) *llm.Response { + resp := &llm.Response{ + FinishReason: mapStopReason(wr.StopReason), + Usage: wr.Usage.toUsage(), + Model: m.fullName(), + Raw: wr, + } + for _, block := range wr.Content { + switch block.Type { + case "text": + resp.Parts = append(resp.Parts, llm.TextPart{Text: block.Text}) + case "tool_use": + args := block.Input + if len(args) == 0 { + args = json.RawMessage("{}") + } + resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{ + ID: block.ID, + Name: block.Name, + Arguments: args, + }) + default: + // thinking, redacted_thinking, server-tool blocks, and any + // future types are skipped, not surfaced as parts. + } + } + return resp +} diff --git a/provider/anthropic/anthropic_test.go b/provider/anthropic/anthropic_test.go new file mode 100644 index 0000000..73dde5d --- /dev/null +++ b/provider/anthropic/anthropic_test.go @@ -0,0 +1,774 @@ +package anthropic + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// okBody is a minimal successful Messages API response. +const okBody = `{ + "id": "msg_01", + "type": "message", + "role": "assistant", + "model": "claude-test", + "content": [{"type": "text", "text": "ok"}], + "stop_reason": "end_turn", + "usage": {"input_tokens": 3, "output_tokens": 5} +}` + +// capture records the last request the test server received. +type capture struct { + mu sync.Mutex + hits int + method string + path string + header http.Header + body []byte +} + +func (c *capture) handler(status int, respBody string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + c.mu.Lock() + c.hits++ + c.method = r.Method + c.path = r.URL.Path + c.header = r.Header.Clone() + c.body = body + c.mu.Unlock() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _, _ = w.Write([]byte(respBody)) + } +} + +// bodyMap decodes the captured request body for key-presence assertions. +func (c *capture) bodyMap(t *testing.T) map[string]any { + t.Helper() + c.mu.Lock() + defer c.mu.Unlock() + var m map[string]any + if err := json.Unmarshal(c.body, &m); err != nil { + t.Fatalf("decode captured body: %v\nbody: %s", err, c.body) + } + return m +} + +// newTestProvider spins up an httptest server and a provider pointed at it. +func newTestProvider(t *testing.T, h http.Handler, opts ...Option) *Provider { + t.Helper() + srv := httptest.NewServer(h) + t.Cleanup(srv.Close) + return New(append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, opts...)...) +} + +func mustModel(t *testing.T, p *Provider, id string, opts ...llm.ModelOption) llm.Model { + t.Helper() + m, err := p.Model(id, opts...) + if err != nil { + t.Fatalf("Model(%q): %v", id, err) + } + return m +} + +func generate(t *testing.T, m llm.Model, req llm.Request, opts ...llm.Option) *llm.Response { + t.Helper() + resp, err := m.Generate(context.Background(), req, opts...) + if err != nil { + t.Fatalf("Generate: %v", err) + } + return resp +} + +func TestRequestHeadersAndPath(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + m := mustModel(t, p, "claude-test") + + generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + + if c.method != http.MethodPost { + t.Errorf("method = %q, want POST", c.method) + } + if c.path != "/v1/messages" { + t.Errorf("path = %q, want /v1/messages", c.path) + } + for header, want := range map[string]string{ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + "content-type": "application/json", + } { + if got := c.header.Get(header); got != want { + t.Errorf("header %s = %q, want %q", header, got, want) + } + } +} + +func TestSystemFold(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + m := mustModel(t, p, "claude-test") + + generate(t, m, llm.Request{ + System: "base prompt", + Messages: []llm.Message{ + llm.SystemText("first extra"), + llm.UserText("hi"), + llm.SystemText("second extra"), + }, + }) + + body := c.bodyMap(t) + if got, want := body["system"], "base prompt\n\nfirst extra\n\nsecond extra"; got != want { + t.Errorf("system = %q, want %q", got, want) + } + msgs := body["messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("messages length = %d, want 1 (system messages must be excluded)", len(msgs)) + } + if role := msgs[0].(map[string]any)["role"]; role != "user" { + t.Errorf("remaining message role = %q, want user", role) + } +} + +func TestNoSystemOmitsField(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + m := mustModel(t, p, "claude-test") + + generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + + if _, ok := c.bodyMap(t)["system"]; ok { + t.Error("system key present, want omitted when empty") + } +} + +func TestMaxTokens(t *testing.T) { + t.Run("default 4096", func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if got := c.bodyMap(t)["max_tokens"].(float64); got != 4096 { + t.Errorf("max_tokens = %v, want 4096", got) + } + }) + t.Run("explicit wins", func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + generate(t, mustModel(t, p, "m"), llm.Request{ + Messages: []llm.Message{llm.UserText("hi")}, + MaxTokens: 123, + }) + if got := c.bodyMap(t)["max_tokens"].(float64); got != 123 { + t.Errorf("max_tokens = %v, want 123", got) + } + }) + t.Run("WithDefaultMaxTokens overrides default", func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithDefaultMaxTokens(99)) + generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if got := c.bodyMap(t)["max_tokens"].(float64); got != 99 { + t.Errorf("max_tokens = %v, want 99", got) + } + }) +} + +func TestImageBlock(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + m := mustModel(t, p, "claude-test") + + raw := []byte{0x01, 0x02, 0x03} + generate(t, m, llm.Request{Messages: []llm.Message{ + llm.UserParts(llm.Text("look at this"), llm.Image("image/png", raw)), + }}) + + msgs := c.bodyMap(t)["messages"].([]any) + content := msgs[0].(map[string]any)["content"].([]any) + if len(content) != 2 { + t.Fatalf("content blocks = %d, want 2", len(content)) + } + img := content[1].(map[string]any) + if img["type"] != "image" { + t.Fatalf("block type = %v, want image", img["type"]) + } + src := img["source"].(map[string]any) + if src["type"] != "base64" { + t.Errorf("source type = %v, want base64", src["type"]) + } + if src["media_type"] != "image/png" { + t.Errorf("media_type = %v, want image/png", src["media_type"]) + } + if want := base64.StdEncoding.EncodeToString(raw); src["data"] != want { + t.Errorf("data = %v, want %q", src["data"], want) + } +} + +func TestToolUseToolResultRoundTrip(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + m := mustModel(t, p, "claude-test") + + generate(t, m, llm.Request{Messages: []llm.Message{ + llm.UserText("weather?"), + { + Role: llm.RoleAssistant, + Parts: []llm.Part{llm.Text("checking")}, + ToolCalls: []llm.ToolCall{ + {ID: "toolu_1", Name: "get_weather", Arguments: json.RawMessage(`{"location":"Paris"}`)}, + {ID: "toolu_2", Name: "noop"}, // empty args must become {} + }, + }, + llm.ToolResultsMessage( + llm.ToolResult{ID: "toolu_1", Name: "get_weather", Content: "72F and sunny"}, + llm.ToolResult{ID: "toolu_2", Name: "noop", Content: "boom", IsError: true}, + ), + }}) + + msgs := c.bodyMap(t)["messages"].([]any) + if len(msgs) != 3 { + t.Fatalf("messages = %d, want 3", len(msgs)) + } + + asst := msgs[1].(map[string]any) + if asst["role"] != "assistant" { + t.Errorf("messages[1].role = %v, want assistant", asst["role"]) + } + asstContent := asst["content"].([]any) + if len(asstContent) != 3 { + t.Fatalf("assistant blocks = %d, want 3 (text + 2 tool_use)", len(asstContent)) + } + tu := asstContent[1].(map[string]any) + if tu["type"] != "tool_use" || tu["id"] != "toolu_1" || tu["name"] != "get_weather" { + t.Errorf("tool_use block = %v", tu) + } + if loc := tu["input"].(map[string]any)["location"]; loc != "Paris" { + t.Errorf("tool_use input.location = %v, want Paris", loc) + } + if input := asstContent[2].(map[string]any)["input"].(map[string]any); len(input) != 0 { + t.Errorf("empty-args tool_use input = %v, want {}", input) + } + + // RoleTool → ONE user message with one tool_result block per result. + toolMsg := msgs[2].(map[string]any) + if toolMsg["role"] != "user" { + t.Errorf("messages[2].role = %v, want user", toolMsg["role"]) + } + results := toolMsg["content"].([]any) + if len(results) != 2 { + t.Fatalf("tool_result blocks = %d, want 2", len(results)) + } + first := results[0].(map[string]any) + if first["type"] != "tool_result" || first["tool_use_id"] != "toolu_1" || first["content"] != "72F and sunny" { + t.Errorf("first tool_result = %v", first) + } + if _, ok := first["is_error"]; ok { + t.Error("first tool_result has is_error, want omitted when false") + } + second := results[1].(map[string]any) + if second["tool_use_id"] != "toolu_2" || second["is_error"] != true { + t.Errorf("second tool_result = %v, want is_error true", second) + } +} + +func TestToolDefinitions(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + m := mustModel(t, p, "claude-test") + + schema := json.RawMessage(`{"type":"object","properties":{"q":{"type":"string"}},"required":["q"]}`) + generate(t, m, llm.Request{ + Messages: []llm.Message{llm.UserText("hi")}, + Tools: []llm.Tool{ + {Name: "search", Description: "Search the web.", Parameters: schema}, + {Name: "ping"}, // nil Parameters → default empty object schema + }, + }) + + tools := c.bodyMap(t)["tools"].([]any) + if len(tools) != 2 { + t.Fatalf("tools = %d, want 2", len(tools)) + } + search := tools[0].(map[string]any) + if search["name"] != "search" || search["description"] != "Search the web." { + t.Errorf("tool[0] = %v", search) + } + if typ := search["input_schema"].(map[string]any)["type"]; typ != "object" { + t.Errorf("input_schema.type = %v, want object", typ) + } + ping := tools[1].(map[string]any) + if typ := ping["input_schema"].(map[string]any)["type"]; typ != "object" { + t.Errorf("nil-Parameters input_schema.type = %v, want object", typ) + } +} + +func TestToolChoiceForms(t *testing.T) { + cases := []struct { + choice string + wantType string // "" means the field must be absent + wantName string + }{ + {choice: "", wantType: ""}, + {choice: "auto", wantType: "auto"}, + {choice: "required", wantType: "any"}, + {choice: "none", wantType: "none"}, + {choice: "get_weather", wantType: "tool", wantName: "get_weather"}, + } + for _, tc := range cases { + t.Run("choice="+tc.choice, func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + generate(t, mustModel(t, p, "m"), llm.Request{ + Messages: []llm.Message{llm.UserText("hi")}, + ToolChoice: tc.choice, + }) + body := c.bodyMap(t) + raw, present := body["tool_choice"] + if tc.wantType == "" { + if present { + t.Fatalf("tool_choice present (%v), want omitted", raw) + } + return + } + choice := raw.(map[string]any) + if choice["type"] != tc.wantType { + t.Errorf("tool_choice.type = %v, want %q", choice["type"], tc.wantType) + } + if tc.wantName != "" && choice["name"] != tc.wantName { + t.Errorf("tool_choice.name = %v, want %q", choice["name"], tc.wantName) + } + }) + } +} + +func TestOutputConfigFormat(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + m := mustModel(t, p, "claude-test") + + schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}},"required":["name"],"additionalProperties":false}`) + generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}}, + llm.WithSchema(schema, "person")) + + body := c.bodyMap(t) + format := body["output_config"].(map[string]any)["format"].(map[string]any) + if format["type"] != "json_schema" { + t.Errorf("output_config.format.type = %v, want json_schema", format["type"]) + } + // Normalize both sides through any → Marshal (sorted keys) to compare. + got, _ := json.Marshal(format["schema"]) + var want any + _ = json.Unmarshal(schema, &want) + wantJSON, _ := json.Marshal(want) + if string(got) != string(wantJSON) { + t.Errorf("schema = %s, want %s", got, wantJSON) + } +} + +func TestOutputConfigOmittedWithoutSchema(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if _, ok := c.bodyMap(t)["output_config"]; ok { + t.Error("output_config present, want omitted when Schema is nil") + } +} + +func TestSamplingKnobs(t *testing.T) { + t.Run("omitted when unset", func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + body := c.bodyMap(t) + if _, ok := body["temperature"]; ok { + t.Error("temperature present, want omitted when unset") + } + if _, ok := body["top_p"]; ok { + t.Error("top_p present, want omitted when unset") + } + if _, ok := body["stop_sequences"]; ok { + t.Error("stop_sequences present, want omitted when unset") + } + }) + t.Run("present when set", func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + generate(t, mustModel(t, p, "m"), + llm.Request{Messages: []llm.Message{llm.UserText("hi")}}, + llm.WithTemperature(0), // explicit zero must still be sent + llm.WithTopP(0.9), + llm.WithStopSequences("END")) + body := c.bodyMap(t) + if got, ok := body["temperature"]; !ok || got.(float64) != 0 { + t.Errorf("temperature = %v (present=%v), want explicit 0", got, ok) + } + if got := body["top_p"].(float64); got != 0.9 { + t.Errorf("top_p = %v, want 0.9", got) + } + stops := body["stop_sequences"].([]any) + if len(stops) != 1 || stops[0] != "END" { + t.Errorf("stop_sequences = %v, want [END]", stops) + } + }) +} + +func TestStreamFieldOmittedOnGenerate(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if _, ok := c.bodyMap(t)["stream"]; ok { + t.Error("stream key present on Generate, want omitted") + } +} + +func TestResponseParse(t *testing.T) { + const body = `{ + "id": "msg_02", + "type": "message", + "role": "assistant", + "model": "claude-test", + "content": [ + {"type": "thinking", "thinking": "pondering...", "signature": "sig"}, + {"type": "text", "text": "I'll check the weather."}, + {"type": "tool_use", "id": "toolu_9", "name": "get_weather", "input": {"location": "Paris"}} + ], + "stop_reason": "tool_use", + "usage": { + "input_tokens": 3, + "output_tokens": 7, + "cache_creation_input_tokens": 10, + "cache_read_input_tokens": 20 + } + }` + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, body)) + resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + + if len(resp.Parts) != 1 { + t.Fatalf("parts = %d, want 1 (thinking blocks must be skipped)", len(resp.Parts)) + } + if got := resp.Text(); got != "I'll check the weather." { + t.Errorf("text = %q", got) + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("tool calls = %d, want 1", len(resp.ToolCalls)) + } + call := resp.ToolCalls[0] + if call.ID != "toolu_9" || call.Name != "get_weather" { + t.Errorf("tool call = %+v", call) + } + var args map[string]any + if err := json.Unmarshal(call.Arguments, &args); err != nil || args["location"] != "Paris" { + t.Errorf("arguments = %s (err %v), want location Paris", call.Arguments, err) + } + if resp.FinishReason != llm.FinishToolCalls { + t.Errorf("finish = %q, want %q", resp.FinishReason, llm.FinishToolCalls) + } + // Total real input = input + cache_creation + cache_read. + if resp.Usage.InputTokens != 33 || resp.Usage.OutputTokens != 7 { + t.Errorf("usage = %+v, want {33 7}", resp.Usage) + } + if resp.Model != "anthropic/claude-test" { + t.Errorf("model = %q, want anthropic/claude-test", resp.Model) + } + if resp.Raw == nil { + t.Error("Raw = nil, want wire response") + } +} + +func TestStopReasonMapping(t *testing.T) { + cases := map[string]llm.FinishReason{ + "end_turn": llm.FinishStop, + "stop_sequence": llm.FinishStop, + "max_tokens": llm.FinishLength, + "model_context_window_exceeded": llm.FinishLength, + "tool_use": llm.FinishToolCalls, + "refusal": llm.FinishContentFilter, + "pause_turn": llm.FinishOther, + "some_future_reason": llm.FinishOther, + } + for stop, want := range cases { + if got := mapStopReason(stop); got != want { + t.Errorf("mapStopReason(%q) = %q, want %q", stop, got, want) + } + } +} + +func TestHTTPErrorMapping(t *testing.T) { + cases := []struct { + name string + status int + body string + wantCode string + wantClass llm.ErrorClass + }{ + { + name: "429 rate limit is transient", + status: http.StatusTooManyRequests, + body: `{"type":"error","error":{"type":"rate_limit_error","message":"slow down"}}`, + wantCode: "rate_limit_error", wantClass: llm.ClassTransient, + }, + { + name: "529 overloaded is transient", + status: 529, + body: `{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`, + wantCode: "overloaded_error", wantClass: llm.ClassTransient, + }, + { + name: "401 auth is permanent", + status: http.StatusUnauthorized, + body: `{"type":"error","error":{"type":"authentication_error","message":"invalid x-api-key"}}`, + wantCode: "authentication_error", wantClass: llm.ClassPermanent, + }, + { + name: "404 is permanent", + status: http.StatusNotFound, + body: `{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`, + wantCode: "not_found_error", wantClass: llm.ClassPermanent, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(tc.status, tc.body)) + _, err := mustModel(t, p, "claude-test").Generate(context.Background(), + llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err == nil { + t.Fatal("Generate succeeded, want error") + } + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("error %T (%v), want *llm.APIError", err, err) + } + if apiErr.Provider != "anthropic" || apiErr.Model != "claude-test" { + t.Errorf("provider/model = %s/%s", apiErr.Provider, apiErr.Model) + } + if apiErr.Status != tc.status { + t.Errorf("status = %d, want %d", apiErr.Status, tc.status) + } + if apiErr.Code != tc.wantCode { + t.Errorf("code = %q, want %q", apiErr.Code, tc.wantCode) + } + if apiErr.Message == "" { + t.Error("message empty, want provider message") + } + if got := llm.Classify(err); got != tc.wantClass { + t.Errorf("Classify = %v, want %v", got, tc.wantClass) + } + }) + } + + t.Run("404 unwraps to ErrModelNotFound", func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusNotFound, + `{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`)) + _, err := mustModel(t, p, "missing").Generate(context.Background(), + llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if !errors.Is(err, llm.ErrModelNotFound) { + t.Errorf("errors.Is(err, ErrModelNotFound) = false for %v", err) + } + }) + + t.Run("non-JSON error body falls back to raw text", func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusBadGateway, "upstream exploded")) + _, err := mustModel(t, p, "m").Generate(context.Background(), + llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("error %T, want *llm.APIError", err) + } + if apiErr.Status != http.StatusBadGateway || apiErr.Message != "upstream exploded" { + t.Errorf("apiErr = %+v", apiErr) + } + }) +} + +func TestMissingAPIKey(t *testing.T) { + t.Setenv("ANTHROPIC_API_KEY", "") // isolate from any real environment + + var c capture + srv := httptest.NewServer(c.handler(http.StatusOK, okBody)) + t.Cleanup(srv.Close) + + p := New(WithBaseURL(srv.URL)) // construction must not fail + _, err := mustModel(t, p, "claude-test").Generate(context.Background(), + llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("error %T (%v), want *llm.APIError", err, err) + } + if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" { + t.Errorf("apiErr = %+v, want 401 authentication_error", apiErr) + } + if llm.Classify(err) != llm.ClassPermanent { + t.Error("missing key must classify permanent") + } + if c.hits != 0 { + t.Errorf("server hits = %d, want 0 (no request without a key)", c.hits) + } +} + +func TestAPIKeyFromEnv(t *testing.T) { + t.Setenv("ANTHROPIC_API_KEY", "env-key") + + var c capture + srv := httptest.NewServer(c.handler(http.StatusOK, okBody)) + t.Cleanup(srv.Close) + + p := New(WithBaseURL(srv.URL)) + generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if got := c.header.Get("x-api-key"); got != "env-key" { + t.Errorf("x-api-key = %q, want env-key", got) + } +} + +func TestCapabilityEnforcement(t *testing.T) { + img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) } + + cases := []struct { + name string + caps *llm.Capabilities // nil = provider defaults + req llm.Request + }{ + { + name: "images unsupported", + caps: &llm.Capabilities{}, // MaxImagesPerReq 0 = no images + req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 4))}}, + }, + { + name: "too many images", + caps: &llm.Capabilities{MaxImagesPerReq: 1}, + req: llm.Request{Messages: []llm.Message{ + llm.UserParts(img("image/png", 4), img("image/png", 4)), + }}, + }, + { + name: "disallowed MIME", + req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/bmp", 4))}}, + }, + { + name: "image too large", + caps: &llm.Capabilities{MaxImagesPerReq: 1, MaxImageBytes: 2}, + req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 3))}}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + var opts []llm.ModelOption + if tc.caps != nil { + opts = append(opts, llm.WithCapabilities(*tc.caps)) + } + m := mustModel(t, p, "claude-test", opts...) + + _, err := m.Generate(context.Background(), tc.req) + if !errors.Is(err, llm.ErrUnsupported) { + t.Errorf("Generate err = %v, want ErrUnsupported", err) + } + _, err = m.Stream(context.Background(), tc.req) + if !errors.Is(err, llm.ErrUnsupported) { + t.Errorf("Stream err = %v, want ErrUnsupported", err) + } + if c.hits != 0 { + t.Errorf("server hits = %d, want 0 (rejected before sending)", c.hits) + } + }) + } + + t.Run("within limits passes", func(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody)) + generate(t, mustModel(t, p, "m"), llm.Request{ + Messages: []llm.Message{llm.UserParts(llm.Text("ok"), img("image/jpeg", 16))}, + }) + if c.hits != 1 { + t.Errorf("server hits = %d, want 1", c.hits) + } + }) +} + +func TestCompatEndpointWithNameAndBaseURL(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithName("compat")) + if p.Name() != "compat" { + t.Errorf("Name() = %q, want compat", p.Name()) + } + resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if resp.Model != "compat/claude-test" { + t.Errorf("resp.Model = %q, want compat/claude-test", resp.Model) + } + + var ec capture + pe := newTestProvider(t, ec.handler(http.StatusTooManyRequests, + `{"type":"error","error":{"type":"rate_limit_error","message":"x"}}`), WithName("compat")) + _, err := mustModel(t, pe, "m").Generate(context.Background(), + llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok || apiErr.Provider != "compat" { + t.Errorf("error provider = %v, want compat (err %v)", apiErr, err) + } +} + +func TestCapabilitiesDefaultsAndOverrides(t *testing.T) { + p := New(WithAPIKey("k")) + m := mustModel(t, p, "m") + caps := m.Capabilities() + if !caps.SupportsTools || !caps.SupportsStructured || !caps.SupportsStreaming { + t.Errorf("default feature flags = %+v, want all true", caps) + } + if caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 10<<20 || caps.MaxImageDimension != 8000 { + t.Errorf("default image limits = %+v", caps) + } + wantMIME := []string{"image/jpeg", "image/png", "image/gif", "image/webp"} + if len(caps.AllowedImageMIME) != len(wantMIME) { + t.Fatalf("AllowedImageMIME = %v, want %v", caps.AllowedImageMIME, wantMIME) + } + for i, mime := range wantMIME { + if caps.AllowedImageMIME[i] != mime { + t.Errorf("AllowedImageMIME[%d] = %q, want %q", i, caps.AllowedImageMIME[i], mime) + } + } + + custom := llm.Capabilities{SupportsStreaming: true, MaxImagesPerReq: 1} + p2 := New(WithAPIKey("k"), WithDefaultCapabilities(custom)) + if got := mustModel(t, p2, "m").Capabilities(); got.MaxImagesPerReq != 1 || got.SupportsTools { + t.Errorf("WithDefaultCapabilities not applied: %+v", got) + } + + perModel := llm.Capabilities{SupportsTools: true} + if got := mustModel(t, p2, "m", llm.WithCapabilities(perModel)).Capabilities(); !got.SupportsTools || got.MaxImagesPerReq != 0 { + t.Errorf("per-model capabilities not applied: %+v", got) + } +} + +func TestTransportErrorNotAPIError(t *testing.T) { + // Point at a server that is immediately closed: the connection failure + // must surface as a wrapped transport error, not *llm.APIError. + srv := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + url := srv.URL + srv.Close() + + p := New(WithAPIKey("k"), WithBaseURL(url)) + _, err := mustModel(t, p, "m").Generate(context.Background(), + llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err == nil { + t.Fatal("Generate succeeded, want transport error") + } + if _, ok := errors.AsType[*llm.APIError](err); ok { + t.Errorf("transport error wrapped in APIError: %v", err) + } + if llm.Classify(err) != llm.ClassTransient { + t.Errorf("connection failure must classify transient: %v", err) + } +} diff --git a/provider/anthropic/stream.go b/provider/anthropic/stream.go new file mode 100644 index 0000000..cd6cc2a --- /dev/null +++ b/provider/anthropic/stream.go @@ -0,0 +1,247 @@ +package anthropic + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "strings" + "sync" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// wireStreamEvent is the union of all SSE data payloads the Messages API +// emits. Dispatch is on Type (the data always carries one), so the SSE +// "event:" line is informational only. +type wireStreamEvent struct { + Type string `json:"type"` + Index int `json:"index"` + + // message_start + Message *struct { + Usage wireUsage `json:"usage"` + } `json:"message"` + + // content_block_start + ContentBlock *struct { + Type string `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + } `json:"content_block"` + + // content_block_delta / message_delta + Delta struct { + Type string `json:"type"` + Text string `json:"text"` + PartialJSON string `json:"partial_json"` + StopReason string `json:"stop_reason"` + } `json:"delta"` + + // message_delta + Usage *wireUsage `json:"usage"` + + // error + Error *struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` +} + +// stream adapts the Messages API SSE stream to llm.Stream. +// +// Why single-threaded pull (no reader goroutine): Next is already the +// consumer's pull point, so parsing lazily inside Next keeps cancellation, +// buffering, and error propagation trivial — Close just closes the body and +// the next read fails. +type stream struct { + provider string + model string + full string // provider/model + body io.ReadCloser + scanner *bufio.Scanner + + // accumulated response + parts []llm.Part + toolCalls []llm.ToolCall + usage llm.Usage + finish llm.FinishReason + + // current content block state + blockType string + textBuf strings.Builder + toolID string + toolName string + argsBuf strings.Builder + + done bool // final Response event emitted + closeOnce sync.Once + closeErr error +} + +func newStream(m *model, body io.ReadCloser) *stream { + sc := bufio.NewScanner(body) + // Why a large limit: one SSE line carries one whole delta; default 64K + // can be exceeded by large structured-output or tool-argument deltas. + sc.Buffer(make([]byte, 0, 64*1024), 10*1024*1024) + return &stream{ + provider: m.provider.name, + model: m.id, + full: m.fullName(), + body: body, + scanner: sc, + finish: llm.FinishOther, + } +} + +// Close implements llm.Stream. Safe to call at any time and more than once. +func (s *stream) Close() error { + s.closeOnce.Do(func() { s.closeErr = s.body.Close() }) + return s.closeErr +} + +// Next implements llm.Stream. It emits TextDelta fragments as they arrive, +// fully-assembled ToolCalls at content_block_stop, exactly one final +// Response event at message_stop, then io.EOF. +func (s *stream) Next() (llm.StreamEvent, error) { + if s.done { + return llm.StreamEvent{}, io.EOF + } + for { + data, err := s.nextData() + if err != nil { + return llm.StreamEvent{}, err + } + var ev wireStreamEvent + if err := json.Unmarshal([]byte(data), &ev); err != nil { + return llm.StreamEvent{}, fmt.Errorf("%s: decode stream event: %w", s.provider, err) + } + + switch ev.Type { + case "message_start": + if ev.Message != nil { + s.usage = ev.Message.Usage.toUsage() + } + + case "content_block_start": + s.blockType = "" + s.textBuf.Reset() + s.argsBuf.Reset() + if ev.ContentBlock != nil { + s.blockType = ev.ContentBlock.Type + if s.blockType == "tool_use" { + s.toolID = ev.ContentBlock.ID + s.toolName = ev.ContentBlock.Name + } + } + + case "content_block_delta": + switch ev.Delta.Type { + case "text_delta": + s.textBuf.WriteString(ev.Delta.Text) + return llm.StreamEvent{TextDelta: ev.Delta.Text}, nil + case "input_json_delta": + // Buffer partial JSON internally; consumers never see it. + s.argsBuf.WriteString(ev.Delta.PartialJSON) + default: + // thinking_delta / signature_delta: tolerated, skipped. + } + + case "content_block_stop": + if event, ok := s.finishBlock(); ok { + return event, nil + } + + case "message_delta": + if ev.Delta.StopReason != "" { + s.finish = mapStopReason(ev.Delta.StopReason) + } + if ev.Usage != nil { + // Output tokens arrive cumulatively in the final delta; + // input tokens were reported in message_start. + s.usage.OutputTokens = ev.Usage.OutputTokens + } + + case "message_stop": + s.done = true + return llm.StreamEvent{Response: &llm.Response{ + Parts: s.parts, + ToolCalls: s.toolCalls, + FinishReason: s.finish, + Usage: s.usage, + Model: s.full, + }}, nil + + case "error": + // Mid-stream failure after the 200 (e.g. overloaded_error). + // Status stays 0: there is no HTTP status for it, and the + // default Classify treats it as transient, which fits overload. + apiErr := &llm.APIError{Provider: s.provider, Model: s.model} + if ev.Error != nil { + apiErr.Code = ev.Error.Type + apiErr.Message = ev.Error.Message + } + return llm.StreamEvent{}, apiErr + + default: + // ping and unknown event types: ignored. + } + } +} + +// finishBlock closes out the current content block, appending its result to +// the accumulated response. Tool-use blocks produce a stream event. +func (s *stream) finishBlock() (llm.StreamEvent, bool) { + defer func() { + s.blockType = "" + s.textBuf.Reset() + s.argsBuf.Reset() + }() + switch s.blockType { + case "text": + if s.textBuf.Len() > 0 { + s.parts = append(s.parts, llm.TextPart{Text: s.textBuf.String()}) + } + case "tool_use": + args := s.argsBuf.String() + if args == "" { + // A tool called with no arguments streams zero (or empty) + // input_json_delta fragments; the canonical form is "{}". + args = "{}" + } + call := llm.ToolCall{ID: s.toolID, Name: s.toolName, Arguments: json.RawMessage(args)} + s.toolCalls = append(s.toolCalls, call) + return llm.StreamEvent{ToolCall: &call}, true + } + return llm.StreamEvent{}, false +} + +// nextData reads SSE lines until one complete event's data is assembled +// (multi-line data fields are joined with "\n" per the SSE spec). "event:" +// lines and comments are ignored; dispatch keys off the JSON "type" field. +func (s *stream) nextData() (string, error) { + var data strings.Builder + for s.scanner.Scan() { + line := s.scanner.Text() + if line == "" { + if data.Len() > 0 { + return data.String(), nil + } + continue + } + if rest, ok := strings.CutPrefix(line, "data:"); ok { + if data.Len() > 0 { + data.WriteByte('\n') + } + data.WriteString(strings.TrimPrefix(rest, " ")) + } + } + if err := s.scanner.Err(); err != nil { + return "", fmt.Errorf("%s: read stream: %w", s.provider, err) + } + if data.Len() > 0 { + return data.String(), nil + } + // EOF before message_stop: the connection dropped mid-response. + return "", fmt.Errorf("%s: stream ended before message_stop: %w", s.provider, io.ErrUnexpectedEOF) +} diff --git a/provider/anthropic/stream_test.go b/provider/anthropic/stream_test.go new file mode 100644 index 0000000..1445471 --- /dev/null +++ b/provider/anthropic/stream_test.go @@ -0,0 +1,324 @@ +package anthropic + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "strings" + "testing" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// sse joins data payloads into an SSE body. Each payload becomes one event +// ("event:" name derived from the JSON type field is what the real API +// sends, but the client dispatches on the data, so a generic name is fine). +func sse(payloads ...string) string { + var b strings.Builder + for _, p := range payloads { + b.WriteString("event: event\n") + b.WriteString("data: ") + b.WriteString(p) + b.WriteString("\n\n") + } + return b.String() +} + +func sseServer(t *testing.T, c *capture, body string) *Provider { + t.Helper() + return newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + c.mu.Lock() + c.hits++ + c.header = r.Header.Clone() + c.body = raw + c.mu.Unlock() + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, body) + })) +} + +// drain collects all events until io.EOF, failing the test on any error. +func drain(t *testing.T, s llm.Stream) []llm.StreamEvent { + t.Helper() + var events []llm.StreamEvent + for { + ev, err := s.Next() + if err == io.EOF { + return events + } + if err != nil { + t.Fatalf("Next: %v", err) + } + events = append(events, ev) + } +} + +func openStream(t *testing.T, p *Provider, modelID string) llm.Stream { + t.Helper() + s, err := mustModel(t, p, modelID).Stream(context.Background(), + llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + t.Cleanup(func() { _ = s.Close() }) + return s +} + +func TestStreamTextDeltas(t *testing.T) { + body := sse( + `{"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"m","usage":{"input_tokens":10,"cache_creation_input_tokens":2,"cache_read_input_tokens":3,"output_tokens":1}}}`, + `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`, + `{"type":"ping"}`, + `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hel"}}`, + `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"lo"}}`, + `{"type":"content_block_stop","index":0}`, + `{"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}`, + `{"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":" world"}}`, + `{"type":"content_block_stop","index":1}`, + `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":12}}`, + `{"type":"message_stop"}`, + ) + var c capture + p := sseServer(t, &c, body) + s := openStream(t, p, "claude-test") + events := drain(t, s) + + if len(events) != 4 { + t.Fatalf("events = %d, want 4 (3 deltas + final response)", len(events)) + } + for i, want := range []string{"Hel", "lo", " world"} { + if events[i].TextDelta != want { + t.Errorf("event[%d].TextDelta = %q, want %q", i, events[i].TextDelta, want) + } + } + + final := events[3].Response + if final == nil { + t.Fatal("last event has no Response") + } + if len(final.Parts) != 2 { + t.Fatalf("final parts = %d, want 2 (one per text block)", len(final.Parts)) + } + if final.Text() != "Hello world" { + t.Errorf("final text = %q, want %q", final.Text(), "Hello world") + } + if final.FinishReason != llm.FinishStop { + t.Errorf("finish = %q, want stop", final.FinishReason) + } + // Input = 10+2+3 from message_start; output = 12 from message_delta. + if final.Usage.InputTokens != 15 || final.Usage.OutputTokens != 12 { + t.Errorf("usage = %+v, want {15 12}", final.Usage) + } + if final.Model != "anthropic/claude-test" { + t.Errorf("model = %q, want anthropic/claude-test", final.Model) + } + + // Past EOF, Next keeps returning io.EOF. + if _, err := s.Next(); err != io.EOF { + t.Errorf("Next after EOF = %v, want io.EOF", err) + } + + // The request must carry "stream": true. + if streamFlag := c.bodyMap(t)["stream"]; streamFlag != true { + t.Errorf("request stream = %v, want true", streamFlag) + } +} + +func TestStreamToolCallAssembly(t *testing.T) { + body := sse( + `{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":8,"output_tokens":1}}}`, + `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`, + `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Checking."}}`, + `{"type":"content_block_stop","index":0}`, + `{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_9","name":"get_weather","input":{}}}`, + `{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":""}}`, + `{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\":"}}`, + `{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" \"San Francisco, CA\"}"}}`, + `{"type":"content_block_stop","index":1}`, + `{"type":"content_block_start","index":2,"content_block":{"type":"tool_use","id":"toolu_10","name":"noop","input":{}}}`, + `{"type":"content_block_stop","index":2}`, + `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":21}}`, + `{"type":"message_stop"}`, + ) + var c capture + p := sseServer(t, &c, body) + events := drain(t, openStream(t, p, "claude-test")) + + if len(events) != 4 { + t.Fatalf("events = %d, want 4 (text, 2 tool calls, final)", len(events)) + } + if events[0].TextDelta != "Checking." { + t.Errorf("event[0] = %+v, want text delta", events[0]) + } + + call := events[1].ToolCall + if call == nil { + t.Fatal("event[1] has no ToolCall") + } + if call.ID != "toolu_9" || call.Name != "get_weather" { + t.Errorf("tool call = %+v", call) + } + var args map[string]any + if err := json.Unmarshal(call.Arguments, &args); err != nil { + t.Fatalf("assembled arguments invalid JSON: %v (%s)", err, call.Arguments) + } + if args["location"] != "San Francisco, CA" { + t.Errorf("arguments = %v", args) + } + + empty := events[2].ToolCall + if empty == nil || empty.ID != "toolu_10" { + t.Fatalf("event[2] = %+v, want second tool call", events[2]) + } + if string(empty.Arguments) != "{}" { + t.Errorf("empty tool call arguments = %s, want {}", empty.Arguments) + } + + final := events[3].Response + if final == nil { + t.Fatal("last event has no Response") + } + if len(final.ToolCalls) != 2 { + t.Errorf("final tool calls = %d, want 2", len(final.ToolCalls)) + } + if final.FinishReason != llm.FinishToolCalls { + t.Errorf("finish = %q, want tool_calls", final.FinishReason) + } + if final.Text() != "Checking." { + t.Errorf("final text = %q", final.Text()) + } + if final.Usage.InputTokens != 8 || final.Usage.OutputTokens != 21 { + t.Errorf("usage = %+v, want {8 21}", final.Usage) + } +} + +func TestStreamThinkingSkipped(t *testing.T) { + body := sse( + `{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`, + `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`, + `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":"hmm"}}`, + `{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":"sig"}}`, + `{"type":"content_block_stop","index":0}`, + `{"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}`, + `{"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":"hi"}}`, + `{"type":"content_block_stop","index":1}`, + `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":2}}`, + `{"type":"message_stop"}`, + ) + var c capture + p := sseServer(t, &c, body) + events := drain(t, openStream(t, p, "claude-test")) + + if len(events) != 2 { + t.Fatalf("events = %d, want 2 (thinking produces none)", len(events)) + } + if events[0].TextDelta != "hi" { + t.Errorf("event[0] = %+v, want TextDelta hi", events[0]) + } + final := events[1].Response + if final == nil || len(final.Parts) != 1 || final.Text() != "hi" { + t.Errorf("final = %+v, want single text part %q", final, "hi") + } +} + +func TestStreamMidStreamError(t *testing.T) { + body := sse( + `{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`, + `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`, + `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"par"}}`, + `{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`, + ) + var c capture + p := sseServer(t, &c, body) + s := openStream(t, p, "claude-test") + + ev, err := s.Next() + if err != nil || ev.TextDelta != "par" { + t.Fatalf("first Next = (%+v, %v), want text delta", ev, err) + } + _, err = s.Next() + if err == nil { + t.Fatal("second Next succeeded, want mid-stream error") + } + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("error %T (%v), want *llm.APIError", err, err) + } + if apiErr.Code != "overloaded_error" || apiErr.Message != "Overloaded" || apiErr.Status != 0 { + t.Errorf("apiErr = %+v", apiErr) + } + if llm.Classify(err) != llm.ClassTransient { + t.Error("overloaded_error must classify transient") + } +} + +func TestStreamHTTPErrorBeforeEvents(t *testing.T) { + var c capture + p := newTestProvider(t, c.handler(529, + `{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`)) + _, err := mustModel(t, p, "claude-test").Stream(context.Background(), + llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err == nil { + t.Fatal("Stream succeeded, want APIError before any events") + } + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("error %T (%v), want *llm.APIError", err, err) + } + if apiErr.Status != 529 || apiErr.Code != "overloaded_error" { + t.Errorf("apiErr = %+v, want 529 overloaded_error", apiErr) + } + if llm.Classify(err) != llm.ClassTransient { + t.Error("529 must classify transient") + } +} + +func TestStreamTruncatedBody(t *testing.T) { + // Stream ends without message_stop: Next must surface unexpected EOF. + body := sse( + `{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`, + `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`, + `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}`, + ) + var c capture + p := sseServer(t, &c, body) + s := openStream(t, p, "claude-test") + + if ev, err := s.Next(); err != nil || ev.TextDelta != "hi" { + t.Fatalf("first Next = (%+v, %v)", ev, err) + } + if _, err := s.Next(); !errors.Is(err, io.ErrUnexpectedEOF) { + t.Errorf("Next on truncated stream = %v, want io.ErrUnexpectedEOF", err) + } +} + +func TestStreamCloseIsSafe(t *testing.T) { + body := sse( + `{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`, + `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`, + `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}`, + `{"type":"content_block_stop","index":0}`, + `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":2}}`, + `{"type":"message_stop"}`, + ) + var c capture + p := sseServer(t, &c, body) + s := openStream(t, p, "claude-test") + + if err := s.Close(); err != nil { + t.Errorf("first Close: %v", err) + } + if err := s.Close(); err != nil { + t.Errorf("second Close: %v", err) + } + + // After EOF, Close is still fine. + s2 := openStream(t, p, "claude-test") + drain(t, s2) + if err := s2.Close(); err != nil { + t.Errorf("Close after EOF: %v", err) + } +} diff --git a/provider/anthropic/wire.go b/provider/anthropic/wire.go new file mode 100644 index 0000000..d80b5da --- /dev/null +++ b/provider/anthropic/wire.go @@ -0,0 +1,299 @@ +package anthropic + +import ( + "encoding/base64" + "encoding/json" + "strings" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// Wire types mirror the Messages API JSON shapes (June 2026 docs). Only the +// fields majordomo uses are modeled; unknown response fields are ignored by +// encoding/json. + +type wireRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Messages []wireMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` + Tools []wireTool `json:"tools,omitempty"` + ToolChoice *wireToolChoice `json:"tool_choice,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + OutputConfig *wireOutputConfig `json:"output_config,omitempty"` +} + +type wireMessage struct { + Role string `json:"role"` + Content []wireBlock `json:"content"` +} + +// wireBlock is a request-side content block. Exactly one shape is populated +// per block, keyed by Type: text, image, tool_use, or tool_result. +type wireBlock struct { + Type string `json:"type"` + + // text + Text string `json:"text,omitempty"` + + // image + Source *wireImageSource `json:"source,omitempty"` + + // tool_use + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` + + // tool_result + ToolUseID string `json:"tool_use_id,omitempty"` + Content string `json:"content,omitempty"` + IsError bool `json:"is_error,omitempty"` +} + +type wireImageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type wireTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema"` +} + +type wireToolChoice struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` +} + +type wireOutputConfig struct { + Format *wireOutputFormat `json:"format,omitempty"` +} + +type wireOutputFormat struct { + Type string `json:"type"` + Schema json.RawMessage `json:"schema"` +} + +type wireResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Model string `json:"model"` + Content []wireRespBlock `json:"content"` + StopReason string `json:"stop_reason"` + Usage wireUsage `json:"usage"` +} + +type wireRespBlock struct { + Type string `json:"type"` + Text string `json:"text"` + ID string `json:"id"` + Name string `json:"name"` + Input json.RawMessage `json:"input"` +} + +type wireUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` +} + +// toUsage maps API token accounting onto the canonical Usage. Why the sum: +// the API's input_tokens counts only tokens after the last cache breakpoint; +// real total input is input + cache_creation + cache_read. +func (u wireUsage) toUsage() llm.Usage { + return llm.Usage{ + InputTokens: u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens, + OutputTokens: u.OutputTokens, + } +} + +type wireErrorEnvelope struct { + Type string `json:"type"` + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` +} + +// buildWireRequest translates the canonical request into the Messages API +// shape. +// +// Request.ReasoningEffort is intentionally ignored: the current Messages API +// has no low/medium/high reasoning knob — thinking is adaptive on current +// models, and the legacy budget/disable parameters 400 on them. The llm +// contract says providers ignore ReasoningEffort where no mapping exists. +// +// Request.SchemaName is likewise ignored: output_config.format takes a bare +// schema with no name field. +func buildWireRequest(modelID string, req llm.Request, defaultMax int, stream bool) wireRequest { + maxTokens := req.MaxTokens + if maxTokens == 0 { + // max_tokens is required by the API; 0 means "provider default". + maxTokens = defaultMax + } + + wr := wireRequest{ + Model: modelID, + MaxTokens: maxTokens, + System: foldSystem(req), + Messages: toWireMessages(req.Messages), + Stream: stream, + Tools: toWireTools(req.Tools), + ToolChoice: toWireToolChoice(req.ToolChoice), + Temperature: req.Temperature, + TopP: req.TopP, + StopSequences: req.StopSequences, + } + if req.Schema != nil { + wr.OutputConfig = &wireOutputConfig{Format: &wireOutputFormat{ + Type: "json_schema", + Schema: req.Schema, + }} + } + return wr +} + +// foldSystem joins Request.System with the text of every RoleSystem message +// (System field first, original order, "\n\n" separators). Why: the API +// takes the system prompt as a top-level field and rejects system roles +// inside messages, so canonical RoleSystem messages must fold in here. +func foldSystem(req llm.Request) string { + parts := make([]string, 0, 2) + if req.System != "" { + parts = append(parts, req.System) + } + for _, msg := range req.Messages { + if msg.Role != llm.RoleSystem { + continue + } + if text := msg.Text(); text != "" { + parts = append(parts, text) + } + } + return strings.Join(parts, "\n\n") +} + +func toWireMessages(msgs []llm.Message) []wireMessage { + out := make([]wireMessage, 0, len(msgs)) + for _, msg := range msgs { + switch msg.Role { + case llm.RoleSystem: + // Folded into the top-level system field by foldSystem. + continue + + case llm.RoleTool: + // One user message carrying one tool_result block per result. + blocks := make([]wireBlock, 0, len(msg.ToolResults)) + for _, res := range msg.ToolResults { + blocks = append(blocks, wireBlock{ + Type: "tool_result", + ToolUseID: res.ID, + Content: res.Content, + IsError: res.IsError, + }) + } + out = append(out, wireMessage{Role: "user", Content: blocks}) + + case llm.RoleAssistant: + blocks := toWireBlocks(msg.Parts) + for _, call := range msg.ToolCalls { + args := call.Arguments + if len(args) == 0 { + // The API requires input to be a JSON object. + args = json.RawMessage("{}") + } + blocks = append(blocks, wireBlock{ + Type: "tool_use", + ID: call.ID, + Name: call.Name, + Input: args, + }) + } + out = append(out, wireMessage{Role: "assistant", Content: blocks}) + + default: // llm.RoleUser and anything unrecognized + out = append(out, wireMessage{Role: "user", Content: toWireBlocks(msg.Parts)}) + } + } + return out +} + +func toWireBlocks(parts []llm.Part) []wireBlock { + blocks := make([]wireBlock, 0, len(parts)) + for _, part := range parts { + switch p := part.(type) { + case llm.TextPart: + blocks = append(blocks, wireBlock{Type: "text", Text: p.Text}) + case llm.ImagePart: + blocks = append(blocks, wireBlock{Type: "image", Source: &wireImageSource{ + Type: "base64", + MediaType: p.MIME, + Data: base64.StdEncoding.EncodeToString(p.Data), + }}) + } + } + return blocks +} + +func toWireTools(tools []llm.Tool) []wireTool { + if len(tools) == 0 { + return nil + } + out := make([]wireTool, 0, len(tools)) + for _, t := range tools { + schema := t.Parameters + if len(schema) == 0 { + // Why: input_schema is required by the API; a tool with no + // arguments still needs an (empty) object schema. + schema = json.RawMessage(`{"type":"object","properties":{}}`) + } + out = append(out, wireTool{ + Name: t.Name, + Description: t.Description, + InputSchema: schema, + }) + } + return out +} + +// toWireToolChoice maps the canonical tool-choice policy. "" omits the field +// (API default is auto); any value other than the three keywords names the +// one tool the model must call. +func toWireToolChoice(choice string) *wireToolChoice { + switch choice { + case "": + return nil + case "auto": + return &wireToolChoice{Type: "auto"} + case "required": + return &wireToolChoice{Type: "any"} + case "none": + return &wireToolChoice{Type: "none"} + default: + return &wireToolChoice{Type: "tool", Name: choice} + } +} + +// mapStopReason maps the API stop_reason onto the canonical FinishReason. +func mapStopReason(stop string) llm.FinishReason { + switch stop { + case "end_turn", "stop_sequence": + return llm.FinishStop + case "max_tokens", "model_context_window_exceeded": + return llm.FinishLength + case "tool_use": + return llm.FinishToolCalls + case "refusal": + return llm.FinishContentFilter + default: + // pause_turn and any future provider-specific reasons. + return llm.FinishOther + } +} diff --git a/provider/ollama/ollama.go b/provider/ollama/ollama.go new file mode 100644 index 0000000..c78e29a --- /dev/null +++ b/provider/ollama/ollama.go @@ -0,0 +1,168 @@ +// Package ollama implements majordomo's provider contract over Ollama's +// native chat API (POST {base}/api/chat), targeted at three backends that +// share one wire protocol: +// +// - a local Ollama instance (preset Local: OLLAMA_HOST or +// http://localhost:11434, no auth), +// - Ollama Cloud (preset Cloud: https://ollama.com, bearer key from +// OLLAMA_API_KEY), and +// - foreman, Steve's native-Ollama queue daemon (preset Foreman: explicit +// base URL + bearer token). +// +// Wire surface verified against docs.ollama.com and ollama/ollama +// docs/api.md + api/types.go (June 2026): NDJSON streaming (stream defaults +// true server-side — Generate always sends stream:false explicitly); +// tool_calls carry arguments as a JSON OBJECT (not a string); tool results +// return as {"role":"tool","content",...,"tool_name"}; structured output +// via "format" (a full JSON-schema object); thinking via the bool-or-string +// "think" field; errors as {"error":"message"} with a non-2xx status. +// +// foreman deviation (verified in its source): sync /api/chat does not +// stream — a stream:true request yields ONE buffered application/json +// object. The NDJSON reader here handles that transparently (a single JSON +// line parses as the final chunk). +package ollama + +import ( + "fmt" + "net/http" + "os" + "strings" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// DefaultLocalBaseURL is the default base URL for a locally-running Ollama. +const DefaultLocalBaseURL = "http://localhost:11434" + +// DefaultCloudBaseURL is the base URL for Ollama Cloud. +const DefaultCloudBaseURL = "https://ollama.com" + +// defaultCapabilities is the conservative provider-wide default; individual +// models (e.g. high-resolution vision tags) override via llm.WithCapabilities. +var defaultCapabilities = llm.Capabilities{ + SupportsTools: true, + SupportsStructured: true, + SupportsStreaming: true, + MaxImagesPerReq: 8, + MaxImageBytes: 20 << 20, + MaxImageDimension: 2048, + AllowedImageMIME: []string{"image/jpeg", "image/png"}, +} + +// Provider is a native-Ollama chat client bound to one base URL. +type Provider struct { + name string + baseURL string + token string + client *http.Client + caps llm.Capabilities +} + +// Option configures the provider. +type Option func(*Provider) + +// WithName overrides the registry name (default "ollama"). +func WithName(name string) Option { return func(p *Provider) { p.name = name } } + +// WithBaseURL sets the backend base URL (scheme://host[:port][/path]). +func WithBaseURL(u string) Option { + return func(p *Provider) { p.baseURL = strings.TrimRight(u, "/") } +} + +// WithToken sets the bearer token (Ollama Cloud key / foreman token). +// Empty means no Authorization header (local mode). +func WithToken(token string) Option { return func(p *Provider) { p.token = token } } + +// WithHTTPClient overrides the HTTP client (proxies, test TLS, timeouts — +// note foreman sync chat long-polls; prefer context deadlines over client +// timeouts). +func WithHTTPClient(c *http.Client) Option { return func(p *Provider) { p.client = c } } + +// WithDefaultCapabilities overrides the provider-wide default capabilities. +func WithDefaultCapabilities(caps llm.Capabilities) Option { + return func(p *Provider) { p.caps = caps } +} + +// New creates a generic native-Ollama provider. Most callers want one of +// the presets (Local, Cloud, Foreman) or an LLM_* env DSN instead. +// Construction never fails; a missing base URL surfaces at request time. +func New(opts ...Option) *Provider { + p := &Provider{ + name: "ollama", + client: &http.Client{}, + caps: defaultCapabilities, + } + for _, opt := range opts { + opt(p) + } + return p +} + +// Local returns the local-Ollama preset: name "ollama", base URL from +// OLLAMA_HOST (normalized per Ollama conventions) or localhost:11434. +func Local(opts ...Option) *Provider { + base := DefaultLocalBaseURL + if h := os.Getenv("OLLAMA_HOST"); h != "" { + base = NormalizeHost(h) + } + return New(append([]Option{WithBaseURL(base)}, opts...)...) +} + +// Cloud returns the Ollama Cloud preset: name "ollama-cloud", +// https://ollama.com, bearer key from OLLAMA_API_KEY. +func Cloud(opts ...Option) *Provider { + return New(append([]Option{ + WithName("ollama-cloud"), + WithBaseURL(DefaultCloudBaseURL), + WithToken(os.Getenv("OLLAMA_API_KEY")), + }, opts...)...) +} + +// Foreman returns a foreman preset bound to the given daemon. +func Foreman(baseURL, token string, opts ...Option) *Provider { + return New(append([]Option{ + WithName("foreman"), + WithBaseURL(baseURL), + WithToken(token), + }, opts...)...) +} + +// NormalizeHost turns an OLLAMA_HOST-style value into a base URL: +// "host" → http://host:11434, "host:port" → http://host:port, full URLs +// pass through (trailing slash trimmed). +func NormalizeHost(h string) string { + h = strings.TrimRight(strings.TrimSpace(h), "/") + if strings.Contains(h, "://") { + return h + } + if !strings.Contains(h, ":") { + h += ":11434" + } + return "http://" + h +} + +// Name implements llm.Provider. +func (p *Provider) Name() string { return p.name } + +// BaseURL reports the configured backend base URL (diagnostics). +func (p *Provider) BaseURL() string { return p.baseURL } + +// Model implements llm.Provider; the id passes through verbatim. +func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) { + cfg := llm.ApplyModelOptions(opts) + caps := p.caps + if cfg.Capabilities != nil { + caps = *cfg.Capabilities + } + return &model{provider: p, id: id, caps: caps}, nil +} + +// checkReady reports a usable configuration (a base URL is the only hard +// requirement; auth problems surface as 401s from the backend). +func (p *Provider) checkReady() error { + if p.baseURL == "" { + return fmt.Errorf("ollama provider %q: no base URL configured (set one via the preset, WithBaseURL, or an LLM_* env DSN)", p.name) + } + return nil +} diff --git a/provider/ollama/ollama_test.go b/provider/ollama/ollama_test.go new file mode 100644 index 0000000..197adeb --- /dev/null +++ b/provider/ollama/ollama_test.go @@ -0,0 +1,492 @@ +package ollama + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// capture spins up an httptest server that records the request and replies +// with the given handler. +type captured struct { + auth string + contentType string + path string + body map[string]any + raw []byte +} + +func serve(t *testing.T, status int, respond func(w http.ResponseWriter)) (*Provider, *captured) { + t.Helper() + cap := &captured{} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cap.auth = r.Header.Get("Authorization") + cap.contentType = r.Header.Get("Content-Type") + cap.path = r.URL.Path + cap.raw, _ = io.ReadAll(r.Body) + _ = json.Unmarshal(cap.raw, &cap.body) + w.WriteHeader(status) + respond(w) + })) + t.Cleanup(ts.Close) + return New(WithBaseURL(ts.URL), WithToken("test-token")), cap +} + +func jsonReply(obj string) func(w http.ResponseWriter) { + return func(w http.ResponseWriter) { _, _ = io.WriteString(w, obj) } +} + +func basicRequest() llm.Request { + return llm.Request{Messages: []llm.Message{llm.UserText("hi")}} +} + +func TestGenerateRoundTrip(t *testing.T) { + p, cap := serve(t, 200, jsonReply(`{ + "model":"qwen3:30b", + "message":{"role":"assistant","content":"hello there"}, + "done":true,"done_reason":"stop", + "prompt_eval_count":12,"eval_count":7 + }`)) + + m, _ := p.Model("qwen3:30b") + temp := 0.2 + resp, err := m.Generate(context.Background(), llm.Request{ + System: "be terse", + Messages: []llm.Message{llm.SystemText("extra sys"), llm.UserText("hi")}, + Temperature: &temp, + MaxTokens: 64, + StopSequences: []string{"END"}, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + + // Wire assertions. + if cap.path != "/api/chat" { + t.Errorf("path = %q", cap.path) + } + if cap.auth != "Bearer test-token" { + t.Errorf("auth = %q", cap.auth) + } + if cap.body["model"] != "qwen3:30b" { + t.Errorf("model = %v", cap.body["model"]) + } + if stream, ok := cap.body["stream"].(bool); !ok || stream { + t.Errorf("stream must be explicit false, got %v", cap.body["stream"]) + } + msgs := cap.body["messages"].([]any) + first := msgs[0].(map[string]any) + if first["role"] != "system" || first["content"] != "be terse\n\nextra sys" { + t.Errorf("system fold = %v", first) + } + second := msgs[1].(map[string]any) + if second["role"] != "user" || second["content"] != "hi" { + t.Errorf("user msg = %v", second) + } + opts := cap.body["options"].(map[string]any) + if opts["temperature"] != 0.2 || opts["num_predict"] != float64(64) { + t.Errorf("options = %v", opts) + } + + // Response assertions. + if resp.Text() != "hello there" { + t.Errorf("text = %q", resp.Text()) + } + if resp.FinishReason != llm.FinishStop { + t.Errorf("finish = %v", resp.FinishReason) + } + if resp.Usage.InputTokens != 12 || resp.Usage.OutputTokens != 7 { + t.Errorf("usage = %+v", resp.Usage) + } + if resp.Model != "ollama/qwen3:30b" { + t.Errorf("resp.Model = %q", resp.Model) + } +} + +func TestImagesEncodeAsBase64(t *testing.T) { + p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"a cat"},"done":true,"done_reason":"stop"}`)) + imgBytes := []byte{0xFF, 0xD8, 0xFF, 0xE0, 1, 2, 3} + + m, _ := p.Model("llava") + _, err := m.Generate(context.Background(), llm.Request{ + Messages: []llm.Message{llm.UserParts(llm.Text("describe"), llm.Image("image/jpeg", imgBytes))}, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + msgs := cap.body["messages"].([]any) + user := msgs[0].(map[string]any) + images := user["images"].([]any) + if len(images) != 1 || images[0] != base64.StdEncoding.EncodeToString(imgBytes) { + t.Errorf("images = %v", images) + } + if strings.Contains(images[0].(string), "data:") { + t.Error("images must be raw base64 without data: prefix") + } +} + +func TestToolsAndToolCallRoundTrip(t *testing.T) { + p, cap := serve(t, 200, jsonReply(`{ + "message":{"role":"assistant","content":"","tool_calls":[ + {"function":{"index":0,"name":"get_weather","arguments":{"city":"Tokyo"}}} + ]}, + "done":true,"done_reason":"stop" + }`)) + + tool := llm.Tool{ + Name: "get_weather", Description: "weather", + Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}`), + } + m, _ := p.Model("qwen3") + resp, err := m.Generate(context.Background(), basicRequest(), llm.WithTools(tool)) + if err != nil { + t.Fatalf("Generate: %v", err) + } + + // Tools serialize with parameters as an object. + tools := cap.body["tools"].([]any) + fn := tools[0].(map[string]any)["function"].(map[string]any) + if fn["name"] != "get_weather" { + t.Errorf("tool fn = %v", fn) + } + if _, ok := fn["parameters"].(map[string]any); !ok { + t.Errorf("parameters must be an object, got %T", fn["parameters"]) + } + + // Tool call comes back with arguments as a JSON object → RawMessage. + if len(resp.ToolCalls) != 1 { + t.Fatalf("tool calls = %v", resp.ToolCalls) + } + tc := resp.ToolCalls[0] + if tc.Name != "get_weather" || tc.ID == "" { + t.Errorf("call = %+v (id must be synthesized)", tc) + } + var args struct { + City string `json:"city"` + } + if err := json.Unmarshal(tc.Arguments, &args); err != nil || args.City != "Tokyo" { + t.Errorf("arguments = %s (%v)", tc.Arguments, err) + } + if resp.FinishReason != llm.FinishToolCalls { + t.Errorf("finish = %v, want tool_calls", resp.FinishReason) + } +} + +func TestToolResultsAndHistoryToolCalls(t *testing.T) { + p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"21C"},"done":true,"done_reason":"stop"}`)) + + m, _ := p.Model("qwen3") + _, err := m.Generate(context.Background(), llm.Request{ + Messages: []llm.Message{ + llm.UserText("weather?"), + {Role: llm.RoleAssistant, ToolCalls: []llm.ToolCall{ + {ID: "call_0", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Tokyo"}`)}, + }}, + llm.ToolResultsMessage( + llm.ToolResult{ID: "call_0", Name: "get_weather", Content: `{"temp":21}`}, + llm.ToolResult{ID: "call_1", Name: "broken_tool", Content: "boom", IsError: true}, + ), + }, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + msgs := cap.body["messages"].([]any) + if len(msgs) != 4 { + t.Fatalf("messages = %d, want 4 (user, assistant, 2 tool results)", len(msgs)) + } + asst := msgs[1].(map[string]any) + calls := asst["tool_calls"].([]any) + args := calls[0].(map[string]any)["function"].(map[string]any)["arguments"] + if _, ok := args.(map[string]any); !ok { + t.Errorf("history tool-call arguments must be an object, got %T", args) + } + tr1 := msgs[2].(map[string]any) + if tr1["role"] != "tool" || tr1["tool_name"] != "get_weather" || tr1["content"] != `{"temp":21}` { + t.Errorf("tool result 1 = %v", tr1) + } + tr2 := msgs[3].(map[string]any) + if tr2["content"] != "ERROR: boom" { + t.Errorf("error result content = %v", tr2["content"]) + } +} + +func TestStructuredOutputFormat(t *testing.T) { + p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"{\"name\":\"Ada\"}"},"done":true,"done_reason":"stop"}`)) + schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}},"required":["name"]}`) + + m, _ := p.Model("qwen3") + resp, err := m.Generate(context.Background(), basicRequest(), llm.WithSchema(schema, "person")) + if err != nil { + t.Fatalf("Generate: %v", err) + } + format, ok := cap.body["format"].(map[string]any) + if !ok || format["type"] != "object" { + t.Errorf("format = %v, want the schema object", cap.body["format"]) + } + if resp.Text() != `{"name":"Ada"}` { + t.Errorf("text = %q", resp.Text()) + } +} + +func TestThinkMapping(t *testing.T) { + p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"ok"},"done":true,"done_reason":"stop"}`)) + m, _ := p.Model("gpt-oss:120b") + _, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("high")) + if err != nil { + t.Fatalf("Generate: %v", err) + } + if cap.body["think"] != "high" { + t.Errorf("think = %v", cap.body["think"]) + } + + if _, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("max")); err == nil { + t.Error("invalid reasoning effort should error") + } +} + +func TestToolChoiceNoneDropsTools(t *testing.T) { + p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"ok"},"done":true,"done_reason":"stop"}`)) + m, _ := p.Model("qwen3") + _, err := m.Generate(context.Background(), basicRequest(), + llm.WithTools(llm.Tool{Name: "t"}), llm.WithToolChoice("none")) + if err != nil { + t.Fatalf("Generate: %v", err) + } + if _, present := cap.body["tools"]; present { + t.Error("tool_choice none must omit tools") + } +} + +func TestStreamingNDJSON(t *testing.T) { + p, _ := serve(t, 200, func(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/x-ndjson") + _, _ = io.WriteString(w, `{"message":{"role":"assistant","content":"Hel"},"done":false} +{"message":{"role":"assistant","content":"lo"},"done":false} +{"message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"ping","arguments":{}}}]},"done":false} +{"message":{"role":"assistant","content":""},"done":true,"done_reason":"stop","prompt_eval_count":5,"eval_count":9} +`) + }) + + m, _ := p.Model("qwen3") + s, err := m.Stream(context.Background(), basicRequest()) + if err != nil { + t.Fatalf("Stream: %v", err) + } + defer s.Close() + + var text strings.Builder + var toolCalls []llm.ToolCall + var final *llm.Response + for { + ev, err := s.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("Next: %v", err) + } + text.WriteString(ev.TextDelta) + if ev.ToolCall != nil { + toolCalls = append(toolCalls, *ev.ToolCall) + } + if ev.Response != nil { + final = ev.Response + } + } + if text.String() != "Hello" { + t.Errorf("text = %q", text.String()) + } + if len(toolCalls) != 1 || toolCalls[0].Name != "ping" { + t.Errorf("tool calls = %+v", toolCalls) + } + if final == nil { + t.Fatal("no final response event") + } + if final.Usage.InputTokens != 5 || final.Usage.OutputTokens != 9 { + t.Errorf("final usage = %+v", final.Usage) + } + if final.FinishReason != llm.FinishToolCalls { + t.Errorf("final finish = %v", final.FinishReason) + } + if final.Text() != "Hello" { + t.Errorf("final text = %q", final.Text()) + } +} + +// TestStreamingForemanSingleObject: foreman returns one buffered JSON +// object to a stream:true request; the stream must still deliver the text +// and a final response. +func TestStreamingForemanSingleObject(t *testing.T) { + p, cap := serve(t, 200, func(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"message":{"role":"assistant","content":"queued answer"},"done":true,"done_reason":"stop","prompt_eval_count":3,"eval_count":4}`) + }) + + m, _ := p.Model("qwen3:30b") + s, err := m.Stream(context.Background(), basicRequest()) + if err != nil { + t.Fatalf("Stream: %v", err) + } + defer s.Close() + + if stream, ok := cap.body["stream"].(bool); !ok || !stream { + t.Errorf("stream flag = %v, want true", cap.body["stream"]) + } + + var text strings.Builder + var final *llm.Response + for { + ev, err := s.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("Next: %v", err) + } + text.WriteString(ev.TextDelta) + if ev.Response != nil { + final = ev.Response + } + } + if text.String() != "queued answer" || final == nil || final.Usage.OutputTokens != 4 { + t.Errorf("text=%q final=%+v", text.String(), final) + } +} + +func TestErrorMapping(t *testing.T) { + t.Run("404 is model-not-found", func(t *testing.T) { + p, _ := serve(t, 404, jsonReply(`{"error":"model not found"}`)) + m, _ := p.Model("nope") + _, err := m.Generate(context.Background(), basicRequest()) + if !errors.Is(err, llm.ErrModelNotFound) { + t.Errorf("error = %v, want ErrModelNotFound", err) + } + if llm.Classify(err) != llm.ClassPermanent { + t.Error("404 must classify permanent") + } + }) + + t.Run("503 transient with message", func(t *testing.T) { + p, _ := serve(t, 503, jsonReply(`{"error":"request cancelled while waiting"}`)) + m, _ := p.Model("qwen3") + _, err := m.Generate(context.Background(), basicRequest()) + var apiErr *llm.APIError + if !errors.As(err, &apiErr) || apiErr.Status != 503 || !strings.Contains(apiErr.Message, "cancelled") { + t.Errorf("error = %v", err) + } + if llm.Classify(err) != llm.ClassTransient { + t.Error("503 must classify transient") + } + }) + + t.Run("non-JSON error body", func(t *testing.T) { + p, _ := serve(t, 500, jsonReply(`upstream exploded`)) + m, _ := p.Model("qwen3") + _, err := m.Generate(context.Background(), basicRequest()) + var apiErr *llm.APIError + if !errors.As(err, &apiErr) || !strings.Contains(apiErr.Message, "upstream exploded") { + t.Errorf("error = %v", err) + } + }) +} + +func TestCapabilityEnforcement(t *testing.T) { + p, _ := serve(t, 200, jsonReply(`{"message":{"content":"x"},"done":true}`)) + + t.Run("too many images", func(t *testing.T) { + m, _ := p.Model("llava", llm.WithCapabilities(llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/png"}})) + _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{ + llm.UserParts(llm.Image("image/png", []byte{1}), llm.Image("image/png", []byte{2})), + }}) + if !errors.Is(err, llm.ErrUnsupported) { + t.Errorf("error = %v, want ErrUnsupported", err) + } + }) + + t.Run("images on text-only model", func(t *testing.T) { + m, _ := p.Model("qwen3", llm.WithCapabilities(llm.Capabilities{})) + _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{ + llm.UserParts(llm.Image("image/png", []byte{1})), + }}) + if !errors.Is(err, llm.ErrUnsupported) { + t.Errorf("error = %v, want ErrUnsupported", err) + } + }) + + t.Run("disallowed mime", func(t *testing.T) { + m, _ := p.Model("llava") // default caps: jpeg/png only + _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{ + llm.UserParts(llm.Image("image/tiff", []byte{1})), + }}) + if !errors.Is(err, llm.ErrUnsupported) { + t.Errorf("error = %v, want ErrUnsupported", err) + } + }) +} + +func TestNoBaseURL(t *testing.T) { + p := New(WithBaseURL("")) + m, _ := p.Model("x") + if _, err := m.Generate(context.Background(), basicRequest()); err == nil || + !strings.Contains(err.Error(), "no base URL") { + t.Errorf("error = %v, want a clear no-base-URL message", err) + } +} + +func TestNormalizeHost(t *testing.T) { + for in, want := range map[string]string{ + "myhost": "http://myhost:11434", + "myhost:8080": "http://myhost:8080", + "http://myhost:8080/": "http://myhost:8080", + "https://ollama.com": "https://ollama.com", + " 127.0.0.1:11434 ": "http://127.0.0.1:11434", + } { + if got := NormalizeHost(in); got != want { + t.Errorf("NormalizeHost(%q) = %q, want %q", in, got, want) + } + } +} + +func TestPresets(t *testing.T) { + t.Run("cloud", func(t *testing.T) { + t.Setenv("OLLAMA_API_KEY", "cloud-key") + p := Cloud() + if p.Name() != "ollama-cloud" || p.baseURL != DefaultCloudBaseURL || p.token != "cloud-key" { + t.Errorf("cloud preset = %+v", p) + } + }) + t.Run("local respects OLLAMA_HOST", func(t *testing.T) { + t.Setenv("OLLAMA_HOST", "box.lan:9999") + p := Local() + if p.Name() != "ollama" || p.baseURL != "http://box.lan:9999" || p.token != "" { + t.Errorf("local preset = %+v", p) + } + }) + t.Run("foreman", func(t *testing.T) { + p := Foreman("http://foreman-m1:8080", "tok") + if p.Name() != "foreman" || p.baseURL != "http://foreman-m1:8080" || p.token != "tok" { + t.Errorf("foreman preset = %+v", p) + } + }) +} + +func TestLocalNoAuthHeader(t *testing.T) { + p, cap := serve(t, 200, jsonReply(`{"message":{"content":"x"},"done":true}`)) + p.token = "" // simulate local mode on the test server + m, _ := p.Model("llama3") + if _, err := m.Generate(context.Background(), basicRequest()); err != nil { + t.Fatalf("Generate: %v", err) + } + if cap.auth != "" { + t.Errorf("auth header = %q, want none in local mode", cap.auth) + } +} diff --git a/provider/ollama/stream.go b/provider/ollama/stream.go new file mode 100644 index 0000000..944c8db --- /dev/null +++ b/provider/ollama/stream.go @@ -0,0 +1,140 @@ +package ollama + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "strconv" + "sync" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// Stream implements llm.Model over Ollama's NDJSON streaming. It also +// transparently handles foreman's non-streaming degradation (a single +// buffered JSON object): one JSON line parses as the final chunk. +func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) { + req = req.Apply(opts...) + if err := m.enforceCapabilities(req); err != nil { + return nil, err + } + wireReq, err := m.buildRequest(req, true) + if err != nil { + return nil, err + } + resp, err := m.do(ctx, wireReq) + if err != nil { + return nil, err + } + + sc := bufio.NewScanner(resp.Body) + // Single NDJSON lines can far exceed the 64KB default (thinking dumps, + // tool payloads, foreman's whole-response-as-one-line degradation). + sc.Buffer(make([]byte, 64<<10), 16<<20) + + return &stream{model: m, body: resp.Body, scanner: sc}, nil +} + +type stream struct { + model *model + body io.Closer + scanner *bufio.Scanner + + mu sync.Mutex + closed bool + finished bool + toolCalls []llm.ToolCall + text []byte + pending []llm.StreamEvent + usage llm.Usage + doneReason string +} + +func (s *stream) Next() (llm.StreamEvent, error) { + s.mu.Lock() + defer s.mu.Unlock() + + for { + if len(s.pending) > 0 { + ev := s.pending[0] + s.pending = s.pending[1:] + return ev, nil + } + if s.finished { + return llm.StreamEvent{}, io.EOF + } + if !s.scanner.Scan() { + if err := s.scanner.Err(); err != nil { + return llm.StreamEvent{}, fmt.Errorf("ollama %s: read stream: %w", s.model.qualified(), err) + } + // EOF without a done chunk: synthesize the final response from + // what we accumulated rather than losing it. + s.queueFinal() + continue + } + line := s.scanner.Bytes() + if len(line) == 0 { + continue + } + + var chunk chatResponse + if err := json.Unmarshal(line, &chunk); err != nil { + return llm.StreamEvent{}, fmt.Errorf("ollama %s: decode stream chunk: %w", s.model.qualified(), err) + } + + if chunk.Message.Content != "" { + s.text = append(s.text, chunk.Message.Content...) + s.pending = append(s.pending, llm.StreamEvent{TextDelta: chunk.Message.Content}) + } + // Tool calls arrive complete per chunk (no partial-argument deltas + // in the native protocol). + base := len(s.toolCalls) + for i, tc := range chunk.Message.ToolCalls { + id := tc.ID + if id == "" { + id = "call_" + strconv.Itoa(base+i) + } + args := tc.Function.Arguments + if len(args) == 0 { + args = json.RawMessage("{}") + } + call := llm.ToolCall{ID: id, Name: tc.Function.Name, Arguments: args} + s.toolCalls = append(s.toolCalls, call) + s.pending = append(s.pending, llm.StreamEvent{ToolCall: &s.toolCalls[len(s.toolCalls)-1]}) + } + if chunk.Done { + s.usage = llm.Usage{InputTokens: chunk.PromptEvalCount, OutputTokens: chunk.EvalCount} + s.doneReason = chunk.DoneReason + s.queueFinal() + } + } +} + +// queueFinal appends the final Response event and marks the stream done. +func (s *stream) queueFinal() { + resp := &llm.Response{ + Model: s.model.qualified(), + Usage: s.usage, + FinishReason: finishReason(s.doneReason, len(s.toolCalls) > 0), + } + if len(s.text) > 0 { + resp.Parts = append(resp.Parts, llm.Text(string(s.text))) + } + if len(s.toolCalls) > 0 { + resp.ToolCalls = append([]llm.ToolCall(nil), s.toolCalls...) + } + s.pending = append(s.pending, llm.StreamEvent{Response: resp}) + s.finished = true +} + +func (s *stream) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return nil + } + s.closed = true + return s.body.Close() +} diff --git a/provider/ollama/wire.go b/provider/ollama/wire.go new file mode 100644 index 0000000..72ad8f0 --- /dev/null +++ b/provider/ollama/wire.go @@ -0,0 +1,343 @@ +package ollama + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// ---- wire types (field names per ollama api/types.go) ---- + +type chatRequest struct { + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + Tools []toolDef `json:"tools,omitempty"` + Format json.RawMessage `json:"format,omitempty"` + Options map[string]any `json:"options,omitempty"` + // Stream has no omitempty on purpose: the server default is true, so + // Generate must send an explicit false. + Stream bool `json:"stream"` + // Think is bool-or-string on the wire ("low"/"medium"/"high" or a bool). + Think json.RawMessage `json:"think,omitempty"` +} + +type chatMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Images []string `json:"images,omitempty"` // raw base64, no data: prefix + ToolCalls []toolCall `json:"tool_calls,omitempty"` + ToolName string `json:"tool_name,omitempty"` // on role:"tool" results +} + +type toolDef struct { + Type string `json:"type"` + Function toolDefFunc `json:"function"` +} + +type toolDefFunc struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +type toolCall struct { + ID string `json:"id,omitempty"` + Function toolCallFunc `json:"function"` +} + +type toolCallFunc struct { + Index int `json:"index,omitempty"` + Name string `json:"name"` + // Arguments is a JSON OBJECT on the wire (unlike OpenAI's string). + Arguments json.RawMessage `json:"arguments"` +} + +type chatResponse struct { + Model string `json:"model"` + Message respMessage `json:"message"` + Done bool `json:"done"` + DoneReason string `json:"done_reason"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` +} + +type respMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Thinking string `json:"thinking"` + ToolCalls []toolCall `json:"tool_calls"` +} + +type errorBody struct { + Error string `json:"error"` +} + +// ---- model ---- + +type model struct { + provider *Provider + id string + caps llm.Capabilities +} + +func (m *model) Capabilities() llm.Capabilities { return m.caps } + +func (m *model) qualified() string { return m.provider.name + "/" + m.id } + +// enforceCapabilities is the backstop check (the media layer normalizes +// before requests get here; see ADR-0009). +func (m *model) enforceCapabilities(req llm.Request) error { + count := 0 + for _, msg := range req.Messages { + for _, part := range msg.Parts { + img, ok := part.(llm.ImagePart) + if !ok { + continue + } + count++ + if !m.caps.SupportsImages() { + return fmt.Errorf("%w: %s does not accept image input", llm.ErrUnsupported, m.qualified()) + } + if !m.caps.MIMEAllowed(img.MIME) { + return fmt.Errorf("%w: %s does not accept %s images", llm.ErrUnsupported, m.qualified(), img.MIME) + } + if m.caps.MaxImageBytes > 0 && len(img.Data) > m.caps.MaxImageBytes { + return fmt.Errorf("%w: image of %d bytes exceeds %s limit of %d", + llm.ErrUnsupported, len(img.Data), m.qualified(), m.caps.MaxImageBytes) + } + } + } + if count > 0 && m.caps.MaxImagesPerReq > 0 && count > m.caps.MaxImagesPerReq { + return fmt.Errorf("%w: %d images exceed %s limit of %d", + llm.ErrUnsupported, count, m.qualified(), m.caps.MaxImagesPerReq) + } + return nil +} + +// buildRequest maps the canonical request onto the wire shape. +func (m *model) buildRequest(req llm.Request, stream bool) (*chatRequest, error) { + out := &chatRequest{Model: m.id, Stream: stream} + + // System prompt: dedicated field first, then folded RoleSystem messages. + var sys []string + if req.System != "" { + sys = append(sys, req.System) + } + for _, msg := range req.Messages { + if msg.Role == llm.RoleSystem { + if t := msg.Text(); t != "" { + sys = append(sys, t) + } + } + } + if len(sys) > 0 { + out.Messages = append(out.Messages, chatMessage{ + Role: "system", Content: strings.Join(sys, "\n\n"), + }) + } + + for _, msg := range req.Messages { + switch msg.Role { + case llm.RoleSystem: + // Already folded above. + case llm.RoleTool: + for _, res := range msg.ToolResults { + content := res.Content + if res.IsError { + content = "ERROR: " + content + } + out.Messages = append(out.Messages, chatMessage{ + Role: "tool", Content: content, ToolName: res.Name, + }) + } + default: + cm := chatMessage{Role: string(msg.Role), Content: msg.Text()} + for _, part := range msg.Parts { + if img, ok := part.(llm.ImagePart); ok { + cm.Images = append(cm.Images, base64.StdEncoding.EncodeToString(img.Data)) + } + } + for _, tc := range msg.ToolCalls { + args := tc.Arguments + if len(args) == 0 { + args = json.RawMessage("{}") + } + cm.ToolCalls = append(cm.ToolCalls, toolCall{ + ID: tc.ID, + Function: toolCallFunc{Name: tc.Name, Arguments: args}, + }) + } + out.Messages = append(out.Messages, cm) + } + } + + // Tools. Ollama has no tool_choice: "none" maps to omitting the tools; + // "required"/named choices have no wire equivalent and are best-effort + // ignored (documented in the README support matrix). + if req.ToolChoice != "none" { + for _, t := range req.Tools { + params := t.Parameters + if len(params) == 0 { + params = json.RawMessage(`{"type":"object","properties":{}}`) + } + out.Tools = append(out.Tools, toolDef{ + Type: "function", + Function: toolDefFunc{Name: t.Name, Description: t.Description, Parameters: params}, + }) + } + } + + if len(req.Schema) > 0 { + out.Format = req.Schema + } + + opts := make(map[string]any) + if req.Temperature != nil { + opts["temperature"] = *req.Temperature + } + if req.TopP != nil { + opts["top_p"] = *req.TopP + } + if req.MaxTokens > 0 { + opts["num_predict"] = req.MaxTokens + } + if len(req.StopSequences) > 0 { + opts["stop"] = req.StopSequences + } + if len(opts) > 0 { + out.Options = opts + } + + switch req.ReasoningEffort { + case "": + case "low", "medium", "high": + out.Think = json.RawMessage(strconv.Quote(req.ReasoningEffort)) + default: + return nil, fmt.Errorf("ollama: invalid reasoning effort %q (want low/medium/high)", req.ReasoningEffort) + } + + return out, nil +} + +// do POSTs /api/chat and returns the response body on 2xx, or a classified +// error. +func (m *model) do(ctx context.Context, wireReq *chatRequest) (*http.Response, error) { + p := m.provider + if err := p.checkReady(); err != nil { + return nil, err + } + body, err := json.Marshal(wireReq) + if err != nil { + return nil, fmt.Errorf("ollama: encode request: %w", err) + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/chat", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("ollama: build request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + if p.token != "" { + httpReq.Header.Set("Authorization", "Bearer "+p.token) + } + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("ollama %s: do request: %w", m.qualified(), err) + } + if resp.StatusCode/100 != 2 { + defer resp.Body.Close() + raw, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10)) + var eb errorBody + _ = json.Unmarshal(raw, &eb) + msg := eb.Error + if msg == "" { + msg = strings.TrimSpace(string(raw)) + } + return nil, &llm.APIError{ + Provider: p.name, Model: m.id, + Status: resp.StatusCode, Message: msg, + } + } + return resp, nil +} + +// Generate implements llm.Model. +func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) { + req = req.Apply(opts...) + if err := m.enforceCapabilities(req); err != nil { + return nil, err + } + wireReq, err := m.buildRequest(req, false) + if err != nil { + return nil, err + } + resp, err := m.do(ctx, wireReq) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var cr chatResponse + if err := json.NewDecoder(resp.Body).Decode(&cr); err != nil { + return nil, fmt.Errorf("ollama %s: decode response: %w", m.qualified(), err) + } + return m.toResponse(&cr), nil +} + +// toResponse converts a final wire chunk into the canonical response. +func (m *model) toResponse(cr *chatResponse) *llm.Response { + out := &llm.Response{ + Model: m.qualified(), + Usage: llm.Usage{InputTokens: cr.PromptEvalCount, OutputTokens: cr.EvalCount}, + Raw: cr, + } + if cr.Message.Content != "" { + out.Parts = append(out.Parts, llm.Text(cr.Message.Content)) + } + out.ToolCalls = convertToolCalls(cr.Message.ToolCalls) + out.FinishReason = finishReason(cr.DoneReason, len(out.ToolCalls) > 0) + return out +} + +// convertToolCalls maps wire tool calls, synthesizing ids where the model +// omitted them (ids are optional in Ollama's shape but required by our +// agent loop to match results to calls). +func convertToolCalls(calls []toolCall) []llm.ToolCall { + out := make([]llm.ToolCall, 0, len(calls)) + for i, tc := range calls { + id := tc.ID + if id == "" { + id = "call_" + strconv.Itoa(i) + } + args := tc.Function.Arguments + if len(args) == 0 { + args = json.RawMessage("{}") + } + out = append(out, llm.ToolCall{ID: id, Name: tc.Function.Name, Arguments: args}) + } + if len(out) == 0 { + return nil + } + return out +} + +func finishReason(doneReason string, hasToolCalls bool) llm.FinishReason { + if hasToolCalls { + return llm.FinishToolCalls + } + switch doneReason { + case "stop", "": + return llm.FinishStop + case "length": + return llm.FinishLength + default: + return llm.FinishOther + } +} diff --git a/provider/openai/model.go b/provider/openai/model.go new file mode 100644 index 0000000..2744a25 --- /dev/null +++ b/provider/openai/model.go @@ -0,0 +1,222 @@ +package openai + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// model is one provider-bound target. +type model struct { + p *Provider + id string + caps llm.Capabilities +} + +// Capabilities implements llm.Model. +func (m *model) Capabilities() llm.Capabilities { return m.caps } + +// Generate implements llm.Model. +func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) { + req = req.Apply(opts...) + if err := checkRequest(m.caps, req); err != nil { + return nil, err + } + httpResp, err := m.do(ctx, req, false) + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + if httpResp.StatusCode/100 != 2 { + return nil, m.apiError(httpResp) + } + var wire chatResponse + if err := json.NewDecoder(httpResp.Body).Decode(&wire); err != nil { + return nil, fmt.Errorf("openai: decode response: %w", err) + } + return m.toResponse(&wire), nil +} + +// Stream implements llm.Model. +func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) { + req = req.Apply(opts...) + if !m.caps.SupportsStreaming { + return nil, fmt.Errorf("%w: streaming not supported by %s/%s", llm.ErrUnsupported, m.p.name, m.id) + } + if err := checkRequest(m.caps, req); err != nil { + return nil, err + } + httpResp, err := m.do(ctx, req, true) + if err != nil { + return nil, err + } + if httpResp.StatusCode/100 != 2 { + defer httpResp.Body.Close() + return nil, m.apiError(httpResp) + } + sc := bufio.NewScanner(httpResp.Body) + // Why: a single SSE data line carries a whole JSON chunk; tool-call + // argument fragments can make lines far larger than Scanner's 64 KiB + // default cap. + sc.Buffer(make([]byte, 0, 64*1024), 16<<20) + return &stream{m: m, body: httpResp.Body, sc: sc}, nil +} + +// do builds and performs the HTTP request. Transport failures are wrapped +// raw (never as *llm.APIError) so llm.Classify still sees net.Error, +// syscall errnos, and context errors underneath. +func (m *model) do(ctx context.Context, req llm.Request, stream bool) (*http.Response, error) { + if m.p.apiKey == "" { + // Why a synthetic 401: the constructor never fails, so a missing + // key must surface at request time as the auth failure it is — + // permanent under llm.Classify, like a real 401. + return nil, &llm.APIError{ + Provider: m.p.name, + Model: m.id, + Status: http.StatusUnauthorized, + Code: "missing_api_key", + Message: "no API key configured: set OPENAI_API_KEY or use WithAPIKey", + } + } + body, err := json.Marshal(m.buildRequest(req, stream)) + if err != nil { + return nil, fmt.Errorf("openai: encode request: %w", err) + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, m.p.baseURL+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("openai: build request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+m.p.apiKey) + if stream { + httpReq.Header.Set("Accept", "text/event-stream") + } + httpResp, err := m.p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("openai: do request: %w", err) + } + return httpResp, nil +} + +// apiError converts a non-2xx response into *llm.APIError, pulling code and +// message from the {"error":{...}} body when it parses. +func (m *model) apiError(httpResp *http.Response) error { + apiErr := &llm.APIError{Provider: m.p.name, Model: m.id, Status: httpResp.StatusCode} + body, _ := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) + var env errorEnvelope + if err := json.Unmarshal(body, &env); err == nil && + (env.Error.Message != "" || env.Error.Type != "" || env.Error.Code != "") { + apiErr.Message = env.Error.Message + apiErr.Code = env.Error.Code + if apiErr.Code == "" { + apiErr.Code = env.Error.Type + } + } else { + // Why: compat servers emit all sorts of error bodies; a raw snippet + // beats silence when the canonical envelope is absent. + apiErr.Message = strings.TrimSpace(string(body)) + } + return apiErr +} + +// toResponse maps the wire response onto the canonical llm.Response. +func (m *model) toResponse(wire *chatResponse) *llm.Response { + resp := &llm.Response{Model: m.p.name + "/" + m.id, Raw: wire} + if wire.Usage != nil { + resp.Usage = llm.Usage{ + InputTokens: wire.Usage.PromptTokens, + OutputTokens: wire.Usage.CompletionTokens, + } + } + if len(wire.Choices) == 0 { + resp.FinishReason = llm.FinishOther + return resp + } + choice := wire.Choices[0] + if choice.Message.Content != "" { + resp.Parts = append(resp.Parts, llm.TextPart{Text: choice.Message.Content}) + } + for i, tc := range choice.Message.ToolCalls { + id := tc.ID + if id == "" { + // Why: ToolResult.ID must echo ToolCall.ID, so calls from compat + // servers that omit ids get synthesized ones. + id = fmt.Sprintf("call_%d", i) + } + resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{ + ID: id, + Name: tc.Function.Name, + Arguments: json.RawMessage(tc.Function.Arguments), + }) + } + resp.FinishReason = mapFinish(choice.FinishReason, len(resp.ToolCalls) > 0) + return resp +} + +// mapFinish maps a wire finish_reason to the canonical enum. Tool-call +// presence wins over the reported reason: a forced (named tool_choice) call +// can finish with "stop" while still carrying tool_calls. +func mapFinish(reason string, hasToolCalls bool) llm.FinishReason { + if hasToolCalls { + return llm.FinishToolCalls + } + switch reason { + case "stop": + return llm.FinishStop + case "length": + return llm.FinishLength + case "tool_calls": + return llm.FinishToolCalls + case "content_filter": + return llm.FinishContentFilter + default: + return llm.FinishOther + } +} + +// checkRequest enforces the model's effective capabilities. Why enforcement +// rather than normalization: a separate media layer resizes/transcodes +// images BEFORE requests reach the provider; this check is the honest +// backstop that refuses, with llm.ErrUnsupported, what the target +// declaredly cannot serve (chains advance past it penalty-free). +func checkRequest(caps llm.Capabilities, req llm.Request) error { + if len(req.Tools) > 0 && !caps.SupportsTools { + return fmt.Errorf("%w: tools not supported", llm.ErrUnsupported) + } + if len(req.Schema) > 0 && !caps.SupportsStructured { + return fmt.Errorf("%w: structured output not supported", llm.ErrUnsupported) + } + images := 0 + for _, msg := range req.Messages { + for _, part := range msg.Parts { + img, ok := part.(llm.ImagePart) + if !ok { + continue + } + images++ + if !caps.SupportsImages() { + return fmt.Errorf("%w: image input not supported", llm.ErrUnsupported) + } + if !caps.MIMEAllowed(img.MIME) { + return fmt.Errorf("%w: image MIME type %q not allowed (allowed: %s)", + llm.ErrUnsupported, img.MIME, strings.Join(caps.AllowedImageMIME, ", ")) + } + if caps.MaxImageBytes > 0 && len(img.Data) > caps.MaxImageBytes { + return fmt.Errorf("%w: image is %d bytes, limit is %d", + llm.ErrUnsupported, len(img.Data), caps.MaxImageBytes) + } + } + } + if images > caps.MaxImagesPerReq { + return fmt.Errorf("%w: request carries %d images, limit is %d", + llm.ErrUnsupported, images, caps.MaxImagesPerReq) + } + return nil +} diff --git a/provider/openai/openai.go b/provider/openai/openai.go new file mode 100644 index 0000000..f8956dd --- /dev/null +++ b/provider/openai/openai.go @@ -0,0 +1,133 @@ +// Package openai implements llm.Provider for the OpenAI Chat Completions +// API and, via WithBaseURL/WithName, any OpenAI-compatible endpoint +// (vLLM, Groq, Together, LM Studio, Ollama's /v1 shim, ...). +// +// Targeted API surface (verified against developers.openai.com, June 2026): +// POST {base}/chat/completions with +// - messages: plain-string content for text-only turns, part arrays with +// base64 data-URL image_url entries for multimodal turns, assistant +// tool_calls history, and {"role":"tool","tool_call_id",...} results; +// - tools as {"type":"function","function":{...}} with tool_choice +// "auto"/"none"/"required" or a named-function object; +// - response_format {"type":"json_schema",...} structured output; +// - max_completion_tokens (or legacy max_tokens via WithLegacyMaxTokens +// for compat servers), temperature, top_p, stop, reasoning_effort; +// - data-only SSE streaming with stream_options.include_usage, the +// "data: [DONE]" sentinel, and tool-call deltas accumulated by index. +// +// Newer response fields (refusal, annotations, usage *_details, delta +// obfuscation) are tolerated and ignored so both api.openai.com and older +// compat servers decode cleanly. +package openai + +import ( + "net/http" + "os" + "strings" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +const defaultBaseURL = "https://api.openai.com/v1" + +// Provider is an llm.Provider backed by an OpenAI Chat Completions endpoint. +type Provider struct { + name string + apiKey string + baseURL string + client *http.Client + caps llm.Capabilities + legacyMaxTokens bool +} + +// Option configures the provider at construction. +type Option func(*Provider) + +// WithAPIKey sets the API key. When absent, New reads OPENAI_API_KEY from +// the environment at construction time. +func WithAPIKey(key string) Option { + return func(p *Provider) { p.apiKey = key } +} + +// WithBaseURL points the client at a different endpoint (compat servers). +// The path "/chat/completions" is appended; a trailing slash is trimmed. +func WithBaseURL(u string) Option { + return func(p *Provider) { p.baseURL = u } +} + +// WithHTTPClient substitutes the HTTP client (timeouts, proxies, tests). +func WithHTTPClient(c *http.Client) Option { + return func(p *Provider) { + if c != nil { + p.client = c + } + } +} + +// WithName overrides the registry name ("openai" by default). Why: the same +// client serves many OpenAI-compatible endpoints, and each needs a distinct +// name in "provider/model" specs and error reporting. +func WithName(name string) Option { + return func(p *Provider) { p.name = name } +} + +// WithDefaultCapabilities replaces the provider-default capabilities. +// Per-model overrides via llm.WithCapabilities still take precedence. +func WithDefaultCapabilities(caps llm.Capabilities) Option { + return func(p *Provider) { p.caps = caps } +} + +// WithLegacyMaxTokens sends Request.MaxTokens as "max_tokens" instead of +// "max_completion_tokens". Why: OpenAI deprecated max_tokens, but many +// third-party compat servers still only honor the legacy field. +func WithLegacyMaxTokens() Option { + return func(p *Provider) { p.legacyMaxTokens = true } +} + +// defaultCapabilities reflects OpenAI's current vision-capable chat models. +// Why these limits: the published per-request caps (1500 images, 512 MB) +// are far beyond what compat servers accept; 100 images / 20 MB each is a +// conservative envelope, and the MIME list is the documented set (PNG, +// JPEG, WEBP, non-animated GIF). +func defaultCapabilities() llm.Capabilities { + return llm.Capabilities{ + SupportsTools: true, + SupportsStructured: true, + SupportsStreaming: true, + MaxImagesPerReq: 100, + MaxImageBytes: 20 << 20, + AllowedImageMIME: []string{"image/jpeg", "image/png", "image/webp", "image/gif"}, + } +} + +// New creates a Provider. It never fails: a missing API key surfaces as a +// 401-style *llm.APIError at request time, not at construction. +func New(opts ...Option) *Provider { + p := &Provider{ + name: "openai", + apiKey: os.Getenv("OPENAI_API_KEY"), + baseURL: defaultBaseURL, + client: http.DefaultClient, + caps: defaultCapabilities(), + } + for _, opt := range opts { + opt(p) + } + p.baseURL = strings.TrimRight(p.baseURL, "/") + return p +} + +// Name implements llm.Provider. +func (p *Provider) Name() string { return p.name } + +// Model implements llm.Provider. The id is passed through verbatim — no +// catalog validation; unknown models fail at request time with the +// backend's own error. +func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) { + cfg := llm.ApplyModelOptions(opts) + caps := p.caps + if cfg.Capabilities != nil { + caps = *cfg.Capabilities + } + return &model{p: p, id: id, caps: caps}, nil +} diff --git a/provider/openai/openai_test.go b/provider/openai/openai_test.go new file mode 100644 index 0000000..09843a0 --- /dev/null +++ b/provider/openai/openai_test.go @@ -0,0 +1,614 @@ +package openai + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +var ( + _ llm.Provider = (*Provider)(nil) + _ llm.Model = (*model)(nil) + _ llm.Stream = (*stream)(nil) +) + +const textResponse = `{ + "id": "chatcmpl-1", "object": "chat.completion", "created": 1741570283, "model": "gpt-test", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "hello", "refusal": null, "annotations": []}, + "logprobs": null, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 19, "completion_tokens": 10, "total_tokens": 29, + "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, + "completion_tokens_details": {"reasoning_tokens": 0} + }, + "service_tier": "default", "system_fingerprint": "fp_x" +}` + +// recorded captures the last request a test server received. +type recorded struct { + body map[string]any + header http.Header + path string + hits int +} + +// newServer starts a test server that records the request and replies with +// a fixed status and body. +func newServer(t *testing.T, status int, respBody string) (*httptest.Server, *recorded) { + t.Helper() + rec := &recorded{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rec.hits++ + rec.header = r.Header.Clone() + rec.path = r.URL.Path + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read request body: %v", err) + } + if len(raw) > 0 { + if err := json.Unmarshal(raw, &rec.body); err != nil { + t.Errorf("request body is not JSON: %v\n%s", err, raw) + } + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + io.WriteString(w, respBody) + })) + t.Cleanup(srv.Close) + return srv, rec +} + +func testModel(t *testing.T, srv *httptest.Server, popts []Option, mopts ...llm.ModelOption) llm.Model { + t.Helper() + opts := append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, popts...) + m, err := New(opts...).Model("gpt-test", mopts...) + if err != nil { + t.Fatalf("Model: %v", err) + } + return m +} + +func fptr(f float64) *float64 { return &f } + +func TestGenerateRequestShape(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + m := testModel(t, srv, nil) + + req := llm.Request{ + System: "base system", + Messages: []llm.Message{ + llm.SystemText("folded system"), + llm.UserParts(llm.Text("look:"), llm.Image("image/png", []byte{1, 2, 3})), + { + Role: llm.RoleAssistant, + Parts: []llm.Part{llm.Text("checking")}, + ToolCalls: []llm.ToolCall{ + {ID: "call_1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Boston"}`)}, + }, + }, + llm.ToolResultsMessage( + llm.ToolResult{ID: "call_1", Name: "get_weather", Content: "72F"}, + llm.ToolResult{ID: "call_2", Name: "get_weather", Content: "boom", IsError: true}, + ), + llm.UserText("thanks"), + }, + Tools: []llm.Tool{{ + Name: "get_weather", + Description: "Get current weather", + Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`), + }}, + ToolChoice: "auto", + Temperature: fptr(0.5), + TopP: fptr(0.9), + MaxTokens: 256, + StopSequences: []string{"END"}, + ReasoningEffort: "high", + Schema: json.RawMessage(`{"type":"object","properties":{"ok":{"type":"boolean"}}}`), + SchemaName: "verdict", + } + if _, err := m.Generate(context.Background(), req); err != nil { + t.Fatalf("Generate: %v", err) + } + + want := map[string]any{ + "model": "gpt-test", + "messages": []any{ + map[string]any{"role": "system", "content": "base system\n\nfolded system"}, + map[string]any{"role": "user", "content": []any{ + map[string]any{"type": "text", "text": "look:"}, + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,AQID"}}, + }}, + map[string]any{"role": "assistant", "content": "checking", "tool_calls": []any{ + map[string]any{"id": "call_1", "type": "function", "function": map[string]any{ + "name": "get_weather", "arguments": `{"city":"Boston"}`, + }}, + }}, + map[string]any{"role": "tool", "content": "72F", "tool_call_id": "call_1"}, + map[string]any{"role": "tool", "content": "ERROR: boom", "tool_call_id": "call_2"}, + map[string]any{"role": "user", "content": "thanks"}, + }, + "tools": []any{ + map[string]any{"type": "function", "function": map[string]any{ + "name": "get_weather", + "description": "Get current weather", + "parameters": map[string]any{"type": "object", "properties": map[string]any{"city": map[string]any{"type": "string"}}}, + }}, + }, + "tool_choice": "auto", + "temperature": 0.5, + "top_p": 0.9, + "max_completion_tokens": float64(256), + "stop": []any{"END"}, + "reasoning_effort": "high", + "response_format": map[string]any{"type": "json_schema", "json_schema": map[string]any{ + "name": "verdict", + "schema": map[string]any{"type": "object", "properties": map[string]any{"ok": map[string]any{"type": "boolean"}}}, + }}, + } + if !reflect.DeepEqual(rec.body, want) { + got, _ := json.MarshalIndent(rec.body, "", " ") + exp, _ := json.MarshalIndent(want, "", " ") + t.Errorf("request body mismatch\ngot:\n%s\nwant:\n%s", got, exp) + } +} + +func TestToolChoiceForms(t *testing.T) { + tests := []struct { + choice string + want any // nil = key absent + }{ + {"", nil}, + {"auto", "auto"}, + {"none", "none"}, + {"required", "required"}, + {"get_weather", map[string]any{"type": "function", "function": map[string]any{"name": "get_weather"}}}, + } + for _, tt := range tests { + t.Run("choice="+tt.choice, func(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + m := testModel(t, srv, nil) + req := llm.Request{ + Messages: []llm.Message{llm.UserText("hi")}, + Tools: []llm.Tool{{Name: "get_weather"}}, + ToolChoice: tt.choice, + } + if _, err := m.Generate(context.Background(), req); err != nil { + t.Fatalf("Generate: %v", err) + } + got, present := rec.body["tool_choice"] + if tt.want == nil { + if present { + t.Errorf("tool_choice present, want omitted: %v", got) + } + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("tool_choice = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMaxTokensField(t *testing.T) { + t.Run("default uses max_completion_tokens", func(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + m := testModel(t, srv, nil) + req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64} + if _, err := m.Generate(context.Background(), req); err != nil { + t.Fatalf("Generate: %v", err) + } + if got := rec.body["max_completion_tokens"]; got != float64(64) { + t.Errorf("max_completion_tokens = %v, want 64", got) + } + if _, present := rec.body["max_tokens"]; present { + t.Error("max_tokens present, want omitted") + } + }) + t.Run("WithLegacyMaxTokens uses max_tokens", func(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + m := testModel(t, srv, []Option{WithLegacyMaxTokens()}) + req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64} + if _, err := m.Generate(context.Background(), req); err != nil { + t.Fatalf("Generate: %v", err) + } + if got := rec.body["max_tokens"]; got != float64(64) { + t.Errorf("max_tokens = %v, want 64", got) + } + if _, present := rec.body["max_completion_tokens"]; present { + t.Error("max_completion_tokens present, want omitted") + } + }) + t.Run("zero omits both", func(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + m := testModel(t, srv, nil) + req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}} + if _, err := m.Generate(context.Background(), req); err != nil { + t.Fatalf("Generate: %v", err) + } + if _, present := rec.body["max_tokens"]; present { + t.Error("max_tokens present, want omitted") + } + if _, present := rec.body["max_completion_tokens"]; present { + t.Error("max_completion_tokens present, want omitted") + } + }) +} + +func TestSchemaNameDefault(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + m := testModel(t, srv, nil) + req := llm.Request{ + Messages: []llm.Message{llm.UserText("hi")}, + Schema: json.RawMessage(`{"type":"object"}`), + } + if _, err := m.Generate(context.Background(), req); err != nil { + t.Fatalf("Generate: %v", err) + } + rf, ok := rec.body["response_format"].(map[string]any) + if !ok { + t.Fatalf("response_format missing: %v", rec.body) + } + js, ok := rf["json_schema"].(map[string]any) + if !ok { + t.Fatalf("json_schema missing: %v", rf) + } + if js["name"] != "response" { + t.Errorf("schema name = %v, want %q", js["name"], "response") + } +} + +func TestGenerateTextResponse(t *testing.T) { + srv, _ := newServer(t, http.StatusOK, textResponse) + m := testModel(t, srv, nil) + resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Generate: %v", err) + } + if got := resp.Text(); got != "hello" { + t.Errorf("Text = %q, want %q", got, "hello") + } + if resp.FinishReason != llm.FinishStop { + t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishStop) + } + if resp.Usage != (llm.Usage{InputTokens: 19, OutputTokens: 10}) { + t.Errorf("Usage = %+v, want {19 10}", resp.Usage) + } + if resp.Model != "openai/gpt-test" { + t.Errorf("Model = %q, want %q", resp.Model, "openai/gpt-test") + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls = %v, want none", resp.ToolCalls) + } + if resp.Raw == nil { + t.Error("Raw is nil, want wire response") + } +} + +func TestGenerateToolCallResponse(t *testing.T) { + const body = `{ + "id": "chatcmpl-2", "object": "chat.completion", "created": 1, "model": "gpt-test", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": null, "tool_calls": [ + {"id": "call_9", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Boston\"}"}}, + {"id": "", "type": "function", "function": {"name": "get_time", "arguments": "{}"}} + ]}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7} + }` + srv, _ := newServer(t, http.StatusOK, body) + m := testModel(t, srv, nil) + resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Generate: %v", err) + } + if len(resp.ToolCalls) != 2 { + t.Fatalf("ToolCalls = %d, want 2", len(resp.ToolCalls)) + } + tc := resp.ToolCalls[0] + if tc.ID != "call_9" || tc.Name != "get_weather" || string(tc.Arguments) != `{"city":"Boston"}` { + t.Errorf("ToolCalls[0] = %+v", tc) + } + if resp.ToolCalls[1].ID != "call_1" { + t.Errorf("synthesized ID = %q, want %q", resp.ToolCalls[1].ID, "call_1") + } + // finish_reason "stop" with tool_calls present: presence wins. + if resp.FinishReason != llm.FinishToolCalls { + t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishToolCalls) + } + if len(resp.Parts) != 0 { + t.Errorf("Parts = %v, want none", resp.Parts) + } +} + +func TestFinishReasonMapping(t *testing.T) { + tests := []struct { + wire string + want llm.FinishReason + }{ + {"stop", llm.FinishStop}, + {"length", llm.FinishLength}, + {"tool_calls", llm.FinishToolCalls}, + {"content_filter", llm.FinishContentFilter}, + {"function_call", llm.FinishOther}, + {"weird_new_reason", llm.FinishOther}, + } + for _, tt := range tests { + t.Run(tt.wire, func(t *testing.T) { + body := `{"choices":[{"index":0,"message":{"role":"assistant","content":"x"},"finish_reason":"` + tt.wire + `"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` + srv, _ := newServer(t, http.StatusOK, body) + m := testModel(t, srv, nil) + resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Generate: %v", err) + } + if resp.FinishReason != tt.want { + t.Errorf("FinishReason = %v, want %v", resp.FinishReason, tt.want) + } + }) + } +} + +func TestAPIErrorMapping(t *testing.T) { + t.Run("429 rate limit is transient", func(t *testing.T) { + const body = `{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}` + srv, _ := newServer(t, http.StatusTooManyRequests, body) + m := testModel(t, srv, nil) + _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("err = %v (%T), want *llm.APIError", err, err) + } + if apiErr.Status != http.StatusTooManyRequests { + t.Errorf("Status = %d, want 429", apiErr.Status) + } + if apiErr.Code != "rate_limit_exceeded" { + t.Errorf("Code = %q, want %q", apiErr.Code, "rate_limit_exceeded") + } + if apiErr.Message != "Rate limit reached" { + t.Errorf("Message = %q", apiErr.Message) + } + if apiErr.Provider != "openai" || apiErr.Model != "gpt-test" { + t.Errorf("Provider/Model = %q/%q", apiErr.Provider, apiErr.Model) + } + if got := llm.Classify(err); got != llm.ClassTransient { + t.Errorf("Classify = %v, want transient", got) + } + }) + t.Run("401 code null falls back to type, permanent", func(t *testing.T) { + const body = `{"error":{"message":"Incorrect API key provided","type":"authentication_error","param":null,"code":null}}` + srv, _ := newServer(t, http.StatusUnauthorized, body) + m := testModel(t, srv, nil) + _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("err = %v (%T), want *llm.APIError", err, err) + } + if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" { + t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code) + } + if got := llm.Classify(err); got != llm.ClassPermanent { + t.Errorf("Classify = %v, want permanent", got) + } + }) + t.Run("non-JSON body becomes message", func(t *testing.T) { + srv, _ := newServer(t, http.StatusServiceUnavailable, "upstream exploded\n") + m := testModel(t, srv, nil) + _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("err = %v (%T), want *llm.APIError", err, err) + } + if apiErr.Status != http.StatusServiceUnavailable || apiErr.Message != "upstream exploded" { + t.Errorf("Status/Message = %d/%q", apiErr.Status, apiErr.Message) + } + }) +} + +func TestMissingAPIKey(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "") + srv, rec := newServer(t, http.StatusOK, textResponse) + m, err := New(WithBaseURL(srv.URL)).Model("gpt-test") + if err != nil { + t.Fatalf("Model: %v", err) + } + _, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("err = %v (%T), want *llm.APIError", err, err) + } + if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "missing_api_key" { + t.Errorf("Status/Code = %d/%q, want 401/missing_api_key", apiErr.Status, apiErr.Code) + } + if rec.hits != 0 { + t.Errorf("server hit %d times, want 0", rec.hits) + } +} + +func TestEnvAPIKeyReadAtConstruction(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + t.Setenv("OPENAI_API_KEY", "env-secret") + p := New(WithBaseURL(srv.URL)) + t.Setenv("OPENAI_API_KEY", "changed-later") // must not affect p + m, err := p.Model("gpt-test") + if err != nil { + t.Fatalf("Model: %v", err) + } + if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil { + t.Fatalf("Generate: %v", err) + } + if got := rec.header.Get("Authorization"); got != "Bearer env-secret" { + t.Errorf("Authorization = %q, want %q", got, "Bearer env-secret") + } +} + +func TestAuthAndContentTypeHeaders(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + m := testModel(t, srv, nil) + if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil { + t.Fatalf("Generate: %v", err) + } + if got := rec.header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization = %q, want %q", got, "Bearer test-key") + } + if got := rec.header.Get("Content-Type"); got != "application/json" { + t.Errorf("Content-Type = %q, want application/json", got) + } + if rec.path != "/chat/completions" { + t.Errorf("path = %q, want /chat/completions", rec.path) + } +} + +func TestCompatEndpointNameAndBaseURL(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + p := New(WithName("groq"), WithAPIKey("k"), WithBaseURL(srv.URL+"/openai/v1/")) + if p.Name() != "groq" { + t.Errorf("Name = %q, want groq", p.Name()) + } + m, err := p.Model("llama-3.3-70b") + if err != nil { + t.Fatalf("Model: %v", err) + } + resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Generate: %v", err) + } + if rec.path != "/openai/v1/chat/completions" { + t.Errorf("path = %q, want /openai/v1/chat/completions (trailing slash trimmed)", rec.path) + } + if resp.Model != "groq/llama-3.3-70b" { + t.Errorf("Model = %q, want groq/llama-3.3-70b", resp.Model) + } + if rec.body["model"] != "llama-3.3-70b" { + t.Errorf("wire model = %v, want llama-3.3-70b (verbatim)", rec.body["model"]) + } +} + +func TestCapabilityEnforcement(t *testing.T) { + img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) } + tests := []struct { + name string + caps *llm.Capabilities // nil = provider defaults + msg llm.Message + }{ + { + name: "images unsupported", + caps: &llm.Capabilities{SupportsTools: true, SupportsStreaming: true}, + msg: llm.UserParts(img("image/png", 4)), + }, + { + name: "too many images", + caps: &llm.Capabilities{MaxImagesPerReq: 1}, + msg: llm.UserParts(img("image/png", 4), img("image/png", 4)), + }, + { + name: "disallowed MIME under defaults", + msg: llm.UserParts(img("image/bmp", 4)), + }, + { + name: "image too large", + caps: &llm.Capabilities{MaxImagesPerReq: 4, MaxImageBytes: 2}, + msg: llm.UserParts(img("image/png", 3)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + var mopts []llm.ModelOption + if tt.caps != nil { + mopts = append(mopts, llm.WithCapabilities(*tt.caps)) + } + m := testModel(t, srv, nil, mopts...) + _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{tt.msg}}) + if !errors.Is(err, llm.ErrUnsupported) { + t.Fatalf("err = %v, want ErrUnsupported", err) + } + if got := llm.Classify(err); got != llm.ClassPermanent { + t.Errorf("Classify = %v, want permanent", got) + } + if rec.hits != 0 { + t.Errorf("server hit %d times, want 0 (must refuse before sending)", rec.hits) + } + }) + } + + t.Run("streaming unsupported", func(t *testing.T) { + srv, rec := newServer(t, http.StatusOK, textResponse) + m := testModel(t, srv, nil, llm.WithCapabilities(llm.Capabilities{SupportsTools: true})) + _, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if !errors.Is(err, llm.ErrUnsupported) { + t.Fatalf("err = %v, want ErrUnsupported", err) + } + if rec.hits != 0 { + t.Errorf("server hit %d times, want 0", rec.hits) + } + }) +} + +func TestModelCapabilitiesOverride(t *testing.T) { + p := New(WithAPIKey("k")) + def, err := p.Model("a") + if err != nil { + t.Fatalf("Model: %v", err) + } + if caps := def.Capabilities(); !caps.SupportsTools || caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 20<<20 { + t.Errorf("default caps = %+v", caps) + } + custom := llm.Capabilities{SupportsStreaming: true, ContextWindow: 8192} + ovr, err := p.Model("b", llm.WithCapabilities(custom)) + if err != nil { + t.Fatalf("Model: %v", err) + } + if got := ovr.Capabilities(); !reflect.DeepEqual(got, custom) { + t.Errorf("override caps = %+v, want %+v", got, custom) + } +} + +func TestTransportErrorIsNotAPIError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + url := srv.URL + srv.Close() // guarantee connection refused + p := New(WithAPIKey("k"), WithBaseURL(url)) + m, err := p.Model("gpt-test") + if err != nil { + t.Fatalf("Model: %v", err) + } + _, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err == nil { + t.Fatal("Generate succeeded against closed server") + } + if _, ok := errors.AsType[*llm.APIError](err); ok { + t.Errorf("transport error wrapped in APIError: %v", err) + } + if !strings.Contains(err.Error(), "openai: do request") { + t.Errorf("err = %v, want openai: do request context", err) + } + if got := llm.Classify(err); got != llm.ClassTransient { + t.Errorf("Classify = %v, want transient (net error must stay visible)", got) + } +} + +func TestDecodeErrorWrapped(t *testing.T) { + srv, _ := newServer(t, http.StatusOK, "{not json") + m := testModel(t, srv, nil) + _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err == nil || !strings.Contains(err.Error(), "openai: decode response") { + t.Errorf("err = %v, want decode response context", err) + } + if _, ok := errors.AsType[*llm.APIError](err); ok { + t.Errorf("decode error wrapped in APIError: %v", err) + } +} diff --git a/provider/openai/stream.go b/provider/openai/stream.go new file mode 100644 index 0000000..a6a7e76 --- /dev/null +++ b/provider/openai/stream.go @@ -0,0 +1,183 @@ +package openai + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "strings" + "sync" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// stream consumes the data-only SSE stream of chat.completion.chunk events. +// +// Delivery contract: TextDelta events as content fragments arrive; ToolCall +// events only once fully assembled (fragments are buffered internally and +// flushed at stream end — simplest correct handling of interleaved parallel +// calls); exactly one final Response event; then io.EOF. +type stream struct { + m *model + body io.ReadCloser + sc *bufio.Scanner + + closeOnce sync.Once + closeErr error + + queue []llm.StreamEvent + done bool // finalize ran; drain queue then io.EOF + + text strings.Builder + calls []*toolCallAcc // first-appearance order + byIndex map[int]*toolCallAcc + finish string + usage llm.Usage +} + +// toolCallAcc accumulates one tool call's fragments. The id and name arrive +// on the first fragment for an index; arguments arrive as string pieces to +// concatenate. +type toolCallAcc struct { + id string + name string + args strings.Builder +} + +// Next implements llm.Stream. +func (s *stream) Next() (llm.StreamEvent, error) { + for { + if len(s.queue) > 0 { + ev := s.queue[0] + s.queue = s.queue[1:] + return ev, nil + } + if s.done { + return llm.StreamEvent{}, io.EOF + } + if !s.sc.Scan() { + if err := s.sc.Err(); err != nil { + return llm.StreamEvent{}, fmt.Errorf("openai: read stream: %w", err) + } + // Why: some compat servers close the body without a [DONE] + // sentinel; a clean EOF still finalizes with what arrived. + s.finalize() + continue + } + line := strings.TrimSpace(s.sc.Text()) + if !strings.HasPrefix(line, "data:") { + continue // SSE comments, event:/id: fields, blank separators + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" { + continue + } + if payload == "[DONE]" { + s.finalize() + continue + } + if err := s.handleChunk([]byte(payload)); err != nil { + return llm.StreamEvent{}, err + } + } +} + +// handleChunk folds one chat.completion.chunk into the stream state, +// queueing any events it produces. +func (s *stream) handleChunk(data []byte) error { + var chunk streamChunk + if err := json.Unmarshal(data, &chunk); err != nil { + return fmt.Errorf("openai: decode stream chunk: %w", err) + } + if chunk.Error != nil { + // Mid-stream error event on an otherwise-200 stream. Status stays 0: + // there is no failing HTTP status to report. + apiErr := &llm.APIError{ + Provider: s.m.p.name, + Model: s.m.id, + Code: chunk.Error.Code, + Message: chunk.Error.Message, + } + if apiErr.Code == "" { + apiErr.Code = chunk.Error.Type + } + return apiErr + } + if chunk.Usage != nil { + s.usage = llm.Usage{ + InputTokens: chunk.Usage.PromptTokens, + OutputTokens: chunk.Usage.CompletionTokens, + } + } + // Why the guard: the include_usage chunk arrives with an EMPTY choices + // array; indexing choices[0] unconditionally would panic on it. + if len(chunk.Choices) == 0 { + return nil + } + choice := chunk.Choices[0] + if choice.FinishReason != "" { + s.finish = choice.FinishReason + } + if choice.Delta.Content != "" { + s.text.WriteString(choice.Delta.Content) + s.queue = append(s.queue, llm.StreamEvent{TextDelta: choice.Delta.Content}) + } + for _, tc := range choice.Delta.ToolCalls { + acc := s.byIndex[tc.Index] + if acc == nil { + if s.byIndex == nil { + s.byIndex = make(map[int]*toolCallAcc) + } + acc = &toolCallAcc{} + s.byIndex[tc.Index] = acc + s.calls = append(s.calls, acc) + } + if tc.ID != "" { + acc.id = tc.ID + } + if tc.Function.Name != "" { + acc.name = tc.Function.Name + } + acc.args.WriteString(tc.Function.Arguments) + } + return nil +} + +// finalize assembles the buffered tool calls and the final Response, queues +// them (ToolCall events first, Response last), and marks the stream done. +func (s *stream) finalize() { + if s.done { + return + } + s.done = true + resp := &llm.Response{Model: s.m.p.name + "/" + s.m.id, Usage: s.usage} + if s.text.Len() > 0 { + resp.Parts = []llm.Part{llm.TextPart{Text: s.text.String()}} + } + for i, acc := range s.calls { + id := acc.id + if id == "" { + // Why: ToolResult.ID must echo ToolCall.ID; synthesize for + // compat servers that stream calls without ids. + id = fmt.Sprintf("call_%d", i) + } + resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{ + ID: id, + Name: acc.name, + Arguments: json.RawMessage(acc.args.String()), + }) + } + resp.FinishReason = mapFinish(s.finish, len(resp.ToolCalls) > 0) + for i := range resp.ToolCalls { + tc := resp.ToolCalls[i] // copy so the event doesn't alias the slice + s.queue = append(s.queue, llm.StreamEvent{ToolCall: &tc}) + } + s.queue = append(s.queue, llm.StreamEvent{Response: resp}) +} + +// Close implements llm.Stream. Closing the body unblocks any in-flight read +// and aborts the HTTP stream; safe to call at any time, including twice. +func (s *stream) Close() error { + s.closeOnce.Do(func() { s.closeErr = s.body.Close() }) + return s.closeErr +} diff --git a/provider/openai/stream_test.go b/provider/openai/stream_test.go new file mode 100644 index 0000000..23944e2 --- /dev/null +++ b/provider/openai/stream_test.go @@ -0,0 +1,267 @@ +package openai + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// sseServer streams each payload as one "data: " SSE event and +// records the request like newServer. +func sseServer(t *testing.T, payloads ...string) (*httptest.Server, *recorded) { + t.Helper() + rec := &recorded{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rec.hits++ + rec.header = r.Header.Clone() + rec.path = r.URL.Path + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read request body: %v", err) + } + if len(raw) > 0 { + if err := json.Unmarshal(raw, &rec.body); err != nil { + t.Errorf("request body is not JSON: %v\n%s", err, raw) + } + } + w.Header().Set("Content-Type", "text/event-stream") + for _, p := range payloads { + io.WriteString(w, "data: "+p+"\n\n") + } + })) + t.Cleanup(srv.Close) + return srv, rec +} + +// collect drains a stream to io.EOF, failing the test on any other error. +func collect(t *testing.T, s llm.Stream) []llm.StreamEvent { + t.Helper() + var events []llm.StreamEvent + for { + ev, err := s.Next() + if err == io.EOF { + return events + } + if err != nil { + t.Fatalf("Next: %v", err) + } + events = append(events, ev) + } +} + +func TestStreamText(t *testing.T) { + srv, rec := sseServer(t, + `{"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-test","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{"content":"Hel"},"finish_reason":null}],"obfuscation":"xK9q"}`, + `{"choices":[{"index":0,"delta":{"content":"lo"},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + `{"choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, + `[DONE]`, + ) + m := testModel(t, srv, nil) + s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + defer s.Close() + events := collect(t, s) + + // Request shape: stream flag, usage opt-in, SSE accept header. + if rec.body["stream"] != true { + t.Errorf("stream = %v, want true", rec.body["stream"]) + } + so, _ := rec.body["stream_options"].(map[string]any) + if so == nil || so["include_usage"] != true { + t.Errorf("stream_options = %v, want include_usage true", rec.body["stream_options"]) + } + if got := rec.header.Get("Accept"); got != "text/event-stream" { + t.Errorf("Accept = %q, want text/event-stream", got) + } + + if len(events) != 3 { + t.Fatalf("got %d events, want 3: %+v", len(events), events) + } + if events[0].TextDelta != "Hel" || events[1].TextDelta != "lo" { + t.Errorf("deltas = %q, %q, want Hel, lo", events[0].TextDelta, events[1].TextDelta) + } + final := events[2].Response + if final == nil { + t.Fatal("last event has no Response") + } + if got := final.Text(); got != "Hello" { + t.Errorf("final text = %q, want Hello", got) + } + if final.FinishReason != llm.FinishStop { + t.Errorf("FinishReason = %v, want stop", final.FinishReason) + } + if final.Usage != (llm.Usage{InputTokens: 5, OutputTokens: 2}) { + t.Errorf("Usage = %+v, want {5 2}", final.Usage) + } + if final.Model != "openai/gpt-test" { + t.Errorf("Model = %q, want openai/gpt-test", final.Model) + } + + // Next after EOF keeps returning EOF; Close is idempotent. + if _, err := s.Next(); err != io.EOF { + t.Errorf("Next after EOF = %v, want io.EOF", err) + } + if err := s.Close(); err != nil { + t.Errorf("first Close: %v", err) + } + if err := s.Close(); err != nil { + t.Errorf("second Close: %v", err) + } +} + +func TestStreamParallelToolCalls(t *testing.T) { + // Two interleaved calls with distinct indexes; id/name only on the first + // fragment of each; arguments split across fragments. + srv, _ := sseServer(t, + `{"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_a","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"city\":"}}]},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_b","type":"function","function":{"name":"get_time","arguments":"{\"tz\":"}}]},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Boston\"}"}}]},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"\"EST\"}"}}]},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + `{"choices":[],"usage":{"prompt_tokens":11,"completion_tokens":9,"total_tokens":20}}`, + `[DONE]`, + ) + m := testModel(t, srv, nil) + s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + defer s.Close() + events := collect(t, s) + + if len(events) != 3 { + t.Fatalf("got %d events, want 3 (two tool calls + response): %+v", len(events), events) + } + a, b := events[0].ToolCall, events[1].ToolCall + if a == nil || b == nil { + t.Fatalf("events 0/1 are not tool calls: %+v", events) + } + if a.ID != "call_a" || a.Name != "get_weather" || string(a.Arguments) != `{"city":"Boston"}` { + t.Errorf("first call = %+v", a) + } + if b.ID != "call_b" || b.Name != "get_time" || string(b.Arguments) != `{"tz":"EST"}` { + t.Errorf("second call = %+v", b) + } + final := events[2].Response + if final == nil { + t.Fatal("last event has no Response") + } + if len(final.ToolCalls) != 2 { + t.Fatalf("final ToolCalls = %d, want 2", len(final.ToolCalls)) + } + if final.ToolCalls[0].ID != "call_a" || final.ToolCalls[1].ID != "call_b" { + t.Errorf("final ToolCalls order = %q, %q", final.ToolCalls[0].ID, final.ToolCalls[1].ID) + } + if final.FinishReason != llm.FinishToolCalls { + t.Errorf("FinishReason = %v, want tool_calls", final.FinishReason) + } + if final.Usage != (llm.Usage{InputTokens: 11, OutputTokens: 9}) { + t.Errorf("Usage = %+v, want {11 9}", final.Usage) + } +} + +func TestStreamMidStreamError(t *testing.T) { + srv, _ := sseServer(t, + `{"choices":[{"index":0,"delta":{"content":"par"},"finish_reason":null}]}`, + `{"error":{"message":"The server had an error while processing your request","type":"server_error","param":null,"code":null}}`, + ) + m := testModel(t, srv, nil) + s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + defer s.Close() + + ev, err := s.Next() + if err != nil || ev.TextDelta != "par" { + t.Fatalf("first event = %+v, %v; want TextDelta par", ev, err) + } + _, err = s.Next() + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("err = %v (%T), want *llm.APIError", err, err) + } + if apiErr.Code != "server_error" { + t.Errorf("Code = %q, want server_error", apiErr.Code) + } + if apiErr.Message != "The server had an error while processing your request" { + t.Errorf("Message = %q", apiErr.Message) + } + if apiErr.Status != 0 { + t.Errorf("Status = %d, want 0 (the HTTP stream was 200)", apiErr.Status) + } +} + +func TestStreamHTTPError(t *testing.T) { + srv, _ := newServer(t, http.StatusTooManyRequests, + `{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}`) + m := testModel(t, srv, nil) + _, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + apiErr, ok := errors.AsType[*llm.APIError](err) + if !ok { + t.Fatalf("err = %v (%T), want *llm.APIError from Stream itself", err, err) + } + if apiErr.Status != http.StatusTooManyRequests || apiErr.Code != "rate_limit_exceeded" { + t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code) + } + if got := llm.Classify(err); got != llm.ClassTransient { + t.Errorf("Classify = %v, want transient", got) + } +} + +func TestStreamWithoutDoneSentinel(t *testing.T) { + // Why: some compat servers close the connection without "data: [DONE]"; + // a clean EOF must still produce the final Response. + srv, _ := sseServer(t, + `{"choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + ) + m := testModel(t, srv, nil) + s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + defer s.Close() + events := collect(t, s) + if len(events) != 2 { + t.Fatalf("got %d events, want 2: %+v", len(events), events) + } + final := events[1].Response + if final == nil || final.Text() != "ok" || final.FinishReason != llm.FinishStop { + t.Errorf("final = %+v", final) + } +} + +func TestStreamCloseEarly(t *testing.T) { + srv, _ := sseServer(t, + `{"choices":[{"index":0,"delta":{"content":"a"},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{"content":"b"},"finish_reason":null}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + `[DONE]`, + ) + m := testModel(t, srv, nil) + s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + if _, err := s.Next(); err != nil { + t.Fatalf("Next: %v", err) + } + if err := s.Close(); err != nil { + t.Errorf("Close mid-stream: %v", err) + } + if err := s.Close(); err != nil { + t.Errorf("Close again: %v", err) + } +} diff --git a/provider/openai/wire.go b/provider/openai/wire.go new file mode 100644 index 0000000..5553ccc --- /dev/null +++ b/provider/openai/wire.go @@ -0,0 +1,321 @@ +package openai + +import ( + "encoding/base64" + "encoding/json" + "strings" + + "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// --- request wire shapes --- + +type chatRequest struct { + Model string `json:"model"` + Messages []wireMessage `json:"messages"` + Tools []wireTool `json:"tools,omitempty"` + // ToolChoice is "auto"/"none"/"required" (string) or a named-function + // object; any avoids two fields for one wire key. + ToolChoice any `json:"tool_choice,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stop []string `json:"stop,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + ResponseFormat *wireRespFormat `json:"response_format,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *wireStreamOptions `json:"stream_options,omitempty"` +} + +type wireMessage struct { + Role string `json:"role"` + // Content is a string for text-only turns, a part array for multimodal + // turns, or nil (wire null) for assistant turns that only call tools. + Content any `json:"content"` + ToolCalls []wireToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type wireTextPart struct { + Type string `json:"type"` + Text string `json:"text"` +} + +type wireImagePart struct { + Type string `json:"type"` + ImageURL wireImageURL `json:"image_url"` +} + +type wireImageURL struct { + URL string `json:"url"` +} + +type wireToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function wireFunctionCall `json:"function"` +} + +type wireFunctionCall struct { + Name string `json:"name"` + // Arguments is a JSON-encoded STRING per the wire format, not an object. + Arguments string `json:"arguments"` +} + +type wireTool struct { + Type string `json:"type"` + Function wireToolFunction `json:"function"` +} + +type wireToolFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +type wireNamedToolChoice struct { + Type string `json:"type"` + Function wireToolName `json:"function"` +} + +type wireToolName struct { + Name string `json:"name"` +} + +type wireRespFormat struct { + Type string `json:"type"` + JSONSchema *wireJSONSchema `json:"json_schema,omitempty"` +} + +// wireJSONSchema omits the strict flag on purpose: strict mode imposes +// schema rewrites (every property required, additionalProperties:false at +// every level) that belong to the caller, not the transport. +type wireJSONSchema struct { + Name string `json:"name"` + Schema json.RawMessage `json:"schema"` +} + +type wireStreamOptions struct { + IncludeUsage bool `json:"include_usage"` +} + +// --- response wire shapes (loose: unknown fields ignored) --- + +type chatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []chatChoice `json:"choices"` + Usage *wireUsage `json:"usage"` +} + +type chatChoice struct { + Index int `json:"index"` + Message wireRespMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type wireRespMessage struct { + Role string `json:"role"` + Content string `json:"content"` // null decodes to "" + Refusal string `json:"refusal"` // tolerated, unused + ToolCalls []wireToolCall `json:"tool_calls"` +} + +type wireUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type errorEnvelope struct { + Error wireError `json:"error"` +} + +type wireError struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` // null decodes to "" +} + +// --- streaming wire shapes --- + +type streamChunk struct { + Choices []streamChoice `json:"choices"` + Usage *wireUsage `json:"usage"` + Error *wireError `json:"error"` // mid-stream error event +} + +type streamChoice struct { + Index int `json:"index"` + Delta streamDelta `json:"delta"` + FinishReason string `json:"finish_reason"` // null decodes to "" +} + +type streamDelta struct { + Content string `json:"content"` // null decodes to "" + ToolCalls []streamToolCallDelta `json:"tool_calls"` +} + +// streamToolCallDelta is one tool-call fragment. The id and name appear only +// on a call's first fragment; later fragments carry just index + an +// arguments substring. Accumulation keys on Index, never ID. +type streamToolCallDelta struct { + Index int `json:"index"` + ID string `json:"id"` + Function wireFunctionCall `json:"function"` +} + +// --- mapping: llm.Request -> chatRequest --- + +// buildRequest translates the canonical request to the wire shape. The +// capability check has already passed by the time this runs. +func (m *model) buildRequest(req llm.Request, stream bool) *chatRequest { + out := &chatRequest{ + Model: m.id, + Temperature: req.Temperature, + TopP: req.TopP, + Stop: req.StopSequences, + ReasoningEffort: req.ReasoningEffort, + } + + // Fold Request.System and every RoleSystem message into one leading + // system message, System field first. Why: the canonical contract allows + // system content in both places; OpenAI wants one system mechanism. + var sys []string + if req.System != "" { + sys = append(sys, req.System) + } + for _, msg := range req.Messages { + if msg.Role == llm.RoleSystem { + if t := msg.Text(); t != "" { + sys = append(sys, t) + } + } + } + if joined := strings.Join(sys, "\n\n"); joined != "" { + out.Messages = append(out.Messages, wireMessage{Role: "system", Content: joined}) + } + + for _, msg := range req.Messages { + switch msg.Role { + case llm.RoleSystem: + // Folded above; excluded from the normal message list. + case llm.RoleUser: + out.Messages = append(out.Messages, wireMessage{Role: "user", Content: contentValue(msg.Parts)}) + case llm.RoleAssistant: + wm := wireMessage{Role: "assistant"} + if text := msg.Text(); text != "" { + wm.Content = text + } + for _, tc := range msg.ToolCalls { + args := string(tc.Arguments) + if args == "" { + // Why: arguments must be a valid JSON document string; + // an empty string is not one. + args = "{}" + } + wm.ToolCalls = append(wm.ToolCalls, wireToolCall{ + ID: tc.ID, + Type: "function", + Function: wireFunctionCall{Name: tc.Name, Arguments: args}, + }) + } + out.Messages = append(out.Messages, wm) + case llm.RoleTool: + // One wire message per result: the API pairs each tool output + // with its call via tool_call_id, one message each. + for _, tr := range msg.ToolResults { + content := tr.Content + if tr.IsError { + content = "ERROR: " + content + } + out.Messages = append(out.Messages, wireMessage{ + Role: "tool", + Content: content, + ToolCallID: tr.ID, + }) + } + } + } + + for _, t := range req.Tools { + out.Tools = append(out.Tools, wireTool{ + Type: "function", + Function: wireToolFunction{Name: t.Name, Description: t.Description, Parameters: t.Parameters}, + }) + } + + switch req.ToolChoice { + case "": + // Omit: provider default ("auto" when tools are present). + case "auto", "none", "required": + out.ToolChoice = req.ToolChoice + default: + // Any other value names the one tool the model must call. + out.ToolChoice = wireNamedToolChoice{Type: "function", Function: wireToolName{Name: req.ToolChoice}} + } + + if req.MaxTokens > 0 { + if m.p.legacyMaxTokens { + out.MaxTokens = req.MaxTokens + } else { + out.MaxCompletionTokens = req.MaxTokens + } + } + + if len(req.Schema) > 0 { + name := req.SchemaName + if name == "" { + name = "response" + } + out.ResponseFormat = &wireRespFormat{ + Type: "json_schema", + JSONSchema: &wireJSONSchema{Name: name, Schema: req.Schema}, + } + } + + if stream { + out.Stream = true + // Why: without include_usage the stream never reports token counts; + // the usage arrives in one extra chunk with an empty choices array. + out.StreamOptions = &wireStreamOptions{IncludeUsage: true} + } + + return out +} + +// contentValue renders message parts as the wire content value: a plain +// string when text-only (maximum compat), a part array when images are +// present. +func contentValue(parts []llm.Part) any { + multimodal := false + for _, p := range parts { + if _, ok := p.(llm.ImagePart); ok { + multimodal = true + break + } + } + if !multimodal { + var b strings.Builder + for _, p := range parts { + if t, ok := p.(llm.TextPart); ok { + b.WriteString(t.Text) + } + } + return b.String() + } + out := make([]any, 0, len(parts)) + for _, p := range parts { + switch v := p.(type) { + case llm.TextPart: + out = append(out, wireTextPart{Type: "text", Text: v.Text}) + case llm.ImagePart: + url := "data:" + v.MIME + ";base64," + base64.StdEncoding.EncodeToString(v.Data) + out = append(out, wireImagePart{Type: "image_url", ImageURL: wireImageURL{URL: url}}) + } + } + return out +} diff --git a/registry.go b/registry.go index 430982f..2992b01 100644 --- a/registry.go +++ b/registry.go @@ -2,6 +2,7 @@ package majordomo import ( "fmt" + "net/http" "os" "strings" "sync" @@ -73,11 +74,12 @@ func (c ChainConfig) classify(err error) llm.ErrorClass { } type registryConfig struct { - health health.Config - chain ChainConfig - envLookup func(string) string - environ func() []string - skipEnv bool + health health.Config + chain ChainConfig + envLookup func(string) string + environ func() []string + skipEnv bool + httpClient *http.Client } // RegistryOption configures New. @@ -115,6 +117,14 @@ func WithoutEnvProviders() RegistryOption { return func(rc *registryConfig) { rc.skipEnv = true } } +// WithHTTPClient sets the HTTP client used by built-in providers and +// env-DSN scheme factories created by this registry (proxies, custom TLS, +// test servers). Providers registered explicitly via RegisterProvider keep +// whatever client they were built with. +func WithHTTPClient(c *http.Client) RegistryOption { + return func(rc *registryConfig) { rc.httpClient = c } +} + // New creates a Registry with all built-in providers and scheme factories // registered, then loads LLM_* env-DSN providers from the process // environment (unless WithoutEnvProviders is given). Malformed LLM_* entries @@ -139,7 +149,7 @@ func New(opts ...RegistryOption) *Registry { envLookup: cfg.envLookup, } - registerBuiltins(r) + registerBuiltins(r, cfg.httpClient) if !cfg.skipEnv { env := make(map[string]string)