feat: OpenAI, Anthropic, and native-Ollama providers + media pipeline
Phase 3: - provider/openai: Chat Completions for OpenAI + compat endpoints (SSE streaming with by-index tool-call assembly, response_format json_schema, legacy max_tokens option, reasoning_effort) - provider/anthropic: Messages API (tool_use/tool_result, GA structured output via output_config.format, full SSE event parser, 529 transient) - provider/ollama: one native /api/chat client behind the ollama, ollama-cloud, and foreman built-ins (presets; NDJSON streaming tolerant of foreman's buffered single-object responses; object tool arguments; format-schema structured output; think mapping) - media/: capability normalization (sniff, downscale, transcode, byte ladder, ErrUnsupported), wired into the chain executor per target with penalty-free advance past incapable elements - registry: real provider + scheme wiring, WithHTTPClient option, required env-foreman TLS chat round-trip test - ADR-0009 multimodal strategy, ADR-0010 tools/structured mapping; README matrix + CLAUDE.md synced Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -27,9 +27,15 @@ majordomo Registry, Parse, env-DSN loading, chain executor, re-exports
|
||||
llm/ canonical contract: Message/Part/Request/Response/Option,
|
||||
Tool/Toolbox, Capabilities, Stream, Model, Provider, errors
|
||||
health/ clock-injected health tracker (bench/backoff)
|
||||
media/ image normalization to target capabilities (Phase 3)
|
||||
media/ image normalization to target capabilities (sniff real
|
||||
format, downscale, transcode, byte ladder; ErrUnsupported
|
||||
for what can't fit) — chains normalize PER TARGET
|
||||
provider/fake/ scriptable in-memory provider for hermetic tests
|
||||
provider/{openai,anthropic,ollama,google}/ (Phases 3-4)
|
||||
provider/openai/ Chat Completions client (+ all OpenAI-compat targets)
|
||||
provider/anthropic/ Messages API client (+ Anthropic-compat targets)
|
||||
provider/ollama/ one native /api/chat client serving the ollama,
|
||||
ollama-cloud, and foreman built-ins via presets
|
||||
provider/google/ Gemini on the official genai SDK (Phase 4)
|
||||
agent/ Agent run loop (Phase 5)
|
||||
skill/ Skill interface + composition (Phase 6)
|
||||
examples/ one runnable example per hard requirement (Phase 7-8)
|
||||
|
||||
@@ -100,12 +100,25 @@ Chains are health-tracked per target:
|
||||
|
||||
| Provider | Spec name | Key env var | Default endpoint |
|
||||
|----------|-----------|-------------|------------------|
|
||||
| OpenAI (+compatible) | `openai` | `OPENAI_API_KEY` | api.openai.com *(pending)* |
|
||||
| Anthropic (+compatible) | `anthropic` | `ANTHROPIC_API_KEY` | api.anthropic.com *(pending)* |
|
||||
| OpenAI (+compatible) | `openai` | `OPENAI_API_KEY` | https://api.openai.com/v1 |
|
||||
| Anthropic (+compatible) | `anthropic` | `ANTHROPIC_API_KEY` | https://api.anthropic.com |
|
||||
| Google (Gemini) | `google` | `GOOGLE_API_KEY` / `GEMINI_API_KEY` | Gen AI API *(pending)* |
|
||||
| Ollama Cloud | `ollama-cloud` | `OLLAMA_API_KEY` | https://ollama.com *(pending)* |
|
||||
| Ollama (local) | `ollama` | — | `OLLAMA_HOST` or http://localhost:11434 *(pending)* |
|
||||
| foreman | `foreman` | — (token via DSN) | requires DSN/base URL *(pending)* |
|
||||
| Ollama Cloud | `ollama-cloud` | `OLLAMA_API_KEY` | https://ollama.com |
|
||||
| Ollama (local) | `ollama` | — | `OLLAMA_HOST` or http://localhost:11434 |
|
||||
| foreman | `foreman` | — (token via DSN) | requires an LLM_* DSN or `ollama.Foreman(url, token)` |
|
||||
|
||||
OpenAI-compatible / Anthropic-compatible endpoints: construct the provider
|
||||
with a name and base URL and register it —
|
||||
|
||||
```go
|
||||
reg.RegisterProvider(openai.New(
|
||||
openai.WithName("groq"),
|
||||
openai.WithBaseURL("https://api.groq.com/openai/v1"),
|
||||
openai.WithAPIKey(key),
|
||||
// openai.WithLegacyMaxTokens(), // for servers that only honor max_tokens
|
||||
))
|
||||
// now "groq/llama-3.3-70b" works in Parse, chains, and aliases
|
||||
```
|
||||
|
||||
### `LLM_*` env-DSN provider definitions
|
||||
|
||||
@@ -139,11 +152,15 @@ Implement the two-method `Provider` interface and register it:
|
||||
reg.RegisterProvider(myProvider) // now "myprovider/model-x" parses, chains, aliases
|
||||
```
|
||||
|
||||
## Multimodality *(pending — Phase 3)*
|
||||
## Multimodality
|
||||
|
||||
Attach images without knowing the target's limits; majordomo normalizes
|
||||
(downscale, re-encode, count/size limits) against the resolved target's
|
||||
declared capabilities and rejects clearly what cannot fit.
|
||||
Attach images without knowing the target's limits. Before each attempt the
|
||||
request is normalized against the **actual serving target's** declared
|
||||
capabilities: the real format is sniffed from the bytes, oversize images
|
||||
are downscaled (aspect preserved), disallowed formats are re-encoded, and
|
||||
byte budgets are enforced by a quality ladder. What cannot be made to fit
|
||||
is rejected with a clear `ErrUnsupported` error — and in a chain, the
|
||||
request simply advances to the next (e.g. vision-capable) element.
|
||||
|
||||
```go
|
||||
resp, err := m.Generate(ctx, majordomo.Request{
|
||||
@@ -154,7 +171,7 @@ resp, err := m.Generate(ctx, majordomo.Request{
|
||||
})
|
||||
```
|
||||
|
||||
## Tool calls *(canonical API ready; provider wiring pending — Phase 3)*
|
||||
## Tool calls
|
||||
|
||||
```go
|
||||
weather := majordomo.Tool{
|
||||
@@ -171,14 +188,20 @@ resp, _ := m.Generate(ctx, req, majordomo.WithTools(weather))
|
||||
// resp.ToolCalls → execute → append ToolResultsMessage → continue
|
||||
```
|
||||
|
||||
## Structured output *(canonical API ready; provider wiring pending — Phase 3)*
|
||||
Each provider maps this one shape to its native function-calling format
|
||||
(OpenAI tools/tool_calls, Anthropic tool_use/tool_result, Ollama tools with
|
||||
object arguments). Tool-call ids are synthesized when a backend omits them;
|
||||
streaming buffers tool-call arguments until they parse.
|
||||
|
||||
## Structured output
|
||||
|
||||
```go
|
||||
resp, _ := m.Generate(ctx, req, majordomo.WithSchema(schemaJSON, "answer"))
|
||||
```
|
||||
|
||||
A generic `Generate[T]` helper (schema from your struct, unmarshal into it)
|
||||
lands with the agent phase.
|
||||
Maps to OpenAI `response_format: json_schema`, Anthropic
|
||||
`output_config.format`, and Ollama `format`. A generic `Generate[T]` helper
|
||||
(schema from your struct, unmarshal into it) lands with the agent phase.
|
||||
|
||||
## Agents & skills *(pending — Phases 5–6)*
|
||||
|
||||
@@ -189,17 +212,25 @@ skills = reusable instruction+tool bundles attachable to any agent.
|
||||
|
||||
| Provider | Resolve/Parse | Chat | Streaming | Tools | Structured | Images | Env DSN |
|
||||
|----------------------|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
||||
| OpenAI (+compatible) | ✅ | pending | pending | pending | pending | pending | ✅ |
|
||||
| Anthropic (+compat) | ✅ | pending | pending | pending | pending | pending | ✅ |
|
||||
| OpenAI (+compatible) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Anthropic (+compat) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Google (Gemini) | ✅ | pending | pending | pending | pending | pending | ✅ |
|
||||
| Ollama Cloud | ✅ | pending | pending | pending | pending | pending | ✅ |
|
||||
| Ollama (local) | ✅ | pending | pending | pending | pending | pending | ✅ |
|
||||
| foreman | ✅ | pending | pending | pending | pending | pending | ✅ |
|
||||
| Ollama Cloud | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Ollama (local) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| foreman | ✅ | ✅ | ✅¹ | ✅ | ✅ | ✅ | ✅ |
|
||||
| fake (testing) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | — |
|
||||
|
||||
¹ foreman's daemon currently buffers sync chat responses (no token-by-token
|
||||
streaming); majordomo's stream API works against it and delivers the
|
||||
response as a single delta plus final event.
|
||||
|
||||
Notes: Ollama has no native tool_choice — `"none"` drops the tools;
|
||||
`"required"`/named choices are best-effort ignored there.
|
||||
|
||||
Cross-cutting: Parse grammar ✅ · aliases/tiers ✅ · failover chains ✅ ·
|
||||
health tracking/backoff ✅ · LLM_* env DSNs ✅ · media pipeline pending ·
|
||||
agent loop pending · skills pending · `Generate[T]` pending.
|
||||
health tracking/backoff ✅ · LLM_* env DSNs ✅ · media pipeline ✅
|
||||
(per-target normalization in chains) · agent loop pending · skills pending
|
||||
· `Generate[T]` pending.
|
||||
|
||||
## Development
|
||||
|
||||
|
||||
+73
-14
@@ -3,8 +3,12 @@ package majordomo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/anthropic"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/ollama"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/openai"
|
||||
)
|
||||
|
||||
// Built-in provider names. Real client implementations land per-phase
|
||||
@@ -20,25 +24,80 @@ const (
|
||||
)
|
||||
|
||||
// registerBuiltins installs the built-in providers and env-DSN scheme
|
||||
// factories into a fresh registry.
|
||||
func registerBuiltins(r *Registry) {
|
||||
stub := func(kind string) SchemeFactory {
|
||||
// factories into a fresh registry. httpClient, when non-nil, is used by
|
||||
// every provider and factory the registry itself constructs.
|
||||
func registerBuiltins(r *Registry, httpClient *http.Client) {
|
||||
ollamaOpts := func(extra ...ollama.Option) []ollama.Option {
|
||||
if httpClient != nil {
|
||||
extra = append(extra, ollama.WithHTTPClient(httpClient))
|
||||
}
|
||||
return extra
|
||||
}
|
||||
|
||||
// Native-Ollama family: three names over one client with presets.
|
||||
r.providers[ProviderOllama] = ollama.Local(ollamaOpts()...)
|
||||
r.providers[ProviderOllamaCloud] = ollama.Cloud(ollamaOpts()...)
|
||||
// foreman has no default URL; the no-DSN registration resolves but
|
||||
// errors on use with a clear message (use an LLM_* DSN or
|
||||
// ollama.Foreman(...) + RegisterProvider).
|
||||
r.providers[ProviderForeman] = ollama.New(ollamaOpts(ollama.WithName(ProviderForeman))...)
|
||||
|
||||
ollamaScheme := func(name string, dsn DSN) (llm.Provider, error) {
|
||||
return ollama.New(ollamaOpts(
|
||||
ollama.WithName(name),
|
||||
ollama.WithBaseURL(dsn.BaseURL()),
|
||||
ollama.WithToken(dsn.Token),
|
||||
)...), nil
|
||||
}
|
||||
r.schemes[ProviderOllama] = ollamaScheme
|
||||
r.schemes[ProviderOllamaCloud] = ollamaScheme
|
||||
r.schemes[ProviderForeman] = ollamaScheme
|
||||
|
||||
// OpenAI and OpenAI-compatible endpoints.
|
||||
openaiOpts := func(extra ...openai.Option) []openai.Option {
|
||||
if httpClient != nil {
|
||||
extra = append(extra, openai.WithHTTPClient(httpClient))
|
||||
}
|
||||
return extra
|
||||
}
|
||||
r.providers[ProviderOpenAI] = openai.New(openaiOpts()...)
|
||||
r.schemes[ProviderOpenAI] = func(name string, dsn DSN) (llm.Provider, error) {
|
||||
return openai.New(openaiOpts(
|
||||
openai.WithName(name),
|
||||
openai.WithBaseURL(dsn.BaseURL()),
|
||||
openai.WithAPIKey(dsn.Token),
|
||||
)...), nil
|
||||
}
|
||||
|
||||
// Anthropic and Anthropic-compatible endpoints.
|
||||
anthropicOpts := func(extra ...anthropic.Option) []anthropic.Option {
|
||||
if httpClient != nil {
|
||||
extra = append(extra, anthropic.WithHTTPClient(httpClient))
|
||||
}
|
||||
return extra
|
||||
}
|
||||
r.providers[ProviderAnthropic] = anthropic.New(anthropicOpts()...)
|
||||
r.schemes[ProviderAnthropic] = func(name string, dsn DSN) (llm.Provider, error) {
|
||||
return anthropic.New(anthropicOpts(
|
||||
anthropic.WithName(name),
|
||||
anthropic.WithBaseURL(dsn.BaseURL()),
|
||||
anthropic.WithAPIKey(dsn.Token),
|
||||
)...), nil
|
||||
}
|
||||
|
||||
// Google lands in its own phase; stub until then.
|
||||
r.providers[ProviderGoogle] = &stubProvider{name: ProviderGoogle, kind: ProviderGoogle}
|
||||
r.schemes[ProviderGoogle] = stubScheme(ProviderGoogle)
|
||||
// "gemini" is an alternate scheme for the Google provider.
|
||||
r.schemes["gemini"] = stubScheme(ProviderGoogle)
|
||||
}
|
||||
|
||||
func stubScheme(kind string) SchemeFactory {
|
||||
return func(name string, dsn DSN) (llm.Provider, error) {
|
||||
return &stubProvider{name: name, kind: kind, baseURL: dsn.BaseURL(), token: dsn.Token}, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, kind := range []string{
|
||||
ProviderOpenAI, ProviderAnthropic, ProviderGoogle,
|
||||
ProviderOllama, ProviderOllamaCloud, ProviderForeman,
|
||||
} {
|
||||
r.providers[kind] = &stubProvider{name: kind, kind: kind}
|
||||
r.schemes[kind] = stub(kind)
|
||||
}
|
||||
// "gemini" is an alternate scheme for the Google provider.
|
||||
r.schemes["gemini"] = stub(ProviderGoogle)
|
||||
}
|
||||
|
||||
// stubProvider stands in for a provider implementation that lands in a
|
||||
// later phase. It resolves and carries its connection details (so Parse,
|
||||
// chains, and env loading are fully functional) but errors on use.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/health"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/media"
|
||||
)
|
||||
|
||||
// ErrChainExhausted reports that every element of a failover chain failed
|
||||
@@ -65,8 +66,8 @@ func (c *chain) Capabilities() llm.Capabilities {
|
||||
// Generate tries each target per the chain semantics above.
|
||||
func (c *chain) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) {
|
||||
req = req.Apply(opts...)
|
||||
return chainDo(ctx, c, func(ctx context.Context, t chainTarget) (*llm.Response, error) {
|
||||
return t.model.Generate(ctx, req)
|
||||
return chainDo(ctx, c, req, func(ctx context.Context, t chainTarget, nreq llm.Request) (*llm.Response, error) {
|
||||
return t.model.Generate(ctx, nreq)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -76,14 +77,17 @@ func (c *chain) Generate(ctx context.Context, req llm.Request, opts ...llm.Optio
|
||||
// (replaying half-delivered output would duplicate content).
|
||||
func (c *chain) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) {
|
||||
req = req.Apply(opts...)
|
||||
return chainDo(ctx, c, func(ctx context.Context, t chainTarget) (llm.Stream, error) {
|
||||
return t.model.Stream(ctx, req)
|
||||
return chainDo(ctx, c, req, func(ctx context.Context, t chainTarget, nreq llm.Request) (llm.Stream, error) {
|
||||
return t.model.Stream(ctx, nreq)
|
||||
})
|
||||
}
|
||||
|
||||
// chainDo runs the head-to-tail failover algorithm around an attempt
|
||||
// function, generic over the result type (response vs stream).
|
||||
func chainDo[T any](ctx context.Context, c *chain, attempt func(context.Context, chainTarget) (T, error)) (T, error) {
|
||||
// function, generic over the result type (response vs stream). Before each
|
||||
// target is tried, the request's media is normalized against THAT target's
|
||||
// capabilities (ADR-0008/0009) — a request that cannot be made to fit one
|
||||
// target advances to the next without a health penalty.
|
||||
func chainDo[T any](ctx context.Context, c *chain, req llm.Request, attempt func(context.Context, chainTarget, llm.Request) (T, error)) (T, error) {
|
||||
var zero T
|
||||
var failures []error
|
||||
|
||||
@@ -94,12 +98,20 @@ func chainDo[T any](ctx context.Context, c *chain, attempt func(context.Context,
|
||||
continue
|
||||
}
|
||||
|
||||
nreq, err := media.Normalize(req, t.model.Capabilities())
|
||||
if err != nil {
|
||||
// Always ErrUnsupported-wrapped: this target cannot take the
|
||||
// request by declaration. Advance, no health penalty.
|
||||
failures = append(failures, fmt.Errorf("%s: %w", t.key, err))
|
||||
continue
|
||||
}
|
||||
|
||||
retries := c.cfg.retries()
|
||||
for attemptN := 0; ; attemptN++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return zero, err
|
||||
}
|
||||
result, err := attempt(ctx, t)
|
||||
result, err := attempt(ctx, t, nreq)
|
||||
if err == nil {
|
||||
c.tracker.ReportSuccess(t.key)
|
||||
return result, nil
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
# ADR-0009: Multimodal strategy — normalize per target, enforce at the provider
|
||||
|
||||
**Status:** Accepted — 2026-06-10
|
||||
|
||||
## Context
|
||||
|
||||
Every provider (and some models) imposes different image rules: max
|
||||
dimensions/bytes, allowed MIME types, max images per request. A caller must
|
||||
be able to attach an image without knowing the eventual target — especially
|
||||
with failover chains, where the serving target isn't known until runtime.
|
||||
|
||||
## Decision
|
||||
|
||||
Two cooperating layers:
|
||||
|
||||
1. **`media.Normalize(req, caps)`** — the transformation point. The chain
|
||||
executor calls it **per target, per attempt**, against the actual
|
||||
target's capabilities, before the provider sees the request:
|
||||
- The real format is **sniffed from magic bytes** and wins over the
|
||||
declared MIME (callers lie; jpeg/png/gif/webp recognized).
|
||||
- Already-fitting images pass through untouched (fast path: zero copies).
|
||||
- Oversize dimensions downscale (aspect-preserving) with a hand-rolled
|
||||
box-filter — stdlib has no scaler and `x/image` stays out per
|
||||
ADR-0007; box-average quality is ample for vision input.
|
||||
- Disallowed MIME re-encodes: original format if allowed, else JPEG
|
||||
(q85), else PNG, else the first allowed encodable type.
|
||||
- Byte budgets enforce via a quality ladder (jpeg 85→65→45→30) then
|
||||
dimension halving; ~6 attempts before giving up.
|
||||
- WebP cannot be decoded by stdlib: it passes through when it fits and
|
||||
is allowed; any needed transform is a clear error.
|
||||
- Everything that cannot be made to fit errors **wrapping
|
||||
`llm.ErrUnsupported`** — never silently dropped.
|
||||
2. **Provider backstop** — each provider cheaply enforces its effective
|
||||
capabilities at request time (image count/MIME/bytes, plus
|
||||
tools/structured/streaming support flags) and rejects with
|
||||
`ErrUnsupported`. This keeps providers honest for expert callers who
|
||||
build models directly without the registry.
|
||||
|
||||
Chain semantics: a normalization failure for one target **advances** to the
|
||||
next element with no health penalty (the target isn't sick, it's just
|
||||
incapable) — so `fp/text-only,fp/vision` serves an image request from the
|
||||
vision element automatically.
|
||||
|
||||
Canonical image content stays **bytes + MIME** (ADR-0002); no URL fetching.
|
||||
|
||||
## Consequences
|
||||
|
||||
- A 100×50 PNG sent at a 32px-cap target arrives as a 32×16 PNG; the same
|
||||
request served by an 8000px target arrives untouched.
|
||||
- Conditional provider rules (e.g. Anthropic's 2000px cap above 20 images)
|
||||
are approximated by the flat declared caps — conservative and simple.
|
||||
|
||||
## Alternatives considered
|
||||
|
||||
- Normalize once against chain-intersection caps: over-restricts every
|
||||
request for the sake of rarely-used fallbacks. Rejected (ADR-0008).
|
||||
- `x/image/draw` scalers: a dependency for one function. Rejected.
|
||||
@@ -0,0 +1,49 @@
|
||||
# ADR-0010: Tools and structured output — one canonical shape, native mappings
|
||||
|
||||
**Status:** Accepted — 2026-06-10
|
||||
|
||||
## Context
|
||||
|
||||
Tool calling and schema-constrained output exist on every target but with
|
||||
different wire shapes (verified against current docs, June 2026; shapes
|
||||
recorded in each provider's package doc). The canonical API must hide all
|
||||
of it.
|
||||
|
||||
## Decision
|
||||
|
||||
Canonical: `Tool{Name, Description, Parameters (JSON Schema), Handler}`;
|
||||
`Response.ToolCalls[]{ID, Name, Arguments json.RawMessage}`; results return
|
||||
as `ToolResultsMessage(ToolResult{ID, Name, Content, IsError})`. Structured
|
||||
output via `WithSchema(schema, name)`. Per-provider mapping:
|
||||
|
||||
| Concern | OpenAI(+compat) | Anthropic(+compat) | Ollama/foreman | Google (Phase 4) |
|
||||
|---|---|---|---|---|
|
||||
| Tool def | `tools[].function{name,description,parameters}` | `tools[]{name,description,input_schema}` | `tools[].function` | `FunctionDeclaration.ParametersJsonSchema` |
|
||||
| Call args | JSON **string** → RawMessage | `tool_use.input` object | `arguments` **object** | `FunctionCall.Args` map |
|
||||
| Results | one `role:tool` msg per result (`tool_call_id`) | one **user** msg, `tool_result` blocks (`is_error` native) | `role:tool` + `tool_name` | `FunctionResponse` parts |
|
||||
| IsError | `"ERROR: "` content prefix | `is_error: true` | `"ERROR: "` prefix | response payload field |
|
||||
| Forced choice | `tool_choice` string / named object | `{"type":"any"/"tool"/"none"}` | none → drop tools; others best-effort ignored | `FunctionCallingConfig` modes |
|
||||
| Structured | `response_format json_schema` (no strict flag) | `output_config.format json_schema` (GA mechanism) | `format: <schema>` | `ResponseJsonSchema` + JSON MIME |
|
||||
|
||||
Cross-cutting decisions:
|
||||
|
||||
- **Missing call ids are synthesized** (`call_<n>`) — Ollama and some
|
||||
compat servers omit them; the agent loop needs ids to match results.
|
||||
- **Streaming buffers tool-call arguments to completion** (ADR-0002):
|
||||
OpenAI fragments accumulate by index, Anthropic `input_json_delta`
|
||||
fragments accumulate per block; consumers only ever see parseable calls.
|
||||
- **No strict-mode flag is sent** to OpenAI: strict mode imposes schema
|
||||
constraints (every property required, additionalProperties:false) that
|
||||
caller-supplied schemas may not satisfy. The `Generate[T]` reflector
|
||||
(Phase 5) emits strict-compatible schemas anyway.
|
||||
- `SchemaName` feeds providers that need a name (OpenAI; default
|
||||
"response"); others ignore it.
|
||||
- Tool handlers never panic the loop: `Toolbox.Execute`/`ExecuteTool`
|
||||
recover panics and JSON-encode results (ADR to agent loop, Phase 5).
|
||||
|
||||
## Consequences
|
||||
|
||||
- One test matrix per provider asserts the exact wire JSON both directions;
|
||||
drift is caught by httptest fixtures, not in production.
|
||||
- Ollama's missing tool_choice means "required" cannot be enforced there —
|
||||
documented in the README matrix rather than emulated.
|
||||
@@ -12,3 +12,5 @@ One decision per file, append-only; supersede rather than rewrite.
|
||||
| [0006](0006-health-and-backoff.md) | Model health tracking and backoff | Accepted |
|
||||
| [0007](0007-dependency-policy.md) | Dependency policy — stdlib-first, hand-rolled REST clients | Accepted |
|
||||
| [0008](0008-chain-semantics.md) | Failover-chain execution semantics | Accepted |
|
||||
| [0009](0009-multimodal-strategy.md) | Multimodal strategy — normalize per target, enforce at provider | Accepted |
|
||||
| [0010](0010-tools-structured-output-mapping.md) | Tools and structured output — canonical shape, native mappings | Accepted |
|
||||
|
||||
+64
-9
@@ -1,10 +1,17 @@
|
||||
package majordomo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/provider/ollama"
|
||||
)
|
||||
|
||||
func TestParseDSN(t *testing.T) {
|
||||
@@ -71,19 +78,16 @@ func TestLoadEnvForeman(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("provider %q not registered", name)
|
||||
}
|
||||
sp, ok := p.(*stubProvider)
|
||||
op, ok := p.(*ollama.Provider)
|
||||
if !ok {
|
||||
t.Fatalf("provider %q is %T, want *stubProvider (phase 1)", name, p)
|
||||
t.Fatalf("provider %q is %T, want *ollama.Provider (foreman scheme)", name, p)
|
||||
}
|
||||
if sp.kind != ProviderForeman {
|
||||
t.Errorf("provider %q kind = %q, want foreman", name, sp.kind)
|
||||
if op.Name() != name {
|
||||
t.Errorf("provider name = %q, want %q", op.Name(), name)
|
||||
}
|
||||
wantURL := "https://foreman-" + name + ".orgrimmar.dudenhoeffer.casa"
|
||||
if sp.baseURL != wantURL {
|
||||
t.Errorf("provider %q baseURL = %q, want %q", name, sp.baseURL, wantURL)
|
||||
}
|
||||
if sp.token != "test-token-change-me" {
|
||||
t.Errorf("provider %q token = %q, want the DSN userinfo", name, sp.token)
|
||||
if op.BaseURL() != wantURL {
|
||||
t.Errorf("provider %q baseURL = %q, want %q", name, op.BaseURL(), wantURL)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,3 +197,54 @@ func TestNewLoadsProcessEnv(t *testing.T) {
|
||||
t.Error("New() should eagerly load LLM_ENVTEST from the process environment")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnvForemanChatRoundTrip is the required end-to-end case: an LLM_*
|
||||
// foreman DSN resolves through Parse and serves a real chat over the wire
|
||||
// (TLS test server, since env DSNs always dial https), with the DSN token
|
||||
// arriving as the bearer credential.
|
||||
func TestEnvForemanChatRoundTrip(t *testing.T) {
|
||||
var gotAuth, gotPath string
|
||||
var gotModel string
|
||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotPath = r.URL.Path
|
||||
var body struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&body)
|
||||
gotModel = body.Model
|
||||
_, _ = io.WriteString(w, `{"message":{"role":"assistant","content":"hi from foreman"},"done":true,"done_reason":"stop","prompt_eval_count":2,"eval_count":3}`)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
host := strings.TrimPrefix(ts.URL, "https://")
|
||||
r := New(
|
||||
WithoutEnvProviders(),
|
||||
WithEnvLookup(func(string) string { return "" }),
|
||||
WithHTTPClient(ts.Client()),
|
||||
)
|
||||
if err := r.LoadEnv(map[string]string{"LLM_FM": "foreman://round-trip-token@" + host}); err != nil {
|
||||
t.Fatalf("LoadEnv: %v", err)
|
||||
}
|
||||
|
||||
m, err := r.Parse("fm/qwen3:30b")
|
||||
if err != nil {
|
||||
t.Fatalf("Parse: %v", err)
|
||||
}
|
||||
resp, err := m.Generate(context.Background(), Request{Messages: []Message{UserText("hello")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if resp.Text() != "hi from foreman" {
|
||||
t.Errorf("text = %q", resp.Text())
|
||||
}
|
||||
if resp.Model != "fm/qwen3:30b" {
|
||||
t.Errorf("resp.Model = %q, want fm/qwen3:30b", resp.Model)
|
||||
}
|
||||
if gotAuth != "Bearer round-trip-token" {
|
||||
t.Errorf("auth = %q, want the DSN token as bearer", gotAuth)
|
||||
}
|
||||
if gotPath != "/api/chat" || gotModel != "qwen3:30b" {
|
||||
t.Errorf("path=%q model=%q", gotPath, gotModel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package majordomo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -325,6 +329,95 @@ func TestSingleTargetGetsChainSemantics(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// pngImage encodes a width×height PNG for media tests.
|
||||
func pngImage(t *testing.T, width, height int) []byte {
|
||||
t.Helper()
|
||||
img := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
for y := range height {
|
||||
for x := range width {
|
||||
img.Set(x, y, color.RGBA{R: uint8(x), G: uint8(y), B: 128, A: 255})
|
||||
}
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
t.Fatalf("encode png: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// TestChainNormalizesMediaPerTarget: the request's image is downscaled to
|
||||
// the capabilities of the target that actually serves it.
|
||||
func TestChainNormalizesMediaPerTarget(t *testing.T) {
|
||||
r := newTestRegistry(t)
|
||||
fp := fake.New("fp",
|
||||
fake.WithModelCapabilities("small-vision", llm.Capabilities{
|
||||
MaxImagesPerReq: 2,
|
||||
MaxImageDimension: 32,
|
||||
AllowedImageMIME: []string{"image/png"},
|
||||
}),
|
||||
)
|
||||
r.RegisterProvider(fp)
|
||||
|
||||
m, _ := r.Parse("fp/small-vision")
|
||||
_, err := m.Generate(context.Background(), Request{Messages: []Message{
|
||||
UserParts(Text("describe"), Image("image/png", pngImage(t, 100, 50))),
|
||||
}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
|
||||
calls := fp.Calls()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls = %d", len(calls))
|
||||
}
|
||||
var img llm.ImagePart
|
||||
for _, part := range calls[0].Request.Messages[0].Parts {
|
||||
if ip, ok := part.(llm.ImagePart); ok {
|
||||
img = ip
|
||||
}
|
||||
}
|
||||
if img.Data == nil {
|
||||
t.Fatal("no image reached the provider")
|
||||
}
|
||||
cfg, err := png.DecodeConfig(bytes.NewReader(img.Data))
|
||||
if err != nil {
|
||||
t.Fatalf("decode delivered image: %v", err)
|
||||
}
|
||||
if cfg.Width != 32 || cfg.Height != 16 {
|
||||
t.Errorf("delivered image = %dx%d, want 32x16 (downscaled to target cap)", cfg.Width, cfg.Height)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChainAdvancesPastImagelessTarget: a text-only head can't take an
|
||||
// image request; the chain advances to a vision-capable element with no
|
||||
// health penalty.
|
||||
func TestChainAdvancesPastImagelessTarget(t *testing.T) {
|
||||
r := newTestRegistry(t)
|
||||
fp := fake.New("fp",
|
||||
fake.WithModelCapabilities("text-only", llm.Capabilities{SupportsTools: true}),
|
||||
fake.WithModelCapabilities("vision", llm.Capabilities{MaxImagesPerReq: 4}),
|
||||
)
|
||||
r.RegisterProvider(fp)
|
||||
fp.Enqueue("vision", fake.Reply("a tasteful png"))
|
||||
|
||||
m, _ := r.Parse("fp/text-only,fp/vision")
|
||||
resp, err := m.Generate(context.Background(), Request{Messages: []Message{
|
||||
UserParts(Text("what is this?"), Image("image/png", pngImage(t, 8, 8))),
|
||||
}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if resp.Text() != "a tasteful png" {
|
||||
t.Errorf("text = %q", resp.Text())
|
||||
}
|
||||
if got := fp.CallCount("text-only"); got != 0 {
|
||||
t.Errorf("text-only target saw %d calls, want 0 (normalization rejects pre-send)", got)
|
||||
}
|
||||
if !r.Health().Available("fp/text-only") {
|
||||
t.Error("media rejection must not penalize health")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHTTP529ClassifiedTransient: Anthropic's "overloaded" status fails
|
||||
// over like any other transient error.
|
||||
func TestHTTP529FailsOver(t *testing.T) {
|
||||
|
||||
+293
@@ -0,0 +1,293 @@
|
||||
// Package media fits request images to a target's declared capabilities.
|
||||
//
|
||||
// Normalize sniffs each image's real format from magic bytes (declared MIME
|
||||
// types lie), corrects the part's MIME, and passes through anything that
|
||||
// already satisfies the target's llm.Capabilities. Images that do not fit
|
||||
// are decoded, downscaled (never upscaled), and re-encoded into an allowed
|
||||
// format and byte budget. Anything that cannot honestly be made to fit —
|
||||
// undecodable formats, impossible byte budgets, too many images, images for
|
||||
// a text-only target — fails with an error wrapping llm.ErrUnsupported so a
|
||||
// failover chain can advance to a more capable target without a health
|
||||
// penalty.
|
||||
//
|
||||
// Why a separate package: every provider would otherwise duplicate the same
|
||||
// decode/scale/encode pipeline. Providers keep only a cheap capability
|
||||
// enforcement backstop; this package performs the actual transformation,
|
||||
// once, against whichever target a chain is currently attempting.
|
||||
package media
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/gif"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// Normalize returns a copy of req whose images fit caps, transforming
|
||||
// (downscale, re-encode) where needed. The input request is never mutated.
|
||||
//
|
||||
// Fast paths: a request with no image parts, or whose images already satisfy
|
||||
// caps, is returned unchanged with all underlying slices shared. When any
|
||||
// image transforms, the Messages slice and the Parts slices of affected
|
||||
// messages are copied (copy-on-write); untouched parts stay shared.
|
||||
//
|
||||
// Images that cannot be made to fit return an error wrapping
|
||||
// llm.ErrUnsupported.
|
||||
func Normalize(req llm.Request, caps llm.Capabilities) (llm.Request, error) {
|
||||
total := 0
|
||||
for i := range req.Messages {
|
||||
for _, p := range req.Messages[i].Parts {
|
||||
if _, ok := p.(llm.ImagePart); ok {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
if total == 0 {
|
||||
return req, nil
|
||||
}
|
||||
if !caps.SupportsImages() {
|
||||
return llm.Request{}, fmt.Errorf("media: %w: target does not accept image input (request carries %d image(s))", llm.ErrUnsupported, total)
|
||||
}
|
||||
// Why error instead of dropping the overflow: silently removing an image
|
||||
// changes the question the caller asked; the honest move is to refuse and
|
||||
// let a chain try a roomier target.
|
||||
if total > caps.MaxImagesPerReq {
|
||||
return llm.Request{}, fmt.Errorf("media: %w: request carries %d images, target allows at most %d per request", llm.ErrUnsupported, total, caps.MaxImagesPerReq)
|
||||
}
|
||||
|
||||
out := req
|
||||
copiedMessages := false
|
||||
for mi := range req.Messages {
|
||||
copiedParts := false
|
||||
for pi, part := range req.Messages[mi].Parts {
|
||||
ip, ok := part.(llm.ImagePart)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
norm, changed, err := normalizeImage(ip, caps)
|
||||
if err != nil {
|
||||
return llm.Request{}, fmt.Errorf("media: message %d, part %d: %w", mi, pi, err)
|
||||
}
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
if !copiedMessages {
|
||||
out.Messages = make([]llm.Message, len(req.Messages))
|
||||
copy(out.Messages, req.Messages)
|
||||
copiedMessages = true
|
||||
}
|
||||
if !copiedParts {
|
||||
parts := make([]llm.Part, len(req.Messages[mi].Parts))
|
||||
copy(parts, req.Messages[mi].Parts)
|
||||
out.Messages[mi].Parts = parts
|
||||
copiedParts = true
|
||||
}
|
||||
out.Messages[mi].Parts[pi] = norm
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Info reports an image part's sniffed format ("jpeg", "png", "gif", or
|
||||
// "webp") and pixel dimensions. It is a cheap metadata read — the pixels are
|
||||
// never decoded. webp is recognized by signature but not decodable with the
|
||||
// standard library, so it reports format "webp" with zero dimensions and a
|
||||
// nil error.
|
||||
func Info(p llm.ImagePart) (format string, width, height int, err error) {
|
||||
format = sniff(p.Data)
|
||||
switch format {
|
||||
case "":
|
||||
return "", 0, 0, fmt.Errorf("media: image bytes match no known format (jpeg, png, gif, webp)")
|
||||
case "webp":
|
||||
return "webp", 0, 0, nil
|
||||
}
|
||||
cfg, _, err := image.DecodeConfig(bytes.NewReader(p.Data))
|
||||
if err != nil {
|
||||
return format, 0, 0, fmt.Errorf("media: decode %s config: %w", format, err)
|
||||
}
|
||||
return format, cfg.Width, cfg.Height, nil
|
||||
}
|
||||
|
||||
// normalizeImage fits one image part to caps. It returns the (possibly
|
||||
// transformed) part and whether it differs from the input. A corrected MIME
|
||||
// with untouched bytes still counts as changed so Normalize copy-on-writes
|
||||
// the containing slices.
|
||||
func normalizeImage(p llm.ImagePart, caps llm.Capabilities) (llm.ImagePart, bool, error) {
|
||||
// Why sniff instead of trusting p.MIME: callers routinely mislabel image
|
||||
// bytes, and providers reject mismatches; the bytes are the truth.
|
||||
format := sniff(p.Data)
|
||||
if format == "" {
|
||||
return p, false, fmt.Errorf("%w: image bytes (declared %q) match no known format (jpeg, png, gif, webp)", llm.ErrUnsupported, p.MIME)
|
||||
}
|
||||
realMIME := "image/" + format
|
||||
changed := false
|
||||
if p.MIME != realMIME {
|
||||
p.MIME = realMIME
|
||||
changed = true
|
||||
}
|
||||
|
||||
mimeOK := caps.MIMEAllowed(realMIME)
|
||||
fitsBytes := caps.MaxImageBytes == 0 || len(p.Data) <= caps.MaxImageBytes
|
||||
fitsDims := true
|
||||
if caps.MaxImageDimension > 0 && format != "webp" {
|
||||
// Cheap header-only dimension read; a failure forces the transform
|
||||
// path, which surfaces the real decode error.
|
||||
cfg, _, err := image.DecodeConfig(bytes.NewReader(p.Data))
|
||||
if err != nil {
|
||||
fitsDims = false
|
||||
} else {
|
||||
fitsDims = cfg.Width <= caps.MaxImageDimension && cfg.Height <= caps.MaxImageDimension
|
||||
}
|
||||
}
|
||||
// Why webp skips the dimension check: the stdlib cannot read webp
|
||||
// headers, so dimensions are unverifiable; if MIME and bytes fit we pass
|
||||
// it through rather than reject a possibly-fine image.
|
||||
if mimeOK && fitsBytes && fitsDims {
|
||||
return p, changed, nil
|
||||
}
|
||||
|
||||
// Transformation required from here on, which needs a real decode.
|
||||
if format == "webp" {
|
||||
return p, false, fmt.Errorf("%w: image is webp (%d bytes), which the Go standard library cannot decode; provide jpeg, png, or gif instead", llm.ErrUnsupported, len(p.Data))
|
||||
}
|
||||
img, _, err := image.Decode(bytes.NewReader(p.Data))
|
||||
if err != nil {
|
||||
return p, false, fmt.Errorf("%w: cannot decode %s image for transformation: %v", llm.ErrUnsupported, format, err)
|
||||
}
|
||||
|
||||
if caps.MaxImageDimension > 0 {
|
||||
b := img.Bounds()
|
||||
if b.Dx() > caps.MaxImageDimension || b.Dy() > caps.MaxImageDimension {
|
||||
nw, nh := fitDims(b.Dx(), b.Dy(), caps.MaxImageDimension)
|
||||
img = downscale(img, nw, nh)
|
||||
}
|
||||
}
|
||||
|
||||
target, err := targetMIME(realMIME, caps)
|
||||
if err != nil {
|
||||
return p, false, err
|
||||
}
|
||||
data, err := encodeFit(img, target, caps.MaxImageBytes)
|
||||
if err != nil {
|
||||
return p, false, err
|
||||
}
|
||||
return llm.ImagePart{MIME: target, Data: data}, true, nil
|
||||
}
|
||||
|
||||
// sniff identifies an image format from its magic bytes, returning "jpeg",
|
||||
// "png", "gif", "webp", or "" when nothing matches.
|
||||
func sniff(data []byte) string {
|
||||
switch {
|
||||
case len(data) >= 3 && data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF:
|
||||
return "jpeg"
|
||||
case len(data) >= 4 && data[0] == 0x89 && data[1] == 'P' && data[2] == 'N' && data[3] == 'G':
|
||||
return "png"
|
||||
case len(data) >= 4 && string(data[:4]) == "GIF8":
|
||||
return "gif"
|
||||
case len(data) >= 12 && string(data[:4]) == "RIFF" && string(data[8:12]) == "WEBP":
|
||||
return "webp"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// encodableMIME reports whether the stdlib can encode the given image type.
|
||||
func encodableMIME(mime string) bool {
|
||||
switch mime {
|
||||
case "image/jpeg", "image/png", "image/gif":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// targetMIME picks the re-encode format: the original when allowed, else
|
||||
// jpeg, else png, else the first allowed encodable type (gif). When nothing
|
||||
// allowed is encodable (e.g. only webp), it errors with llm.ErrUnsupported.
|
||||
func targetMIME(original string, caps llm.Capabilities) (string, error) {
|
||||
if encodableMIME(original) && caps.MIMEAllowed(original) {
|
||||
return original, nil
|
||||
}
|
||||
// Why jpeg before png: vision inputs are photographs more often than
|
||||
// screenshots, and jpeg's quality knob is the only size lever we have
|
||||
// for the byte-budget loop.
|
||||
for _, m := range []string{"image/jpeg", "image/png"} {
|
||||
if caps.MIMEAllowed(m) {
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
// An empty allow-list permits everything and was caught above, so the
|
||||
// list is non-empty here: take its first encodable entry.
|
||||
for _, m := range caps.AllowedImageMIME {
|
||||
if encodableMIME(m) {
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("%w: none of the allowed image types %v can be encoded with the Go standard library", llm.ErrUnsupported, caps.AllowedImageMIME)
|
||||
}
|
||||
|
||||
// encodeFit encodes img as mime within maxBytes (0 = no limit), trading
|
||||
// jpeg quality first and then resolution for size. The ladder is fixed
|
||||
// (jpeg: q85, q65, q45, q30, then half and quarter dimensions at q65;
|
||||
// png/gif: full, half, quarter dimensions) — at most six attempts, since an
|
||||
// image that survives a 16x pixel reduction over budget will not be saved
|
||||
// by further fiddling.
|
||||
func encodeFit(img image.Image, mime string, maxBytes int) ([]byte, error) {
|
||||
type attempt struct {
|
||||
div int // divide both dimensions by this
|
||||
quality int // jpeg quality; ignored for png/gif
|
||||
}
|
||||
var ladder []attempt
|
||||
if mime == "image/jpeg" {
|
||||
ladder = []attempt{{1, 85}, {1, 65}, {1, 45}, {1, 30}, {2, 65}, {4, 65}}
|
||||
} else {
|
||||
ladder = []attempt{{1, 0}, {2, 0}, {4, 0}}
|
||||
}
|
||||
|
||||
scaled := map[int]image.Image{1: img}
|
||||
smallest := -1
|
||||
for _, a := range ladder {
|
||||
cur, ok := scaled[a.div]
|
||||
if !ok {
|
||||
b := img.Bounds()
|
||||
nw, nh := max(b.Dx()/a.div, 1), max(b.Dy()/a.div, 1)
|
||||
cur = downscale(img, nw, nh)
|
||||
scaled[a.div] = cur
|
||||
}
|
||||
data, err := encodeImage(cur, mime, a.quality)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode %s: %w", mime, err)
|
||||
}
|
||||
if maxBytes == 0 || len(data) <= maxBytes {
|
||||
return data, nil
|
||||
}
|
||||
if smallest == -1 || len(data) < smallest {
|
||||
smallest = len(data)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("%w: image cannot be reduced to the %d-byte limit; smallest achievable %s encoding is %d bytes", llm.ErrUnsupported, maxBytes, mime, smallest)
|
||||
}
|
||||
|
||||
// encodeImage encodes img into the given MIME type. quality applies to jpeg
|
||||
// only.
|
||||
func encodeImage(img image.Image, mime string, quality int) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
var err error
|
||||
switch mime {
|
||||
case "image/jpeg":
|
||||
err = jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality})
|
||||
case "image/png":
|
||||
err = png.Encode(&buf, img)
|
||||
case "image/gif":
|
||||
err = gif.Encode(&buf, img, nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("no stdlib encoder for %q", mime)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
@@ -0,0 +1,513 @@
|
||||
package media
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/gif"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"math/rand/v2"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// --- test image builders -------------------------------------------------
|
||||
|
||||
// gradient builds a smooth w x h RGBA image (compresses well).
|
||||
func gradient(w, h int) *image.RGBA {
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for y := 0; y < h; y++ {
|
||||
for x := 0; x < w; x++ {
|
||||
img.SetRGBA(x, y, color.RGBA{
|
||||
R: uint8(x * 255 / max(w-1, 1)),
|
||||
G: uint8(y * 255 / max(h-1, 1)),
|
||||
B: 128,
|
||||
A: 255,
|
||||
})
|
||||
}
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
// noisy builds a w x h image of deterministic random pixels (compresses
|
||||
// terribly — ideal for exercising the byte-budget ladder).
|
||||
func noisy(w, h int) *image.RGBA {
|
||||
rng := rand.New(rand.NewPCG(1, 2))
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for i := range img.Pix {
|
||||
img.Pix[i] = uint8(rng.UintN(256))
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func encPNG(t *testing.T, img image.Image) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
t.Fatalf("png encode: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func encJPEG(t *testing.T, img image.Image) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 90}); err != nil {
|
||||
t.Fatalf("jpeg encode: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func encGIF(t *testing.T, img image.Image) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
if err := gif.Encode(&buf, img, nil); err != nil {
|
||||
t.Fatalf("gif encode: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// webpBlob is a minimal byte sequence carrying the RIFF/WEBP signature.
|
||||
// The stdlib cannot decode webp, so sniffing is all that ever reads it.
|
||||
func webpBlob() []byte {
|
||||
b := []byte("RIFF")
|
||||
b = append(b, 0x1a, 0x00, 0x00, 0x00)
|
||||
b = append(b, "WEBPVP8 "...)
|
||||
b = append(b, make([]byte, 18)...)
|
||||
return b
|
||||
}
|
||||
|
||||
func imgReq(parts ...llm.Part) llm.Request {
|
||||
return llm.Request{Messages: []llm.Message{llm.UserParts(parts...)}}
|
||||
}
|
||||
|
||||
// firstImage returns the first image part in the request.
|
||||
func firstImage(t *testing.T, req llm.Request) llm.ImagePart {
|
||||
t.Helper()
|
||||
for _, m := range req.Messages {
|
||||
for _, p := range m.Parts {
|
||||
if ip, ok := p.(llm.ImagePart); ok {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Fatal("no image part in request")
|
||||
return llm.ImagePart{}
|
||||
}
|
||||
|
||||
// --- fast paths -----------------------------------------------------------
|
||||
|
||||
func TestNormalizeFastPathNoImages(t *testing.T) {
|
||||
req := llm.Request{Messages: []llm.Message{llm.UserText("hello")}}
|
||||
got, err := Normalize(req, llm.Capabilities{}) // even a no-image target
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
if &got.Messages[0] != &req.Messages[0] {
|
||||
t.Error("messages slice was copied on the no-image fast path")
|
||||
}
|
||||
if &got.Messages[0].Parts[0] != &req.Messages[0].Parts[0] {
|
||||
t.Error("parts slice was copied on the no-image fast path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeFastPathFittingImages(t *testing.T) {
|
||||
data := encPNG(t, gradient(20, 10))
|
||||
req := imgReq(llm.Text("look:"), llm.Image("image/png", data))
|
||||
caps := llm.Capabilities{
|
||||
MaxImagesPerReq: 4,
|
||||
MaxImageBytes: len(data) + 100,
|
||||
MaxImageDimension: 64,
|
||||
AllowedImageMIME: []string{"image/png"},
|
||||
}
|
||||
got, err := Normalize(req, caps)
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
if &got.Messages[0] != &req.Messages[0] {
|
||||
t.Error("messages slice was copied although every image already fits")
|
||||
}
|
||||
if &got.Messages[0].Parts[1] != &req.Messages[0].Parts[1] {
|
||||
t.Error("parts slice was copied although every image already fits")
|
||||
}
|
||||
}
|
||||
|
||||
// --- rejection paths ------------------------------------------------------
|
||||
|
||||
func TestNormalizeImagesUnsupported(t *testing.T) {
|
||||
req := imgReq(llm.Image("image/png", encPNG(t, gradient(4, 4))))
|
||||
_, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 0})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "does not accept image input") {
|
||||
t.Errorf("err message %q lacks explanation", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeTooManyImages(t *testing.T) {
|
||||
img := llm.Image("image/png", encPNG(t, gradient(4, 4)))
|
||||
req := llm.Request{Messages: []llm.Message{
|
||||
llm.UserParts(img, img),
|
||||
llm.UserParts(img),
|
||||
}}
|
||||
_, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 2})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "3 images") || !strings.Contains(err.Error(), "at most 2") {
|
||||
t.Errorf("err message %q lacks the counts", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGarbageBytes(t *testing.T) {
|
||||
req := imgReq(llm.Image("image/png", []byte("certainly not an image")))
|
||||
_, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 1})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no known format") {
|
||||
t.Errorf("err message %q lacks a clear explanation", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- MIME sniffing & correction --------------------------------------------
|
||||
|
||||
func TestNormalizeMIMECorrection(t *testing.T) {
|
||||
data := encPNG(t, gradient(8, 8))
|
||||
req := imgReq(llm.Image("image/jpeg", data)) // caller lies: bytes are png
|
||||
got, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 1})
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
ip := firstImage(t, got)
|
||||
if ip.MIME != "image/png" {
|
||||
t.Errorf("MIME = %q, want sniff-corrected %q", ip.MIME, "image/png")
|
||||
}
|
||||
if !bytes.Equal(ip.Data, data) {
|
||||
t.Error("image bytes changed although only the MIME needed correcting")
|
||||
}
|
||||
if orig := firstImage(t, req); orig.MIME != "image/jpeg" {
|
||||
t.Errorf("input request mutated: MIME now %q", orig.MIME)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCopyOnWrite(t *testing.T) {
|
||||
data := encPNG(t, gradient(8, 8))
|
||||
req := llm.Request{Messages: []llm.Message{
|
||||
llm.UserText("untouched message"),
|
||||
llm.UserParts(llm.Text("untouched part"), llm.Image("image/jpeg", data)),
|
||||
}}
|
||||
got, err := Normalize(req, llm.Capabilities{MaxImagesPerReq: 1})
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
if &got.Messages[0] == &req.Messages[0] {
|
||||
t.Error("messages slice shared although a part changed (mutation hazard)")
|
||||
}
|
||||
if &got.Messages[0].Parts[0] != &req.Messages[0].Parts[0] {
|
||||
t.Error("parts slice of the untouched message was copied")
|
||||
}
|
||||
if &got.Messages[1].Parts[0] == &req.Messages[1].Parts[0] {
|
||||
t.Error("parts slice of the changed message is still shared (mutation hazard)")
|
||||
}
|
||||
}
|
||||
|
||||
// --- dimension capping ------------------------------------------------------
|
||||
|
||||
func TestNormalizeDownscale(t *testing.T) {
|
||||
req := imgReq(llm.Image("image/png", encPNG(t, gradient(200, 100))))
|
||||
caps := llm.Capabilities{MaxImagesPerReq: 1, MaxImageDimension: 50}
|
||||
got, err := Normalize(req, caps)
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
format, w, h, err := Info(firstImage(t, got))
|
||||
if err != nil {
|
||||
t.Fatalf("Info: %v", err)
|
||||
}
|
||||
if format != "png" {
|
||||
t.Errorf("format = %q, want original format %q preserved", format, "png")
|
||||
}
|
||||
if w != 50 || h != 25 {
|
||||
t.Errorf("dimensions = %dx%d, want 50x25 (aspect preserved)", w, h)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeDownscalePortrait(t *testing.T) {
|
||||
req := imgReq(llm.Image("image/png", encPNG(t, gradient(100, 200))))
|
||||
caps := llm.Capabilities{MaxImagesPerReq: 1, MaxImageDimension: 50}
|
||||
got, err := Normalize(req, caps)
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
_, w, h, err := Info(firstImage(t, got))
|
||||
if err != nil {
|
||||
t.Fatalf("Info: %v", err)
|
||||
}
|
||||
if w != 25 || h != 50 {
|
||||
t.Errorf("dimensions = %dx%d, want 25x50 (aspect preserved)", w, h)
|
||||
}
|
||||
}
|
||||
|
||||
// --- transcoding -------------------------------------------------------------
|
||||
|
||||
func TestNormalizeTranscode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
mime string
|
||||
allowed []string
|
||||
want string
|
||||
}{
|
||||
{"png to jpeg", encPNG(t, gradient(16, 16)), "image/png", []string{"image/jpeg"}, "jpeg"},
|
||||
{"jpeg to png", encJPEG(t, gradient(16, 16)), "image/jpeg", []string{"image/png"}, "png"},
|
||||
{"gif to png", encGIF(t, gradient(16, 16)), "image/gif", []string{"image/png"}, "png"},
|
||||
{"png to gif fallback", encPNG(t, gradient(16, 16)), "image/png", []string{"image/gif"}, "gif"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := imgReq(llm.Image(tt.mime, tt.data))
|
||||
caps := llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: tt.allowed}
|
||||
got, err := Normalize(req, caps)
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
ip := firstImage(t, got)
|
||||
if ip.MIME != "image/"+tt.want {
|
||||
t.Errorf("MIME = %q, want %q", ip.MIME, "image/"+tt.want)
|
||||
}
|
||||
format, w, h, err := Info(ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Info: %v", err)
|
||||
}
|
||||
if format != tt.want {
|
||||
t.Errorf("sniffed format = %q, want %q", format, tt.want)
|
||||
}
|
||||
if w != 16 || h != 16 {
|
||||
t.Errorf("dimensions = %dx%d, want 16x16 (no resize needed)", w, h)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeNoEncodableAllowedType(t *testing.T) {
|
||||
// png needs transcoding but the only allowed type is webp, which the
|
||||
// stdlib cannot encode.
|
||||
req := imgReq(llm.Image("image/png", encPNG(t, gradient(8, 8))))
|
||||
caps := llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/webp"}}
|
||||
_, err := Normalize(req, caps)
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "image/webp") {
|
||||
t.Errorf("err message %q does not name the unencodable allowed types", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- byte budget ---------------------------------------------------------------
|
||||
|
||||
func TestNormalizeByteBudgetFits(t *testing.T) {
|
||||
// Random noise defeats q85 jpeg at full size; the ladder must walk down
|
||||
// quality and then resolution until the encoding fits.
|
||||
req := imgReq(llm.Image("image/png", encPNG(t, noisy(256, 256))))
|
||||
caps := llm.Capabilities{
|
||||
MaxImagesPerReq: 1,
|
||||
AllowedImageMIME: []string{"image/jpeg"},
|
||||
MaxImageBytes: 8 * 1024,
|
||||
}
|
||||
got, err := Normalize(req, caps)
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
ip := firstImage(t, got)
|
||||
if len(ip.Data) > caps.MaxImageBytes {
|
||||
t.Errorf("len(Data) = %d, exceeds budget %d", len(ip.Data), caps.MaxImageBytes)
|
||||
}
|
||||
if format, _, _, err := Info(ip); err != nil || format != "jpeg" {
|
||||
t.Errorf("Info = %q, %v; want jpeg, nil", format, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeByteBudgetImpossible(t *testing.T) {
|
||||
req := imgReq(llm.Image("image/png", encPNG(t, noisy(256, 256))))
|
||||
caps := llm.Capabilities{
|
||||
MaxImagesPerReq: 1,
|
||||
AllowedImageMIME: []string{"image/jpeg"},
|
||||
MaxImageBytes: 10, // no image fits in 10 bytes
|
||||
}
|
||||
_, err := Normalize(req, caps)
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "10-byte limit") {
|
||||
t.Errorf("err message %q lacks the budget", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "smallest achievable") {
|
||||
t.Errorf("err message %q lacks the achieved size", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- webp ---------------------------------------------------------------------
|
||||
|
||||
func TestNormalizeWebPPassThrough(t *testing.T) {
|
||||
data := webpBlob()
|
||||
req := imgReq(llm.Image("image/webp", data))
|
||||
caps := llm.Capabilities{
|
||||
MaxImagesPerReq: 1,
|
||||
MaxImageBytes: 1024,
|
||||
MaxImageDimension: 50, // unverifiable for webp; must not force a transform
|
||||
AllowedImageMIME: []string{"image/webp"},
|
||||
}
|
||||
got, err := Normalize(req, caps)
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
if &got.Messages[0] != &req.Messages[0] {
|
||||
t.Error("request copied although the webp image passes through")
|
||||
}
|
||||
if ip := firstImage(t, got); !bytes.Equal(ip.Data, data) {
|
||||
t.Error("webp bytes changed on pass-through")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeWebPNeedsTransform(t *testing.T) {
|
||||
req := imgReq(llm.Image("image/webp", webpBlob()))
|
||||
caps := llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/jpeg"}}
|
||||
_, err := Normalize(req, caps)
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "webp") {
|
||||
t.Errorf("err message %q does not name the format", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "jpeg, png, or gif") {
|
||||
t.Errorf("err message %q does not say what to provide instead", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- input immutability ----------------------------------------------------------
|
||||
|
||||
func TestNormalizeInputNotMutated(t *testing.T) {
|
||||
data := encPNG(t, gradient(200, 100))
|
||||
snapshot := bytes.Clone(data)
|
||||
req := llm.Request{
|
||||
System: "sys",
|
||||
Messages: []llm.Message{
|
||||
llm.UserParts(llm.Text("scale me"), llm.Image("image/jpeg", data)),
|
||||
},
|
||||
}
|
||||
caps := llm.Capabilities{
|
||||
MaxImagesPerReq: 1,
|
||||
MaxImageDimension: 50,
|
||||
AllowedImageMIME: []string{"image/jpeg"},
|
||||
}
|
||||
got, err := Normalize(req, caps)
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
orig := firstImage(t, req)
|
||||
if orig.MIME != "image/jpeg" {
|
||||
t.Errorf("input MIME mutated to %q", orig.MIME)
|
||||
}
|
||||
if !bytes.Equal(orig.Data, snapshot) {
|
||||
t.Error("input image bytes mutated")
|
||||
}
|
||||
if txt := req.Messages[0].Parts[0].(llm.TextPart); txt.Text != "scale me" {
|
||||
t.Errorf("input text part mutated: %q", txt.Text)
|
||||
}
|
||||
if ip := firstImage(t, got); bytes.Equal(ip.Data, snapshot) {
|
||||
t.Error("output image was expected to transform but is byte-identical")
|
||||
}
|
||||
}
|
||||
|
||||
// --- alpha handling ----------------------------------------------------------------
|
||||
|
||||
func TestNormalizeAlphaPNGToJPEG(t *testing.T) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, 32, 32))
|
||||
for y := 0; y < 32; y++ {
|
||||
for x := 0; x < 32; x++ {
|
||||
img.SetRGBA(x, y, color.RGBA{R: 200, G: 60, B: 30, A: uint8(x * 8)})
|
||||
}
|
||||
}
|
||||
req := imgReq(llm.Image("image/png", encPNG(t, img)))
|
||||
caps := llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/jpeg"}}
|
||||
got, err := Normalize(req, caps)
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
ip := firstImage(t, got)
|
||||
decoded, err := jpeg.Decode(bytes.NewReader(ip.Data))
|
||||
if err != nil {
|
||||
t.Fatalf("decoding transcoded jpeg: %v", err)
|
||||
}
|
||||
if b := decoded.Bounds(); b.Dx() != 32 || b.Dy() != 32 {
|
||||
t.Errorf("decoded dimensions = %dx%d, want 32x32", b.Dx(), b.Dy())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Info ----------------------------------------------------------------------------
|
||||
|
||||
func TestInfo(t *testing.T) {
|
||||
pngData := encPNG(t, gradient(10, 7))
|
||||
jpegData := encJPEG(t, gradient(5, 9))
|
||||
gifData := encGIF(t, gradient(6, 4))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
part llm.ImagePart
|
||||
format string
|
||||
w, h int
|
||||
wantErr bool
|
||||
}{
|
||||
{"png", llm.ImagePart{MIME: "image/png", Data: pngData}, "png", 10, 7, false},
|
||||
{"jpeg", llm.ImagePart{MIME: "image/jpeg", Data: jpegData}, "jpeg", 5, 9, false},
|
||||
{"gif", llm.ImagePart{MIME: "image/gif", Data: gifData}, "gif", 6, 4, false},
|
||||
{"mislabeled png", llm.ImagePart{MIME: "image/jpeg", Data: pngData}, "png", 10, 7, false},
|
||||
{"webp", llm.ImagePart{MIME: "image/webp", Data: webpBlob()}, "webp", 0, 0, false},
|
||||
{"garbage", llm.ImagePart{MIME: "image/png", Data: []byte("nope")}, "", 0, 0, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
format, w, h, err := Info(tt.part)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("Info: expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Info: %v", err)
|
||||
}
|
||||
if format != tt.format || w != tt.w || h != tt.h {
|
||||
t.Errorf("Info = %q, %d, %d; want %q, %d, %d", format, w, h, tt.format, tt.w, tt.h)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- byte-cap pass-through interplay ----------------------------------------------------
|
||||
|
||||
func TestNormalizeOversizeBytesTriggersTransform(t *testing.T) {
|
||||
// A fitting MIME and dimension but an over-budget payload must re-encode,
|
||||
// not pass through.
|
||||
data := encPNG(t, noisy(64, 64))
|
||||
req := imgReq(llm.Image("image/png", data))
|
||||
caps := llm.Capabilities{
|
||||
MaxImagesPerReq: 1,
|
||||
MaxImageBytes: len(data) / 2,
|
||||
AllowedImageMIME: []string{"image/png", "image/jpeg"},
|
||||
}
|
||||
got, err := Normalize(req, caps)
|
||||
if err != nil {
|
||||
t.Fatalf("Normalize: %v", err)
|
||||
}
|
||||
ip := firstImage(t, got)
|
||||
if len(ip.Data) > caps.MaxImageBytes {
|
||||
t.Errorf("len(Data) = %d, exceeds budget %d", len(ip.Data), caps.MaxImageBytes)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package media
|
||||
|
||||
import "image"
|
||||
|
||||
// fitDims scales (w, h) so the longer side equals limit, preserving aspect
|
||||
// ratio with round-half-up on the shorter side, floored at 1 pixel.
|
||||
func fitDims(w, h, limit int) (int, int) {
|
||||
if w >= h {
|
||||
return limit, max((h*limit+w/2)/w, 1)
|
||||
}
|
||||
return max((w*limit+h/2)/h, 1), limit
|
||||
}
|
||||
|
||||
// downscale resizes src to dw x dh using area averaging (a box filter): each
|
||||
// destination pixel is the mean of its corresponding source region.
|
||||
//
|
||||
// Why hand-rolled: the stdlib has no scaler and ADR-0007 bars
|
||||
// golang.org/x/image without a new ADR. Area averaging is dependency-free,
|
||||
// alias-resistant when shrinking (every source pixel contributes exactly
|
||||
// once), and entirely adequate quality for vision-model input. It is only
|
||||
// ever called to shrink — Normalize never upscales.
|
||||
func downscale(src image.Image, dw, dh int) *image.RGBA {
|
||||
b := src.Bounds()
|
||||
sw, sh := b.Dx(), b.Dy()
|
||||
dst := image.NewRGBA(image.Rect(0, 0, dw, dh))
|
||||
for dy := 0; dy < dh; dy++ {
|
||||
// Integer box edges: destination pixel dy covers source rows
|
||||
// [dy*sh/dh, (dy+1)*sh/dh), widened to at least one row.
|
||||
sy0 := dy * sh / dh
|
||||
sy1 := max((dy+1)*sh/dh, sy0+1)
|
||||
for dx := 0; dx < dw; dx++ {
|
||||
sx0 := dx * sw / dw
|
||||
sx1 := max((dx+1)*sw/dw, sx0+1)
|
||||
var r, g, bl, a uint64
|
||||
for sy := sy0; sy < sy1; sy++ {
|
||||
for sx := sx0; sx < sx1; sx++ {
|
||||
pr, pg, pb, pa := src.At(b.Min.X+sx, b.Min.Y+sy).RGBA()
|
||||
r += uint64(pr)
|
||||
g += uint64(pg)
|
||||
bl += uint64(pb)
|
||||
a += uint64(pa)
|
||||
}
|
||||
}
|
||||
n := uint64((sy1 - sy0) * (sx1 - sx0))
|
||||
i := dst.PixOffset(dx, dy)
|
||||
// RGBA() returns 16-bit channels; average, then drop to 8 bits.
|
||||
dst.Pix[i+0] = uint8(r / n >> 8)
|
||||
dst.Pix[i+1] = uint8(g / n >> 8)
|
||||
dst.Pix[i+2] = uint8(bl / n >> 8)
|
||||
dst.Pix[i+3] = uint8(a / n >> 8)
|
||||
}
|
||||
}
|
||||
return dst
|
||||
}
|
||||
+40
@@ -1,5 +1,45 @@
|
||||
# progress
|
||||
|
||||
## 2026-06-10 — Phase 3: REST providers (OpenAI, Anthropic, Ollama×3) + media
|
||||
|
||||
**Landed:**
|
||||
- `provider/openai`: Chat Completions client for OpenAI and every
|
||||
OpenAI-compatible endpoint (tools with string-arguments mapping, strict
|
||||
SSE streaming incl. by-index tool-call assembly and the empty-choices
|
||||
usage chunk, response_format json_schema, max_completion_tokens with a
|
||||
WithLegacyMaxTokens compat option, reasoning_effort).
|
||||
- `provider/anthropic`: Messages API client (anthropic-version 2023-06-01,
|
||||
required-max_tokens defaulting, tool_use/tool_result blocks with native
|
||||
is_error, GA structured output via output_config.format, full SSE event
|
||||
parser with input_json_delta buffering, 529-overloaded classified
|
||||
transient, usage sums cache tokens).
|
||||
- `provider/ollama`: ONE native /api/chat client serving ollama (local,
|
||||
OLLAMA_HOST normalization), ollama-cloud (https://ollama.com + bearer
|
||||
OLLAMA_API_KEY), and foreman (base URL + bearer; tolerates its
|
||||
buffered-single-object "streaming"). Object tool arguments, tool_name
|
||||
results, format-schema structured output, think-level mapping, NDJSON
|
||||
streaming with 16MB lines.
|
||||
- `media/`: normalization pipeline per ADR-0009 (magic-byte sniffing,
|
||||
box-filter downscale, transcode preference ladder, byte-budget quality
|
||||
ladder, webp passthrough-or-reject, copy-on-write, everything-unfittable
|
||||
wraps ErrUnsupported).
|
||||
- Chain executor now normalizes media PER TARGET before each attempt and
|
||||
advances penalty-free past targets that can't take the request (proven:
|
||||
text-only head + vision fallback; per-target downscale assertions).
|
||||
- Registry: real providers + scheme factories wired for openai, anthropic,
|
||||
ollama, ollama-cloud, foreman (google still stubbed, Phase 4);
|
||||
WithHTTPClient registry option; required env-foreman TLS chat round-trip
|
||||
test (LLM_FM=foreman://token@host → Parse("fm/qwen3:30b") → bearer
|
||||
arrives, chat answers).
|
||||
- ADR-0009 (multimodal), ADR-0010 (tools/structured mapping); README
|
||||
matrix flipped to ✅ for the four landed provider families; ~70 new
|
||||
hermetic tests across the three provider packages + media.
|
||||
- Run note: openai/anthropic/media were built by three parallel
|
||||
subagents against the frozen llm contract; ollama/foreman, chain wiring,
|
||||
and registry integration done in the main line. All gates green.
|
||||
|
||||
**Next:** Phase 4 — Google provider on google.golang.org/genai.
|
||||
|
||||
## 2026-06-10 — Phase 2: health + failover chain, proven
|
||||
|
||||
**Landed:** the full deterministic failover test matrix over the fake
|
||||
|
||||
@@ -0,0 +1,319 @@
|
||||
// Package anthropic implements llm.Provider for the Anthropic Messages API
|
||||
// and Anthropic-compatible endpoints.
|
||||
//
|
||||
// API surface targeted: POST {base}/v1/messages with headers x-api-key,
|
||||
// anthropic-version: 2023-06-01, and content-type: application/json, per the
|
||||
// platform.claude.com Messages API reference as of June 2026. Streaming uses
|
||||
// the documented SSE event sequence (message_start, content_block_start,
|
||||
// content_block_delta, content_block_stop, message_delta, message_stop).
|
||||
// Structured output uses the GA output_config.format mechanism with
|
||||
// {"type":"json_schema"}; the result arrives as JSON text in the first text
|
||||
// content block.
|
||||
//
|
||||
// Why a hand-rolled client (no SDK): ADR-0007 — majordomo is stdlib-first,
|
||||
// and the canonical llm contract needs only a narrow slice of the API.
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultName = "anthropic"
|
||||
defaultBaseURL = "https://api.anthropic.com"
|
||||
|
||||
// apiVersion is the anthropic-version header value. 2023-06-01 remains
|
||||
// the current (and only) stable version string as of June 2026.
|
||||
apiVersion = "2023-06-01"
|
||||
|
||||
// defaultMaxTokens is used when Request.MaxTokens is 0, because the
|
||||
// Messages API requires max_tokens on every request.
|
||||
defaultMaxTokens = 4096
|
||||
)
|
||||
|
||||
// defaultCapabilities reflects the documented first-party API image limits:
|
||||
// 100 images per request (200K-context models), 10 MB per image, 8000 px per
|
||||
// side, and the four supported media types.
|
||||
func defaultCapabilities() llm.Capabilities {
|
||||
return llm.Capabilities{
|
||||
SupportsTools: true,
|
||||
SupportsStructured: true,
|
||||
SupportsStreaming: true,
|
||||
MaxImagesPerReq: 100,
|
||||
MaxImageBytes: 10 << 20,
|
||||
MaxImageDimension: 8000,
|
||||
AllowedImageMIME: []string{
|
||||
"image/jpeg", "image/png", "image/gif", "image/webp",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Provider is an llm.Provider backed by the Anthropic Messages API.
|
||||
type Provider struct {
|
||||
name string
|
||||
apiKey string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
caps llm.Capabilities
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
// Option configures the provider at construction.
|
||||
type Option func(*Provider)
|
||||
|
||||
// WithAPIKey sets the API key explicitly, bypassing the ANTHROPIC_API_KEY
|
||||
// environment default.
|
||||
func WithAPIKey(key string) Option {
|
||||
return func(p *Provider) { p.apiKey = key }
|
||||
}
|
||||
|
||||
// WithBaseURL points the provider at an Anthropic-compatible endpoint. A
|
||||
// trailing slash is trimmed; "/v1/messages" is appended per request.
|
||||
func WithBaseURL(u string) Option {
|
||||
return func(p *Provider) { p.baseURL = strings.TrimRight(u, "/") }
|
||||
}
|
||||
|
||||
// WithHTTPClient replaces the HTTP client (timeouts, proxies, test doubles).
|
||||
func WithHTTPClient(c *http.Client) Option {
|
||||
return func(p *Provider) { p.client = c }
|
||||
}
|
||||
|
||||
// WithName overrides the registry name. Why: an Anthropic-compatible
|
||||
// endpoint registered under its own name must surface that name in
|
||||
// Response.Model and errors, not "anthropic".
|
||||
func WithName(name string) Option {
|
||||
return func(p *Provider) { p.name = name }
|
||||
}
|
||||
|
||||
// WithDefaultCapabilities replaces the provider-default capabilities.
|
||||
func WithDefaultCapabilities(caps llm.Capabilities) Option {
|
||||
return func(p *Provider) { p.caps = caps }
|
||||
}
|
||||
|
||||
// WithDefaultMaxTokens overrides the max_tokens value used when
|
||||
// Request.MaxTokens is 0. Why: the Messages API rejects requests without
|
||||
// max_tokens, so the provider must always send something.
|
||||
func WithDefaultMaxTokens(n int) Option {
|
||||
return func(p *Provider) { p.maxTokens = n }
|
||||
}
|
||||
|
||||
// New creates an Anthropic provider. It never fails: a missing API key
|
||||
// (no WithAPIKey and no ANTHROPIC_API_KEY in the environment) surfaces as a
|
||||
// 401-style *llm.APIError at request time, not at construction.
|
||||
func New(opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
name: defaultName,
|
||||
baseURL: defaultBaseURL,
|
||||
client: http.DefaultClient,
|
||||
caps: defaultCapabilities(),
|
||||
maxTokens: defaultMaxTokens,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
if p.apiKey == "" {
|
||||
p.apiKey = os.Getenv("ANTHROPIC_API_KEY")
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Name implements llm.Provider.
|
||||
func (p *Provider) Name() string { return p.name }
|
||||
|
||||
// Model implements llm.Provider. The id is passed through verbatim — it is
|
||||
// never validated against a catalog.
|
||||
func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) {
|
||||
cfg := llm.ApplyModelOptions(opts)
|
||||
caps := p.caps
|
||||
if cfg.Capabilities != nil {
|
||||
caps = *cfg.Capabilities
|
||||
}
|
||||
return &model{provider: p, id: id, caps: caps}, nil
|
||||
}
|
||||
|
||||
type model struct {
|
||||
provider *Provider
|
||||
id string
|
||||
caps llm.Capabilities
|
||||
}
|
||||
|
||||
// Capabilities implements llm.Model.
|
||||
func (m *model) Capabilities() llm.Capabilities { return m.caps }
|
||||
|
||||
// fullName is the "provider/model" identifier used in Response.Model.
|
||||
func (m *model) fullName() string { return m.provider.name + "/" + m.id }
|
||||
|
||||
// Generate implements llm.Model.
|
||||
func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) {
|
||||
req = req.Apply(opts...)
|
||||
if err := m.enforceCapabilities(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpResp, err := m.do(ctx, req, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
if httpResp.StatusCode/100 != 2 {
|
||||
return nil, m.apiError(httpResp)
|
||||
}
|
||||
var wr wireResponse
|
||||
if err := json.NewDecoder(httpResp.Body).Decode(&wr); err != nil {
|
||||
return nil, fmt.Errorf("%s: decode response: %w", m.provider.name, err)
|
||||
}
|
||||
return m.toResponse(&wr), nil
|
||||
}
|
||||
|
||||
// Stream implements llm.Model. A non-2xx status is returned as an error from
|
||||
// Stream itself, before any events are delivered.
|
||||
func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) {
|
||||
req = req.Apply(opts...)
|
||||
if err := m.enforceCapabilities(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpResp, err := m.do(ctx, req, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if httpResp.StatusCode/100 != 2 {
|
||||
defer httpResp.Body.Close()
|
||||
return nil, m.apiError(httpResp)
|
||||
}
|
||||
return newStream(m, httpResp.Body), nil
|
||||
}
|
||||
|
||||
// enforceCapabilities is the honest backstop behind the media layer: it
|
||||
// rejects (rather than silently mutates) requests the target cannot serve.
|
||||
// Why: a separate media layer resizes/transcodes images BEFORE requests
|
||||
// reach the provider, so anything still out of bounds here is a real error.
|
||||
func (m *model) enforceCapabilities(req llm.Request) error {
|
||||
images := 0
|
||||
for _, msg := range req.Messages {
|
||||
for _, part := range msg.Parts {
|
||||
img, ok := part.(llm.ImagePart)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
images++
|
||||
if !m.caps.SupportsImages() {
|
||||
return fmt.Errorf("%w: %s does not accept image input", llm.ErrUnsupported, m.fullName())
|
||||
}
|
||||
if !m.caps.MIMEAllowed(img.MIME) {
|
||||
return fmt.Errorf("%w: %s does not accept image MIME %q", llm.ErrUnsupported, m.fullName(), img.MIME)
|
||||
}
|
||||
if m.caps.MaxImageBytes > 0 && len(img.Data) > m.caps.MaxImageBytes {
|
||||
return fmt.Errorf("%w: image of %d bytes exceeds %s limit of %d bytes",
|
||||
llm.ErrUnsupported, len(img.Data), m.fullName(), m.caps.MaxImageBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
if m.caps.MaxImagesPerReq > 0 && images > m.caps.MaxImagesPerReq {
|
||||
return fmt.Errorf("%w: request carries %d images, %s allows at most %d",
|
||||
llm.ErrUnsupported, images, m.fullName(), m.caps.MaxImagesPerReq)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// do builds and executes one Messages API call. Transport errors are wrapped
|
||||
// with context but NOT converted to *llm.APIError, so llm.Classify still
|
||||
// sees the underlying net.Error / syscall errno.
|
||||
func (m *model) do(ctx context.Context, req llm.Request, streaming bool) (*http.Response, error) {
|
||||
p := m.provider
|
||||
if p.apiKey == "" {
|
||||
// Why request-time, not construction-time: New never fails by
|
||||
// convention, and a 401-shaped APIError classifies permanent so
|
||||
// chains fail fast past a misconfigured target.
|
||||
return nil, &llm.APIError{
|
||||
Provider: p.name,
|
||||
Model: m.id,
|
||||
Status: http.StatusUnauthorized,
|
||||
Code: "authentication_error",
|
||||
Message: "no API key configured: set ANTHROPIC_API_KEY or use WithAPIKey",
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(buildWireRequest(m.id, req, p.maxTokens, streaming))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: encode request: %w", p.name, err)
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/messages", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: build request: %w", p.name, err)
|
||||
}
|
||||
httpReq.Header.Set("x-api-key", p.apiKey)
|
||||
httpReq.Header.Set("anthropic-version", apiVersion)
|
||||
httpReq.Header.Set("content-type", "application/json")
|
||||
if streaming {
|
||||
httpReq.Header.Set("accept", "text/event-stream")
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: do request: %w", p.name, err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// apiError converts a non-2xx response into *llm.APIError, filling Code and
|
||||
// Message from the documented {"type":"error","error":{...}} body when it
|
||||
// parses, and falling back to the raw body text when it does not.
|
||||
func (m *model) apiError(resp *http.Response) error {
|
||||
apiErr := &llm.APIError{
|
||||
Provider: m.provider.name,
|
||||
Model: m.id,
|
||||
Status: resp.StatusCode,
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return apiErr
|
||||
}
|
||||
var we wireErrorEnvelope
|
||||
if json.Unmarshal(body, &we) == nil && we.Error.Type != "" {
|
||||
apiErr.Code = we.Error.Type
|
||||
apiErr.Message = we.Error.Message
|
||||
} else {
|
||||
apiErr.Message = strings.TrimSpace(string(body))
|
||||
}
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// toResponse maps a wire response onto the canonical llm.Response. Thinking
|
||||
// and other unrecognized block types are tolerated and skipped — they are
|
||||
// not part of the canonical content vocabulary.
|
||||
func (m *model) toResponse(wr *wireResponse) *llm.Response {
|
||||
resp := &llm.Response{
|
||||
FinishReason: mapStopReason(wr.StopReason),
|
||||
Usage: wr.Usage.toUsage(),
|
||||
Model: m.fullName(),
|
||||
Raw: wr,
|
||||
}
|
||||
for _, block := range wr.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
resp.Parts = append(resp.Parts, llm.TextPart{Text: block.Text})
|
||||
case "tool_use":
|
||||
args := block.Input
|
||||
if len(args) == 0 {
|
||||
args = json.RawMessage("{}")
|
||||
}
|
||||
resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{
|
||||
ID: block.ID,
|
||||
Name: block.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
default:
|
||||
// thinking, redacted_thinking, server-tool blocks, and any
|
||||
// future types are skipped, not surfaced as parts.
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@@ -0,0 +1,774 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// okBody is a minimal successful Messages API response.
|
||||
const okBody = `{
|
||||
"id": "msg_01",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 3, "output_tokens": 5}
|
||||
}`
|
||||
|
||||
// capture records the last request the test server received.
|
||||
type capture struct {
|
||||
mu sync.Mutex
|
||||
hits int
|
||||
method string
|
||||
path string
|
||||
header http.Header
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (c *capture) handler(status int, respBody string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
c.mu.Lock()
|
||||
c.hits++
|
||||
c.method = r.Method
|
||||
c.path = r.URL.Path
|
||||
c.header = r.Header.Clone()
|
||||
c.body = body
|
||||
c.mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_, _ = w.Write([]byte(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
// bodyMap decodes the captured request body for key-presence assertions.
|
||||
func (c *capture) bodyMap(t *testing.T) map[string]any {
|
||||
t.Helper()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(c.body, &m); err != nil {
|
||||
t.Fatalf("decode captured body: %v\nbody: %s", err, c.body)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// newTestProvider spins up an httptest server and a provider pointed at it.
|
||||
func newTestProvider(t *testing.T, h http.Handler, opts ...Option) *Provider {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(h)
|
||||
t.Cleanup(srv.Close)
|
||||
return New(append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, opts...)...)
|
||||
}
|
||||
|
||||
func mustModel(t *testing.T, p *Provider, id string, opts ...llm.ModelOption) llm.Model {
|
||||
t.Helper()
|
||||
m, err := p.Model(id, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Model(%q): %v", id, err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func generate(t *testing.T, m llm.Model, req llm.Request, opts ...llm.Option) *llm.Response {
|
||||
t.Helper()
|
||||
resp, err := m.Generate(context.Background(), req, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func TestRequestHeadersAndPath(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
|
||||
if c.method != http.MethodPost {
|
||||
t.Errorf("method = %q, want POST", c.method)
|
||||
}
|
||||
if c.path != "/v1/messages" {
|
||||
t.Errorf("path = %q, want /v1/messages", c.path)
|
||||
}
|
||||
for header, want := range map[string]string{
|
||||
"x-api-key": "test-key",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
} {
|
||||
if got := c.header.Get(header); got != want {
|
||||
t.Errorf("header %s = %q, want %q", header, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemFold(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
generate(t, m, llm.Request{
|
||||
System: "base prompt",
|
||||
Messages: []llm.Message{
|
||||
llm.SystemText("first extra"),
|
||||
llm.UserText("hi"),
|
||||
llm.SystemText("second extra"),
|
||||
},
|
||||
})
|
||||
|
||||
body := c.bodyMap(t)
|
||||
if got, want := body["system"], "base prompt\n\nfirst extra\n\nsecond extra"; got != want {
|
||||
t.Errorf("system = %q, want %q", got, want)
|
||||
}
|
||||
msgs := body["messages"].([]any)
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("messages length = %d, want 1 (system messages must be excluded)", len(msgs))
|
||||
}
|
||||
if role := msgs[0].(map[string]any)["role"]; role != "user" {
|
||||
t.Errorf("remaining message role = %q, want user", role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoSystemOmitsField(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
|
||||
if _, ok := c.bodyMap(t)["system"]; ok {
|
||||
t.Error("system key present, want omitted when empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxTokens(t *testing.T) {
|
||||
t.Run("default 4096", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if got := c.bodyMap(t)["max_tokens"].(float64); got != 4096 {
|
||||
t.Errorf("max_tokens = %v, want 4096", got)
|
||||
}
|
||||
})
|
||||
t.Run("explicit wins", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
MaxTokens: 123,
|
||||
})
|
||||
if got := c.bodyMap(t)["max_tokens"].(float64); got != 123 {
|
||||
t.Errorf("max_tokens = %v, want 123", got)
|
||||
}
|
||||
})
|
||||
t.Run("WithDefaultMaxTokens overrides default", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithDefaultMaxTokens(99))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if got := c.bodyMap(t)["max_tokens"].(float64); got != 99 {
|
||||
t.Errorf("max_tokens = %v, want 99", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestImageBlock(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
raw := []byte{0x01, 0x02, 0x03}
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{
|
||||
llm.UserParts(llm.Text("look at this"), llm.Image("image/png", raw)),
|
||||
}})
|
||||
|
||||
msgs := c.bodyMap(t)["messages"].([]any)
|
||||
content := msgs[0].(map[string]any)["content"].([]any)
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("content blocks = %d, want 2", len(content))
|
||||
}
|
||||
img := content[1].(map[string]any)
|
||||
if img["type"] != "image" {
|
||||
t.Fatalf("block type = %v, want image", img["type"])
|
||||
}
|
||||
src := img["source"].(map[string]any)
|
||||
if src["type"] != "base64" {
|
||||
t.Errorf("source type = %v, want base64", src["type"])
|
||||
}
|
||||
if src["media_type"] != "image/png" {
|
||||
t.Errorf("media_type = %v, want image/png", src["media_type"])
|
||||
}
|
||||
if want := base64.StdEncoding.EncodeToString(raw); src["data"] != want {
|
||||
t.Errorf("data = %v, want %q", src["data"], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolUseToolResultRoundTrip(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{
|
||||
llm.UserText("weather?"),
|
||||
{
|
||||
Role: llm.RoleAssistant,
|
||||
Parts: []llm.Part{llm.Text("checking")},
|
||||
ToolCalls: []llm.ToolCall{
|
||||
{ID: "toolu_1", Name: "get_weather", Arguments: json.RawMessage(`{"location":"Paris"}`)},
|
||||
{ID: "toolu_2", Name: "noop"}, // empty args must become {}
|
||||
},
|
||||
},
|
||||
llm.ToolResultsMessage(
|
||||
llm.ToolResult{ID: "toolu_1", Name: "get_weather", Content: "72F and sunny"},
|
||||
llm.ToolResult{ID: "toolu_2", Name: "noop", Content: "boom", IsError: true},
|
||||
),
|
||||
}})
|
||||
|
||||
msgs := c.bodyMap(t)["messages"].([]any)
|
||||
if len(msgs) != 3 {
|
||||
t.Fatalf("messages = %d, want 3", len(msgs))
|
||||
}
|
||||
|
||||
asst := msgs[1].(map[string]any)
|
||||
if asst["role"] != "assistant" {
|
||||
t.Errorf("messages[1].role = %v, want assistant", asst["role"])
|
||||
}
|
||||
asstContent := asst["content"].([]any)
|
||||
if len(asstContent) != 3 {
|
||||
t.Fatalf("assistant blocks = %d, want 3 (text + 2 tool_use)", len(asstContent))
|
||||
}
|
||||
tu := asstContent[1].(map[string]any)
|
||||
if tu["type"] != "tool_use" || tu["id"] != "toolu_1" || tu["name"] != "get_weather" {
|
||||
t.Errorf("tool_use block = %v", tu)
|
||||
}
|
||||
if loc := tu["input"].(map[string]any)["location"]; loc != "Paris" {
|
||||
t.Errorf("tool_use input.location = %v, want Paris", loc)
|
||||
}
|
||||
if input := asstContent[2].(map[string]any)["input"].(map[string]any); len(input) != 0 {
|
||||
t.Errorf("empty-args tool_use input = %v, want {}", input)
|
||||
}
|
||||
|
||||
// RoleTool → ONE user message with one tool_result block per result.
|
||||
toolMsg := msgs[2].(map[string]any)
|
||||
if toolMsg["role"] != "user" {
|
||||
t.Errorf("messages[2].role = %v, want user", toolMsg["role"])
|
||||
}
|
||||
results := toolMsg["content"].([]any)
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("tool_result blocks = %d, want 2", len(results))
|
||||
}
|
||||
first := results[0].(map[string]any)
|
||||
if first["type"] != "tool_result" || first["tool_use_id"] != "toolu_1" || first["content"] != "72F and sunny" {
|
||||
t.Errorf("first tool_result = %v", first)
|
||||
}
|
||||
if _, ok := first["is_error"]; ok {
|
||||
t.Error("first tool_result has is_error, want omitted when false")
|
||||
}
|
||||
second := results[1].(map[string]any)
|
||||
if second["tool_use_id"] != "toolu_2" || second["is_error"] != true {
|
||||
t.Errorf("second tool_result = %v, want is_error true", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolDefinitions(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"q":{"type":"string"}},"required":["q"]}`)
|
||||
generate(t, m, llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
Tools: []llm.Tool{
|
||||
{Name: "search", Description: "Search the web.", Parameters: schema},
|
||||
{Name: "ping"}, // nil Parameters → default empty object schema
|
||||
},
|
||||
})
|
||||
|
||||
tools := c.bodyMap(t)["tools"].([]any)
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("tools = %d, want 2", len(tools))
|
||||
}
|
||||
search := tools[0].(map[string]any)
|
||||
if search["name"] != "search" || search["description"] != "Search the web." {
|
||||
t.Errorf("tool[0] = %v", search)
|
||||
}
|
||||
if typ := search["input_schema"].(map[string]any)["type"]; typ != "object" {
|
||||
t.Errorf("input_schema.type = %v, want object", typ)
|
||||
}
|
||||
ping := tools[1].(map[string]any)
|
||||
if typ := ping["input_schema"].(map[string]any)["type"]; typ != "object" {
|
||||
t.Errorf("nil-Parameters input_schema.type = %v, want object", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolChoiceForms(t *testing.T) {
|
||||
cases := []struct {
|
||||
choice string
|
||||
wantType string // "" means the field must be absent
|
||||
wantName string
|
||||
}{
|
||||
{choice: "", wantType: ""},
|
||||
{choice: "auto", wantType: "auto"},
|
||||
{choice: "required", wantType: "any"},
|
||||
{choice: "none", wantType: "none"},
|
||||
{choice: "get_weather", wantType: "tool", wantName: "get_weather"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run("choice="+tc.choice, func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
ToolChoice: tc.choice,
|
||||
})
|
||||
body := c.bodyMap(t)
|
||||
raw, present := body["tool_choice"]
|
||||
if tc.wantType == "" {
|
||||
if present {
|
||||
t.Fatalf("tool_choice present (%v), want omitted", raw)
|
||||
}
|
||||
return
|
||||
}
|
||||
choice := raw.(map[string]any)
|
||||
if choice["type"] != tc.wantType {
|
||||
t.Errorf("tool_choice.type = %v, want %q", choice["type"], tc.wantType)
|
||||
}
|
||||
if tc.wantName != "" && choice["name"] != tc.wantName {
|
||||
t.Errorf("tool_choice.name = %v, want %q", choice["name"], tc.wantName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputConfigFormat(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}},"required":["name"],"additionalProperties":false}`)
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}},
|
||||
llm.WithSchema(schema, "person"))
|
||||
|
||||
body := c.bodyMap(t)
|
||||
format := body["output_config"].(map[string]any)["format"].(map[string]any)
|
||||
if format["type"] != "json_schema" {
|
||||
t.Errorf("output_config.format.type = %v, want json_schema", format["type"])
|
||||
}
|
||||
// Normalize both sides through any → Marshal (sorted keys) to compare.
|
||||
got, _ := json.Marshal(format["schema"])
|
||||
var want any
|
||||
_ = json.Unmarshal(schema, &want)
|
||||
wantJSON, _ := json.Marshal(want)
|
||||
if string(got) != string(wantJSON) {
|
||||
t.Errorf("schema = %s, want %s", got, wantJSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputConfigOmittedWithoutSchema(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if _, ok := c.bodyMap(t)["output_config"]; ok {
|
||||
t.Error("output_config present, want omitted when Schema is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSamplingKnobs(t *testing.T) {
|
||||
t.Run("omitted when unset", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
body := c.bodyMap(t)
|
||||
if _, ok := body["temperature"]; ok {
|
||||
t.Error("temperature present, want omitted when unset")
|
||||
}
|
||||
if _, ok := body["top_p"]; ok {
|
||||
t.Error("top_p present, want omitted when unset")
|
||||
}
|
||||
if _, ok := body["stop_sequences"]; ok {
|
||||
t.Error("stop_sequences present, want omitted when unset")
|
||||
}
|
||||
})
|
||||
t.Run("present when set", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}},
|
||||
llm.WithTemperature(0), // explicit zero must still be sent
|
||||
llm.WithTopP(0.9),
|
||||
llm.WithStopSequences("END"))
|
||||
body := c.bodyMap(t)
|
||||
if got, ok := body["temperature"]; !ok || got.(float64) != 0 {
|
||||
t.Errorf("temperature = %v (present=%v), want explicit 0", got, ok)
|
||||
}
|
||||
if got := body["top_p"].(float64); got != 0.9 {
|
||||
t.Errorf("top_p = %v, want 0.9", got)
|
||||
}
|
||||
stops := body["stop_sequences"].([]any)
|
||||
if len(stops) != 1 || stops[0] != "END" {
|
||||
t.Errorf("stop_sequences = %v, want [END]", stops)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamFieldOmittedOnGenerate(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if _, ok := c.bodyMap(t)["stream"]; ok {
|
||||
t.Error("stream key present on Generate, want omitted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseParse(t *testing.T) {
|
||||
const body = `{
|
||||
"id": "msg_02",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "pondering...", "signature": "sig"},
|
||||
{"type": "text", "text": "I'll check the weather."},
|
||||
{"type": "tool_use", "id": "toolu_9", "name": "get_weather", "input": {"location": "Paris"}}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"usage": {
|
||||
"input_tokens": 3,
|
||||
"output_tokens": 7,
|
||||
"cache_creation_input_tokens": 10,
|
||||
"cache_read_input_tokens": 20
|
||||
}
|
||||
}`
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, body))
|
||||
resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
|
||||
if len(resp.Parts) != 1 {
|
||||
t.Fatalf("parts = %d, want 1 (thinking blocks must be skipped)", len(resp.Parts))
|
||||
}
|
||||
if got := resp.Text(); got != "I'll check the weather." {
|
||||
t.Errorf("text = %q", got)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("tool calls = %d, want 1", len(resp.ToolCalls))
|
||||
}
|
||||
call := resp.ToolCalls[0]
|
||||
if call.ID != "toolu_9" || call.Name != "get_weather" {
|
||||
t.Errorf("tool call = %+v", call)
|
||||
}
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal(call.Arguments, &args); err != nil || args["location"] != "Paris" {
|
||||
t.Errorf("arguments = %s (err %v), want location Paris", call.Arguments, err)
|
||||
}
|
||||
if resp.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("finish = %q, want %q", resp.FinishReason, llm.FinishToolCalls)
|
||||
}
|
||||
// Total real input = input + cache_creation + cache_read.
|
||||
if resp.Usage.InputTokens != 33 || resp.Usage.OutputTokens != 7 {
|
||||
t.Errorf("usage = %+v, want {33 7}", resp.Usage)
|
||||
}
|
||||
if resp.Model != "anthropic/claude-test" {
|
||||
t.Errorf("model = %q, want anthropic/claude-test", resp.Model)
|
||||
}
|
||||
if resp.Raw == nil {
|
||||
t.Error("Raw = nil, want wire response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopReasonMapping(t *testing.T) {
|
||||
cases := map[string]llm.FinishReason{
|
||||
"end_turn": llm.FinishStop,
|
||||
"stop_sequence": llm.FinishStop,
|
||||
"max_tokens": llm.FinishLength,
|
||||
"model_context_window_exceeded": llm.FinishLength,
|
||||
"tool_use": llm.FinishToolCalls,
|
||||
"refusal": llm.FinishContentFilter,
|
||||
"pause_turn": llm.FinishOther,
|
||||
"some_future_reason": llm.FinishOther,
|
||||
}
|
||||
for stop, want := range cases {
|
||||
if got := mapStopReason(stop); got != want {
|
||||
t.Errorf("mapStopReason(%q) = %q, want %q", stop, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPErrorMapping(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
status int
|
||||
body string
|
||||
wantCode string
|
||||
wantClass llm.ErrorClass
|
||||
}{
|
||||
{
|
||||
name: "429 rate limit is transient",
|
||||
status: http.StatusTooManyRequests,
|
||||
body: `{"type":"error","error":{"type":"rate_limit_error","message":"slow down"}}`,
|
||||
wantCode: "rate_limit_error", wantClass: llm.ClassTransient,
|
||||
},
|
||||
{
|
||||
name: "529 overloaded is transient",
|
||||
status: 529,
|
||||
body: `{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`,
|
||||
wantCode: "overloaded_error", wantClass: llm.ClassTransient,
|
||||
},
|
||||
{
|
||||
name: "401 auth is permanent",
|
||||
status: http.StatusUnauthorized,
|
||||
body: `{"type":"error","error":{"type":"authentication_error","message":"invalid x-api-key"}}`,
|
||||
wantCode: "authentication_error", wantClass: llm.ClassPermanent,
|
||||
},
|
||||
{
|
||||
name: "404 is permanent",
|
||||
status: http.StatusNotFound,
|
||||
body: `{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`,
|
||||
wantCode: "not_found_error", wantClass: llm.ClassPermanent,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(tc.status, tc.body))
|
||||
_, err := mustModel(t, p, "claude-test").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil {
|
||||
t.Fatal("Generate succeeded, want error")
|
||||
}
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Provider != "anthropic" || apiErr.Model != "claude-test" {
|
||||
t.Errorf("provider/model = %s/%s", apiErr.Provider, apiErr.Model)
|
||||
}
|
||||
if apiErr.Status != tc.status {
|
||||
t.Errorf("status = %d, want %d", apiErr.Status, tc.status)
|
||||
}
|
||||
if apiErr.Code != tc.wantCode {
|
||||
t.Errorf("code = %q, want %q", apiErr.Code, tc.wantCode)
|
||||
}
|
||||
if apiErr.Message == "" {
|
||||
t.Error("message empty, want provider message")
|
||||
}
|
||||
if got := llm.Classify(err); got != tc.wantClass {
|
||||
t.Errorf("Classify = %v, want %v", got, tc.wantClass)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("404 unwraps to ErrModelNotFound", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusNotFound,
|
||||
`{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`))
|
||||
_, err := mustModel(t, p, "missing").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if !errors.Is(err, llm.ErrModelNotFound) {
|
||||
t.Errorf("errors.Is(err, ErrModelNotFound) = false for %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-JSON error body falls back to raw text", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusBadGateway, "upstream exploded"))
|
||||
_, err := mustModel(t, p, "m").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T, want *llm.APIError", err)
|
||||
}
|
||||
if apiErr.Status != http.StatusBadGateway || apiErr.Message != "upstream exploded" {
|
||||
t.Errorf("apiErr = %+v", apiErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMissingAPIKey(t *testing.T) {
|
||||
t.Setenv("ANTHROPIC_API_KEY", "") // isolate from any real environment
|
||||
|
||||
var c capture
|
||||
srv := httptest.NewServer(c.handler(http.StatusOK, okBody))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := New(WithBaseURL(srv.URL)) // construction must not fail
|
||||
_, err := mustModel(t, p, "claude-test").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" {
|
||||
t.Errorf("apiErr = %+v, want 401 authentication_error", apiErr)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassPermanent {
|
||||
t.Error("missing key must classify permanent")
|
||||
}
|
||||
if c.hits != 0 {
|
||||
t.Errorf("server hits = %d, want 0 (no request without a key)", c.hits)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyFromEnv(t *testing.T) {
|
||||
t.Setenv("ANTHROPIC_API_KEY", "env-key")
|
||||
|
||||
var c capture
|
||||
srv := httptest.NewServer(c.handler(http.StatusOK, okBody))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := New(WithBaseURL(srv.URL))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if got := c.header.Get("x-api-key"); got != "env-key" {
|
||||
t.Errorf("x-api-key = %q, want env-key", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapabilityEnforcement(t *testing.T) {
|
||||
img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) }
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
caps *llm.Capabilities // nil = provider defaults
|
||||
req llm.Request
|
||||
}{
|
||||
{
|
||||
name: "images unsupported",
|
||||
caps: &llm.Capabilities{}, // MaxImagesPerReq 0 = no images
|
||||
req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 4))}},
|
||||
},
|
||||
{
|
||||
name: "too many images",
|
||||
caps: &llm.Capabilities{MaxImagesPerReq: 1},
|
||||
req: llm.Request{Messages: []llm.Message{
|
||||
llm.UserParts(img("image/png", 4), img("image/png", 4)),
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "disallowed MIME",
|
||||
req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/bmp", 4))}},
|
||||
},
|
||||
{
|
||||
name: "image too large",
|
||||
caps: &llm.Capabilities{MaxImagesPerReq: 1, MaxImageBytes: 2},
|
||||
req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 3))}},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
var opts []llm.ModelOption
|
||||
if tc.caps != nil {
|
||||
opts = append(opts, llm.WithCapabilities(*tc.caps))
|
||||
}
|
||||
m := mustModel(t, p, "claude-test", opts...)
|
||||
|
||||
_, err := m.Generate(context.Background(), tc.req)
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Errorf("Generate err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
_, err = m.Stream(context.Background(), tc.req)
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Errorf("Stream err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if c.hits != 0 {
|
||||
t.Errorf("server hits = %d, want 0 (rejected before sending)", c.hits)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("within limits passes", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{
|
||||
Messages: []llm.Message{llm.UserParts(llm.Text("ok"), img("image/jpeg", 16))},
|
||||
})
|
||||
if c.hits != 1 {
|
||||
t.Errorf("server hits = %d, want 1", c.hits)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCompatEndpointWithNameAndBaseURL(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithName("compat"))
|
||||
if p.Name() != "compat" {
|
||||
t.Errorf("Name() = %q, want compat", p.Name())
|
||||
}
|
||||
resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if resp.Model != "compat/claude-test" {
|
||||
t.Errorf("resp.Model = %q, want compat/claude-test", resp.Model)
|
||||
}
|
||||
|
||||
var ec capture
|
||||
pe := newTestProvider(t, ec.handler(http.StatusTooManyRequests,
|
||||
`{"type":"error","error":{"type":"rate_limit_error","message":"x"}}`), WithName("compat"))
|
||||
_, err := mustModel(t, pe, "m").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok || apiErr.Provider != "compat" {
|
||||
t.Errorf("error provider = %v, want compat (err %v)", apiErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapabilitiesDefaultsAndOverrides(t *testing.T) {
|
||||
p := New(WithAPIKey("k"))
|
||||
m := mustModel(t, p, "m")
|
||||
caps := m.Capabilities()
|
||||
if !caps.SupportsTools || !caps.SupportsStructured || !caps.SupportsStreaming {
|
||||
t.Errorf("default feature flags = %+v, want all true", caps)
|
||||
}
|
||||
if caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 10<<20 || caps.MaxImageDimension != 8000 {
|
||||
t.Errorf("default image limits = %+v", caps)
|
||||
}
|
||||
wantMIME := []string{"image/jpeg", "image/png", "image/gif", "image/webp"}
|
||||
if len(caps.AllowedImageMIME) != len(wantMIME) {
|
||||
t.Fatalf("AllowedImageMIME = %v, want %v", caps.AllowedImageMIME, wantMIME)
|
||||
}
|
||||
for i, mime := range wantMIME {
|
||||
if caps.AllowedImageMIME[i] != mime {
|
||||
t.Errorf("AllowedImageMIME[%d] = %q, want %q", i, caps.AllowedImageMIME[i], mime)
|
||||
}
|
||||
}
|
||||
|
||||
custom := llm.Capabilities{SupportsStreaming: true, MaxImagesPerReq: 1}
|
||||
p2 := New(WithAPIKey("k"), WithDefaultCapabilities(custom))
|
||||
if got := mustModel(t, p2, "m").Capabilities(); got.MaxImagesPerReq != 1 || got.SupportsTools {
|
||||
t.Errorf("WithDefaultCapabilities not applied: %+v", got)
|
||||
}
|
||||
|
||||
perModel := llm.Capabilities{SupportsTools: true}
|
||||
if got := mustModel(t, p2, "m", llm.WithCapabilities(perModel)).Capabilities(); !got.SupportsTools || got.MaxImagesPerReq != 0 {
|
||||
t.Errorf("per-model capabilities not applied: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportErrorNotAPIError(t *testing.T) {
|
||||
// Point at a server that is immediately closed: the connection failure
|
||||
// must surface as a wrapped transport error, not *llm.APIError.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
url := srv.URL
|
||||
srv.Close()
|
||||
|
||||
p := New(WithAPIKey("k"), WithBaseURL(url))
|
||||
_, err := mustModel(t, p, "m").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil {
|
||||
t.Fatal("Generate succeeded, want transport error")
|
||||
}
|
||||
if _, ok := errors.AsType[*llm.APIError](err); ok {
|
||||
t.Errorf("transport error wrapped in APIError: %v", err)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassTransient {
|
||||
t.Errorf("connection failure must classify transient: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// wireStreamEvent is the union of all SSE data payloads the Messages API
|
||||
// emits. Dispatch is on Type (the data always carries one), so the SSE
|
||||
// "event:" line is informational only.
|
||||
type wireStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index"`
|
||||
|
||||
// message_start
|
||||
Message *struct {
|
||||
Usage wireUsage `json:"usage"`
|
||||
} `json:"message"`
|
||||
|
||||
// content_block_start
|
||||
ContentBlock *struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"content_block"`
|
||||
|
||||
// content_block_delta / message_delta
|
||||
Delta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
PartialJSON string `json:"partial_json"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
} `json:"delta"`
|
||||
|
||||
// message_delta
|
||||
Usage *wireUsage `json:"usage"`
|
||||
|
||||
// error
|
||||
Error *struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// stream adapts the Messages API SSE stream to llm.Stream.
|
||||
//
|
||||
// Why single-threaded pull (no reader goroutine): Next is already the
|
||||
// consumer's pull point, so parsing lazily inside Next keeps cancellation,
|
||||
// buffering, and error propagation trivial — Close just closes the body and
|
||||
// the next read fails.
|
||||
type stream struct {
|
||||
provider string
|
||||
model string
|
||||
full string // provider/model
|
||||
body io.ReadCloser
|
||||
scanner *bufio.Scanner
|
||||
|
||||
// accumulated response
|
||||
parts []llm.Part
|
||||
toolCalls []llm.ToolCall
|
||||
usage llm.Usage
|
||||
finish llm.FinishReason
|
||||
|
||||
// current content block state
|
||||
blockType string
|
||||
textBuf strings.Builder
|
||||
toolID string
|
||||
toolName string
|
||||
argsBuf strings.Builder
|
||||
|
||||
done bool // final Response event emitted
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func newStream(m *model, body io.ReadCloser) *stream {
|
||||
sc := bufio.NewScanner(body)
|
||||
// Why a large limit: one SSE line carries one whole delta; default 64K
|
||||
// can be exceeded by large structured-output or tool-argument deltas.
|
||||
sc.Buffer(make([]byte, 0, 64*1024), 10*1024*1024)
|
||||
return &stream{
|
||||
provider: m.provider.name,
|
||||
model: m.id,
|
||||
full: m.fullName(),
|
||||
body: body,
|
||||
scanner: sc,
|
||||
finish: llm.FinishOther,
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements llm.Stream. Safe to call at any time and more than once.
|
||||
func (s *stream) Close() error {
|
||||
s.closeOnce.Do(func() { s.closeErr = s.body.Close() })
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
// Next implements llm.Stream. It emits TextDelta fragments as they arrive,
|
||||
// fully-assembled ToolCalls at content_block_stop, exactly one final
|
||||
// Response event at message_stop, then io.EOF.
|
||||
func (s *stream) Next() (llm.StreamEvent, error) {
|
||||
if s.done {
|
||||
return llm.StreamEvent{}, io.EOF
|
||||
}
|
||||
for {
|
||||
data, err := s.nextData()
|
||||
if err != nil {
|
||||
return llm.StreamEvent{}, err
|
||||
}
|
||||
var ev wireStreamEvent
|
||||
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
||||
return llm.StreamEvent{}, fmt.Errorf("%s: decode stream event: %w", s.provider, err)
|
||||
}
|
||||
|
||||
switch ev.Type {
|
||||
case "message_start":
|
||||
if ev.Message != nil {
|
||||
s.usage = ev.Message.Usage.toUsage()
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
s.blockType = ""
|
||||
s.textBuf.Reset()
|
||||
s.argsBuf.Reset()
|
||||
if ev.ContentBlock != nil {
|
||||
s.blockType = ev.ContentBlock.Type
|
||||
if s.blockType == "tool_use" {
|
||||
s.toolID = ev.ContentBlock.ID
|
||||
s.toolName = ev.ContentBlock.Name
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
switch ev.Delta.Type {
|
||||
case "text_delta":
|
||||
s.textBuf.WriteString(ev.Delta.Text)
|
||||
return llm.StreamEvent{TextDelta: ev.Delta.Text}, nil
|
||||
case "input_json_delta":
|
||||
// Buffer partial JSON internally; consumers never see it.
|
||||
s.argsBuf.WriteString(ev.Delta.PartialJSON)
|
||||
default:
|
||||
// thinking_delta / signature_delta: tolerated, skipped.
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
if event, ok := s.finishBlock(); ok {
|
||||
return event, nil
|
||||
}
|
||||
|
||||
case "message_delta":
|
||||
if ev.Delta.StopReason != "" {
|
||||
s.finish = mapStopReason(ev.Delta.StopReason)
|
||||
}
|
||||
if ev.Usage != nil {
|
||||
// Output tokens arrive cumulatively in the final delta;
|
||||
// input tokens were reported in message_start.
|
||||
s.usage.OutputTokens = ev.Usage.OutputTokens
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
s.done = true
|
||||
return llm.StreamEvent{Response: &llm.Response{
|
||||
Parts: s.parts,
|
||||
ToolCalls: s.toolCalls,
|
||||
FinishReason: s.finish,
|
||||
Usage: s.usage,
|
||||
Model: s.full,
|
||||
}}, nil
|
||||
|
||||
case "error":
|
||||
// Mid-stream failure after the 200 (e.g. overloaded_error).
|
||||
// Status stays 0: there is no HTTP status for it, and the
|
||||
// default Classify treats it as transient, which fits overload.
|
||||
apiErr := &llm.APIError{Provider: s.provider, Model: s.model}
|
||||
if ev.Error != nil {
|
||||
apiErr.Code = ev.Error.Type
|
||||
apiErr.Message = ev.Error.Message
|
||||
}
|
||||
return llm.StreamEvent{}, apiErr
|
||||
|
||||
default:
|
||||
// ping and unknown event types: ignored.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finishBlock closes out the current content block, appending its result to
|
||||
// the accumulated response. Tool-use blocks produce a stream event.
|
||||
func (s *stream) finishBlock() (llm.StreamEvent, bool) {
|
||||
defer func() {
|
||||
s.blockType = ""
|
||||
s.textBuf.Reset()
|
||||
s.argsBuf.Reset()
|
||||
}()
|
||||
switch s.blockType {
|
||||
case "text":
|
||||
if s.textBuf.Len() > 0 {
|
||||
s.parts = append(s.parts, llm.TextPart{Text: s.textBuf.String()})
|
||||
}
|
||||
case "tool_use":
|
||||
args := s.argsBuf.String()
|
||||
if args == "" {
|
||||
// A tool called with no arguments streams zero (or empty)
|
||||
// input_json_delta fragments; the canonical form is "{}".
|
||||
args = "{}"
|
||||
}
|
||||
call := llm.ToolCall{ID: s.toolID, Name: s.toolName, Arguments: json.RawMessage(args)}
|
||||
s.toolCalls = append(s.toolCalls, call)
|
||||
return llm.StreamEvent{ToolCall: &call}, true
|
||||
}
|
||||
return llm.StreamEvent{}, false
|
||||
}
|
||||
|
||||
// nextData reads SSE lines until one complete event's data is assembled
|
||||
// (multi-line data fields are joined with "\n" per the SSE spec). "event:"
|
||||
// lines and comments are ignored; dispatch keys off the JSON "type" field.
|
||||
func (s *stream) nextData() (string, error) {
|
||||
var data strings.Builder
|
||||
for s.scanner.Scan() {
|
||||
line := s.scanner.Text()
|
||||
if line == "" {
|
||||
if data.Len() > 0 {
|
||||
return data.String(), nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
if rest, ok := strings.CutPrefix(line, "data:"); ok {
|
||||
if data.Len() > 0 {
|
||||
data.WriteByte('\n')
|
||||
}
|
||||
data.WriteString(strings.TrimPrefix(rest, " "))
|
||||
}
|
||||
}
|
||||
if err := s.scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("%s: read stream: %w", s.provider, err)
|
||||
}
|
||||
if data.Len() > 0 {
|
||||
return data.String(), nil
|
||||
}
|
||||
// EOF before message_stop: the connection dropped mid-response.
|
||||
return "", fmt.Errorf("%s: stream ended before message_stop: %w", s.provider, io.ErrUnexpectedEOF)
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// sse joins data payloads into an SSE body. Each payload becomes one event
|
||||
// ("event:" name derived from the JSON type field is what the real API
|
||||
// sends, but the client dispatches on the data, so a generic name is fine).
|
||||
func sse(payloads ...string) string {
|
||||
var b strings.Builder
|
||||
for _, p := range payloads {
|
||||
b.WriteString("event: event\n")
|
||||
b.WriteString("data: ")
|
||||
b.WriteString(p)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func sseServer(t *testing.T, c *capture, body string) *Provider {
|
||||
t.Helper()
|
||||
return newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
c.mu.Lock()
|
||||
c.hits++
|
||||
c.header = r.Header.Clone()
|
||||
c.body = raw
|
||||
c.mu.Unlock()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
}
|
||||
|
||||
// drain collects all events until io.EOF, failing the test on any error.
|
||||
func drain(t *testing.T, s llm.Stream) []llm.StreamEvent {
|
||||
t.Helper()
|
||||
var events []llm.StreamEvent
|
||||
for {
|
||||
ev, err := s.Next()
|
||||
if err == io.EOF {
|
||||
return events
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Next: %v", err)
|
||||
}
|
||||
events = append(events, ev)
|
||||
}
|
||||
}
|
||||
|
||||
func openStream(t *testing.T, p *Provider, modelID string) llm.Stream {
|
||||
t.Helper()
|
||||
s, err := mustModel(t, p, modelID).Stream(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func TestStreamTextDeltas(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"m","usage":{"input_tokens":10,"cache_creation_input_tokens":2,"cache_read_input_tokens":3,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"ping"}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hel"}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"lo"}}`,
|
||||
`{"type":"content_block_stop","index":0}`,
|
||||
`{"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":" world"}}`,
|
||||
`{"type":"content_block_stop","index":1}`,
|
||||
`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":12}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
s := openStream(t, p, "claude-test")
|
||||
events := drain(t, s)
|
||||
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("events = %d, want 4 (3 deltas + final response)", len(events))
|
||||
}
|
||||
for i, want := range []string{"Hel", "lo", " world"} {
|
||||
if events[i].TextDelta != want {
|
||||
t.Errorf("event[%d].TextDelta = %q, want %q", i, events[i].TextDelta, want)
|
||||
}
|
||||
}
|
||||
|
||||
final := events[3].Response
|
||||
if final == nil {
|
||||
t.Fatal("last event has no Response")
|
||||
}
|
||||
if len(final.Parts) != 2 {
|
||||
t.Fatalf("final parts = %d, want 2 (one per text block)", len(final.Parts))
|
||||
}
|
||||
if final.Text() != "Hello world" {
|
||||
t.Errorf("final text = %q, want %q", final.Text(), "Hello world")
|
||||
}
|
||||
if final.FinishReason != llm.FinishStop {
|
||||
t.Errorf("finish = %q, want stop", final.FinishReason)
|
||||
}
|
||||
// Input = 10+2+3 from message_start; output = 12 from message_delta.
|
||||
if final.Usage.InputTokens != 15 || final.Usage.OutputTokens != 12 {
|
||||
t.Errorf("usage = %+v, want {15 12}", final.Usage)
|
||||
}
|
||||
if final.Model != "anthropic/claude-test" {
|
||||
t.Errorf("model = %q, want anthropic/claude-test", final.Model)
|
||||
}
|
||||
|
||||
// Past EOF, Next keeps returning io.EOF.
|
||||
if _, err := s.Next(); err != io.EOF {
|
||||
t.Errorf("Next after EOF = %v, want io.EOF", err)
|
||||
}
|
||||
|
||||
// The request must carry "stream": true.
|
||||
if streamFlag := c.bodyMap(t)["stream"]; streamFlag != true {
|
||||
t.Errorf("request stream = %v, want true", streamFlag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamToolCallAssembly(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":8,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Checking."}}`,
|
||||
`{"type":"content_block_stop","index":0}`,
|
||||
`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_9","name":"get_weather","input":{}}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":""}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\":"}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" \"San Francisco, CA\"}"}}`,
|
||||
`{"type":"content_block_stop","index":1}`,
|
||||
`{"type":"content_block_start","index":2,"content_block":{"type":"tool_use","id":"toolu_10","name":"noop","input":{}}}`,
|
||||
`{"type":"content_block_stop","index":2}`,
|
||||
`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":21}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
events := drain(t, openStream(t, p, "claude-test"))
|
||||
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("events = %d, want 4 (text, 2 tool calls, final)", len(events))
|
||||
}
|
||||
if events[0].TextDelta != "Checking." {
|
||||
t.Errorf("event[0] = %+v, want text delta", events[0])
|
||||
}
|
||||
|
||||
call := events[1].ToolCall
|
||||
if call == nil {
|
||||
t.Fatal("event[1] has no ToolCall")
|
||||
}
|
||||
if call.ID != "toolu_9" || call.Name != "get_weather" {
|
||||
t.Errorf("tool call = %+v", call)
|
||||
}
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal(call.Arguments, &args); err != nil {
|
||||
t.Fatalf("assembled arguments invalid JSON: %v (%s)", err, call.Arguments)
|
||||
}
|
||||
if args["location"] != "San Francisco, CA" {
|
||||
t.Errorf("arguments = %v", args)
|
||||
}
|
||||
|
||||
empty := events[2].ToolCall
|
||||
if empty == nil || empty.ID != "toolu_10" {
|
||||
t.Fatalf("event[2] = %+v, want second tool call", events[2])
|
||||
}
|
||||
if string(empty.Arguments) != "{}" {
|
||||
t.Errorf("empty tool call arguments = %s, want {}", empty.Arguments)
|
||||
}
|
||||
|
||||
final := events[3].Response
|
||||
if final == nil {
|
||||
t.Fatal("last event has no Response")
|
||||
}
|
||||
if len(final.ToolCalls) != 2 {
|
||||
t.Errorf("final tool calls = %d, want 2", len(final.ToolCalls))
|
||||
}
|
||||
if final.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("finish = %q, want tool_calls", final.FinishReason)
|
||||
}
|
||||
if final.Text() != "Checking." {
|
||||
t.Errorf("final text = %q", final.Text())
|
||||
}
|
||||
if final.Usage.InputTokens != 8 || final.Usage.OutputTokens != 21 {
|
||||
t.Errorf("usage = %+v, want {8 21}", final.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamThinkingSkipped(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":"hmm"}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":"sig"}}`,
|
||||
`{"type":"content_block_stop","index":0}`,
|
||||
`{"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":"hi"}}`,
|
||||
`{"type":"content_block_stop","index":1}`,
|
||||
`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":2}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
events := drain(t, openStream(t, p, "claude-test"))
|
||||
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("events = %d, want 2 (thinking produces none)", len(events))
|
||||
}
|
||||
if events[0].TextDelta != "hi" {
|
||||
t.Errorf("event[0] = %+v, want TextDelta hi", events[0])
|
||||
}
|
||||
final := events[1].Response
|
||||
if final == nil || len(final.Parts) != 1 || final.Text() != "hi" {
|
||||
t.Errorf("final = %+v, want single text part %q", final, "hi")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamMidStreamError(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"par"}}`,
|
||||
`{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
s := openStream(t, p, "claude-test")
|
||||
|
||||
ev, err := s.Next()
|
||||
if err != nil || ev.TextDelta != "par" {
|
||||
t.Fatalf("first Next = (%+v, %v), want text delta", ev, err)
|
||||
}
|
||||
_, err = s.Next()
|
||||
if err == nil {
|
||||
t.Fatal("second Next succeeded, want mid-stream error")
|
||||
}
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Code != "overloaded_error" || apiErr.Message != "Overloaded" || apiErr.Status != 0 {
|
||||
t.Errorf("apiErr = %+v", apiErr)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassTransient {
|
||||
t.Error("overloaded_error must classify transient")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHTTPErrorBeforeEvents(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(529,
|
||||
`{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`))
|
||||
_, err := mustModel(t, p, "claude-test").Stream(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil {
|
||||
t.Fatal("Stream succeeded, want APIError before any events")
|
||||
}
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != 529 || apiErr.Code != "overloaded_error" {
|
||||
t.Errorf("apiErr = %+v, want 529 overloaded_error", apiErr)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassTransient {
|
||||
t.Error("529 must classify transient")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTruncatedBody(t *testing.T) {
|
||||
// Stream ends without message_stop: Next must surface unexpected EOF.
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
s := openStream(t, p, "claude-test")
|
||||
|
||||
if ev, err := s.Next(); err != nil || ev.TextDelta != "hi" {
|
||||
t.Fatalf("first Next = (%+v, %v)", ev, err)
|
||||
}
|
||||
if _, err := s.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Errorf("Next on truncated stream = %v, want io.ErrUnexpectedEOF", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamCloseIsSafe(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}`,
|
||||
`{"type":"content_block_stop","index":0}`,
|
||||
`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":2}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
s := openStream(t, p, "claude-test")
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("first Close: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("second Close: %v", err)
|
||||
}
|
||||
|
||||
// After EOF, Close is still fine.
|
||||
s2 := openStream(t, p, "claude-test")
|
||||
drain(t, s2)
|
||||
if err := s2.Close(); err != nil {
|
||||
t.Errorf("Close after EOF: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// Wire types mirror the Messages API JSON shapes (June 2026 docs). Only the
|
||||
// fields majordomo uses are modeled; unknown response fields are ignored by
|
||||
// encoding/json.
|
||||
|
||||
type wireRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []wireMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []wireTool `json:"tools,omitempty"`
|
||||
ToolChoice *wireToolChoice `json:"tool_choice,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
OutputConfig *wireOutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
type wireMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []wireBlock `json:"content"`
|
||||
}
|
||||
|
||||
// wireBlock is a request-side content block. Exactly one shape is populated
|
||||
// per block, keyed by Type: text, image, tool_use, or tool_result.
|
||||
type wireBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// text
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// image
|
||||
Source *wireImageSource `json:"source,omitempty"`
|
||||
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// tool_result
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
type wireImageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type wireTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
type wireToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type wireOutputConfig struct {
|
||||
Format *wireOutputFormat `json:"format,omitempty"`
|
||||
}
|
||||
|
||||
type wireOutputFormat struct {
|
||||
Type string `json:"type"`
|
||||
Schema json.RawMessage `json:"schema"`
|
||||
}
|
||||
|
||||
type wireResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Content []wireRespBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Usage wireUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type wireRespBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Input json.RawMessage `json:"input"`
|
||||
}
|
||||
|
||||
type wireUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
}
|
||||
|
||||
// toUsage maps API token accounting onto the canonical Usage. Why the sum:
|
||||
// the API's input_tokens counts only tokens after the last cache breakpoint;
|
||||
// real total input is input + cache_creation + cache_read.
|
||||
func (u wireUsage) toUsage() llm.Usage {
|
||||
return llm.Usage{
|
||||
InputTokens: u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens,
|
||||
OutputTokens: u.OutputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
type wireErrorEnvelope struct {
|
||||
Type string `json:"type"`
|
||||
Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// buildWireRequest translates the canonical request into the Messages API
|
||||
// shape.
|
||||
//
|
||||
// Request.ReasoningEffort is intentionally ignored: the current Messages API
|
||||
// has no low/medium/high reasoning knob — thinking is adaptive on current
|
||||
// models, and the legacy budget/disable parameters 400 on them. The llm
|
||||
// contract says providers ignore ReasoningEffort where no mapping exists.
|
||||
//
|
||||
// Request.SchemaName is likewise ignored: output_config.format takes a bare
|
||||
// schema with no name field.
|
||||
func buildWireRequest(modelID string, req llm.Request, defaultMax int, stream bool) wireRequest {
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
// max_tokens is required by the API; 0 means "provider default".
|
||||
maxTokens = defaultMax
|
||||
}
|
||||
|
||||
wr := wireRequest{
|
||||
Model: modelID,
|
||||
MaxTokens: maxTokens,
|
||||
System: foldSystem(req),
|
||||
Messages: toWireMessages(req.Messages),
|
||||
Stream: stream,
|
||||
Tools: toWireTools(req.Tools),
|
||||
ToolChoice: toWireToolChoice(req.ToolChoice),
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
StopSequences: req.StopSequences,
|
||||
}
|
||||
if req.Schema != nil {
|
||||
wr.OutputConfig = &wireOutputConfig{Format: &wireOutputFormat{
|
||||
Type: "json_schema",
|
||||
Schema: req.Schema,
|
||||
}}
|
||||
}
|
||||
return wr
|
||||
}
|
||||
|
||||
// foldSystem joins Request.System with the text of every RoleSystem message
|
||||
// (System field first, original order, "\n\n" separators). Why: the API
|
||||
// takes the system prompt as a top-level field and rejects system roles
|
||||
// inside messages, so canonical RoleSystem messages must fold in here.
|
||||
func foldSystem(req llm.Request) string {
|
||||
parts := make([]string, 0, 2)
|
||||
if req.System != "" {
|
||||
parts = append(parts, req.System)
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
if msg.Role != llm.RoleSystem {
|
||||
continue
|
||||
}
|
||||
if text := msg.Text(); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
func toWireMessages(msgs []llm.Message) []wireMessage {
|
||||
out := make([]wireMessage, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case llm.RoleSystem:
|
||||
// Folded into the top-level system field by foldSystem.
|
||||
continue
|
||||
|
||||
case llm.RoleTool:
|
||||
// One user message carrying one tool_result block per result.
|
||||
blocks := make([]wireBlock, 0, len(msg.ToolResults))
|
||||
for _, res := range msg.ToolResults {
|
||||
blocks = append(blocks, wireBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: res.ID,
|
||||
Content: res.Content,
|
||||
IsError: res.IsError,
|
||||
})
|
||||
}
|
||||
out = append(out, wireMessage{Role: "user", Content: blocks})
|
||||
|
||||
case llm.RoleAssistant:
|
||||
blocks := toWireBlocks(msg.Parts)
|
||||
for _, call := range msg.ToolCalls {
|
||||
args := call.Arguments
|
||||
if len(args) == 0 {
|
||||
// The API requires input to be a JSON object.
|
||||
args = json.RawMessage("{}")
|
||||
}
|
||||
blocks = append(blocks, wireBlock{
|
||||
Type: "tool_use",
|
||||
ID: call.ID,
|
||||
Name: call.Name,
|
||||
Input: args,
|
||||
})
|
||||
}
|
||||
out = append(out, wireMessage{Role: "assistant", Content: blocks})
|
||||
|
||||
default: // llm.RoleUser and anything unrecognized
|
||||
out = append(out, wireMessage{Role: "user", Content: toWireBlocks(msg.Parts)})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toWireBlocks(parts []llm.Part) []wireBlock {
|
||||
blocks := make([]wireBlock, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
switch p := part.(type) {
|
||||
case llm.TextPart:
|
||||
blocks = append(blocks, wireBlock{Type: "text", Text: p.Text})
|
||||
case llm.ImagePart:
|
||||
blocks = append(blocks, wireBlock{Type: "image", Source: &wireImageSource{
|
||||
Type: "base64",
|
||||
MediaType: p.MIME,
|
||||
Data: base64.StdEncoding.EncodeToString(p.Data),
|
||||
}})
|
||||
}
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
func toWireTools(tools []llm.Tool) []wireTool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]wireTool, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
schema := t.Parameters
|
||||
if len(schema) == 0 {
|
||||
// Why: input_schema is required by the API; a tool with no
|
||||
// arguments still needs an (empty) object schema.
|
||||
schema = json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
}
|
||||
out = append(out, wireTool{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// toWireToolChoice maps the canonical tool-choice policy. "" omits the field
|
||||
// (API default is auto); any value other than the three keywords names the
|
||||
// one tool the model must call.
|
||||
func toWireToolChoice(choice string) *wireToolChoice {
|
||||
switch choice {
|
||||
case "":
|
||||
return nil
|
||||
case "auto":
|
||||
return &wireToolChoice{Type: "auto"}
|
||||
case "required":
|
||||
return &wireToolChoice{Type: "any"}
|
||||
case "none":
|
||||
return &wireToolChoice{Type: "none"}
|
||||
default:
|
||||
return &wireToolChoice{Type: "tool", Name: choice}
|
||||
}
|
||||
}
|
||||
|
||||
// mapStopReason maps the API stop_reason onto the canonical FinishReason.
|
||||
func mapStopReason(stop string) llm.FinishReason {
|
||||
switch stop {
|
||||
case "end_turn", "stop_sequence":
|
||||
return llm.FinishStop
|
||||
case "max_tokens", "model_context_window_exceeded":
|
||||
return llm.FinishLength
|
||||
case "tool_use":
|
||||
return llm.FinishToolCalls
|
||||
case "refusal":
|
||||
return llm.FinishContentFilter
|
||||
default:
|
||||
// pause_turn and any future provider-specific reasons.
|
||||
return llm.FinishOther
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
// Package ollama implements majordomo's provider contract over Ollama's
|
||||
// native chat API (POST {base}/api/chat), targeted at three backends that
|
||||
// share one wire protocol:
|
||||
//
|
||||
// - a local Ollama instance (preset Local: OLLAMA_HOST or
|
||||
// http://localhost:11434, no auth),
|
||||
// - Ollama Cloud (preset Cloud: https://ollama.com, bearer key from
|
||||
// OLLAMA_API_KEY), and
|
||||
// - foreman, Steve's native-Ollama queue daemon (preset Foreman: explicit
|
||||
// base URL + bearer token).
|
||||
//
|
||||
// Wire surface verified against docs.ollama.com and ollama/ollama
|
||||
// docs/api.md + api/types.go (June 2026): NDJSON streaming (stream defaults
|
||||
// true server-side — Generate always sends stream:false explicitly);
|
||||
// tool_calls carry arguments as a JSON OBJECT (not a string); tool results
|
||||
// return as {"role":"tool","content",...,"tool_name"}; structured output
|
||||
// via "format" (a full JSON-schema object); thinking via the bool-or-string
|
||||
// "think" field; errors as {"error":"message"} with a non-2xx status.
|
||||
//
|
||||
// foreman deviation (verified in its source): sync /api/chat does not
|
||||
// stream — a stream:true request yields ONE buffered application/json
|
||||
// object. The NDJSON reader here handles that transparently (a single JSON
|
||||
// line parses as the final chunk).
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// DefaultLocalBaseURL is the default base URL for a locally-running Ollama.
|
||||
const DefaultLocalBaseURL = "http://localhost:11434"
|
||||
|
||||
// DefaultCloudBaseURL is the base URL for Ollama Cloud.
|
||||
const DefaultCloudBaseURL = "https://ollama.com"
|
||||
|
||||
// defaultCapabilities is the conservative provider-wide default; individual
|
||||
// models (e.g. high-resolution vision tags) override via llm.WithCapabilities.
|
||||
var defaultCapabilities = llm.Capabilities{
|
||||
SupportsTools: true,
|
||||
SupportsStructured: true,
|
||||
SupportsStreaming: true,
|
||||
MaxImagesPerReq: 8,
|
||||
MaxImageBytes: 20 << 20,
|
||||
MaxImageDimension: 2048,
|
||||
AllowedImageMIME: []string{"image/jpeg", "image/png"},
|
||||
}
|
||||
|
||||
// Provider is a native-Ollama chat client bound to one base URL.
|
||||
type Provider struct {
|
||||
name string
|
||||
baseURL string
|
||||
token string
|
||||
client *http.Client
|
||||
caps llm.Capabilities
|
||||
}
|
||||
|
||||
// Option configures the provider.
|
||||
type Option func(*Provider)
|
||||
|
||||
// WithName overrides the registry name (default "ollama").
|
||||
func WithName(name string) Option { return func(p *Provider) { p.name = name } }
|
||||
|
||||
// WithBaseURL sets the backend base URL (scheme://host[:port][/path]).
|
||||
func WithBaseURL(u string) Option {
|
||||
return func(p *Provider) { p.baseURL = strings.TrimRight(u, "/") }
|
||||
}
|
||||
|
||||
// WithToken sets the bearer token (Ollama Cloud key / foreman token).
|
||||
// Empty means no Authorization header (local mode).
|
||||
func WithToken(token string) Option { return func(p *Provider) { p.token = token } }
|
||||
|
||||
// WithHTTPClient overrides the HTTP client (proxies, test TLS, timeouts —
|
||||
// note foreman sync chat long-polls; prefer context deadlines over client
|
||||
// timeouts).
|
||||
func WithHTTPClient(c *http.Client) Option { return func(p *Provider) { p.client = c } }
|
||||
|
||||
// WithDefaultCapabilities overrides the provider-wide default capabilities.
|
||||
func WithDefaultCapabilities(caps llm.Capabilities) Option {
|
||||
return func(p *Provider) { p.caps = caps }
|
||||
}
|
||||
|
||||
// New creates a generic native-Ollama provider. Most callers want one of
|
||||
// the presets (Local, Cloud, Foreman) or an LLM_* env DSN instead.
|
||||
// Construction never fails; a missing base URL surfaces at request time.
|
||||
func New(opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
name: "ollama",
|
||||
client: &http.Client{},
|
||||
caps: defaultCapabilities,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Local returns the local-Ollama preset: name "ollama", base URL from
|
||||
// OLLAMA_HOST (normalized per Ollama conventions) or localhost:11434.
|
||||
func Local(opts ...Option) *Provider {
|
||||
base := DefaultLocalBaseURL
|
||||
if h := os.Getenv("OLLAMA_HOST"); h != "" {
|
||||
base = NormalizeHost(h)
|
||||
}
|
||||
return New(append([]Option{WithBaseURL(base)}, opts...)...)
|
||||
}
|
||||
|
||||
// Cloud returns the Ollama Cloud preset: name "ollama-cloud",
|
||||
// https://ollama.com, bearer key from OLLAMA_API_KEY.
|
||||
func Cloud(opts ...Option) *Provider {
|
||||
return New(append([]Option{
|
||||
WithName("ollama-cloud"),
|
||||
WithBaseURL(DefaultCloudBaseURL),
|
||||
WithToken(os.Getenv("OLLAMA_API_KEY")),
|
||||
}, opts...)...)
|
||||
}
|
||||
|
||||
// Foreman returns a foreman preset bound to the given daemon.
|
||||
func Foreman(baseURL, token string, opts ...Option) *Provider {
|
||||
return New(append([]Option{
|
||||
WithName("foreman"),
|
||||
WithBaseURL(baseURL),
|
||||
WithToken(token),
|
||||
}, opts...)...)
|
||||
}
|
||||
|
||||
// NormalizeHost turns an OLLAMA_HOST-style value into a base URL:
|
||||
// "host" → http://host:11434, "host:port" → http://host:port, full URLs
|
||||
// pass through (trailing slash trimmed).
|
||||
func NormalizeHost(h string) string {
|
||||
h = strings.TrimRight(strings.TrimSpace(h), "/")
|
||||
if strings.Contains(h, "://") {
|
||||
return h
|
||||
}
|
||||
if !strings.Contains(h, ":") {
|
||||
h += ":11434"
|
||||
}
|
||||
return "http://" + h
|
||||
}
|
||||
|
||||
// Name implements llm.Provider.
|
||||
func (p *Provider) Name() string { return p.name }
|
||||
|
||||
// BaseURL reports the configured backend base URL (diagnostics).
|
||||
func (p *Provider) BaseURL() string { return p.baseURL }
|
||||
|
||||
// Model implements llm.Provider; the id passes through verbatim.
|
||||
func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) {
|
||||
cfg := llm.ApplyModelOptions(opts)
|
||||
caps := p.caps
|
||||
if cfg.Capabilities != nil {
|
||||
caps = *cfg.Capabilities
|
||||
}
|
||||
return &model{provider: p, id: id, caps: caps}, nil
|
||||
}
|
||||
|
||||
// checkReady reports a usable configuration (a base URL is the only hard
|
||||
// requirement; auth problems surface as 401s from the backend).
|
||||
func (p *Provider) checkReady() error {
|
||||
if p.baseURL == "" {
|
||||
return fmt.Errorf("ollama provider %q: no base URL configured (set one via the preset, WithBaseURL, or an LLM_* env DSN)", p.name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,492 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// capture spins up an httptest server that records the request and replies
|
||||
// with the given handler.
|
||||
type captured struct {
|
||||
auth string
|
||||
contentType string
|
||||
path string
|
||||
body map[string]any
|
||||
raw []byte
|
||||
}
|
||||
|
||||
func serve(t *testing.T, status int, respond func(w http.ResponseWriter)) (*Provider, *captured) {
|
||||
t.Helper()
|
||||
cap := &captured{}
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cap.auth = r.Header.Get("Authorization")
|
||||
cap.contentType = r.Header.Get("Content-Type")
|
||||
cap.path = r.URL.Path
|
||||
cap.raw, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(cap.raw, &cap.body)
|
||||
w.WriteHeader(status)
|
||||
respond(w)
|
||||
}))
|
||||
t.Cleanup(ts.Close)
|
||||
return New(WithBaseURL(ts.URL), WithToken("test-token")), cap
|
||||
}
|
||||
|
||||
func jsonReply(obj string) func(w http.ResponseWriter) {
|
||||
return func(w http.ResponseWriter) { _, _ = io.WriteString(w, obj) }
|
||||
}
|
||||
|
||||
func basicRequest() llm.Request {
|
||||
return llm.Request{Messages: []llm.Message{llm.UserText("hi")}}
|
||||
}
|
||||
|
||||
func TestGenerateRoundTrip(t *testing.T) {
|
||||
p, cap := serve(t, 200, jsonReply(`{
|
||||
"model":"qwen3:30b",
|
||||
"message":{"role":"assistant","content":"hello there"},
|
||||
"done":true,"done_reason":"stop",
|
||||
"prompt_eval_count":12,"eval_count":7
|
||||
}`))
|
||||
|
||||
m, _ := p.Model("qwen3:30b")
|
||||
temp := 0.2
|
||||
resp, err := m.Generate(context.Background(), llm.Request{
|
||||
System: "be terse",
|
||||
Messages: []llm.Message{llm.SystemText("extra sys"), llm.UserText("hi")},
|
||||
Temperature: &temp,
|
||||
MaxTokens: 64,
|
||||
StopSequences: []string{"END"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
|
||||
// Wire assertions.
|
||||
if cap.path != "/api/chat" {
|
||||
t.Errorf("path = %q", cap.path)
|
||||
}
|
||||
if cap.auth != "Bearer test-token" {
|
||||
t.Errorf("auth = %q", cap.auth)
|
||||
}
|
||||
if cap.body["model"] != "qwen3:30b" {
|
||||
t.Errorf("model = %v", cap.body["model"])
|
||||
}
|
||||
if stream, ok := cap.body["stream"].(bool); !ok || stream {
|
||||
t.Errorf("stream must be explicit false, got %v", cap.body["stream"])
|
||||
}
|
||||
msgs := cap.body["messages"].([]any)
|
||||
first := msgs[0].(map[string]any)
|
||||
if first["role"] != "system" || first["content"] != "be terse\n\nextra sys" {
|
||||
t.Errorf("system fold = %v", first)
|
||||
}
|
||||
second := msgs[1].(map[string]any)
|
||||
if second["role"] != "user" || second["content"] != "hi" {
|
||||
t.Errorf("user msg = %v", second)
|
||||
}
|
||||
opts := cap.body["options"].(map[string]any)
|
||||
if opts["temperature"] != 0.2 || opts["num_predict"] != float64(64) {
|
||||
t.Errorf("options = %v", opts)
|
||||
}
|
||||
|
||||
// Response assertions.
|
||||
if resp.Text() != "hello there" {
|
||||
t.Errorf("text = %q", resp.Text())
|
||||
}
|
||||
if resp.FinishReason != llm.FinishStop {
|
||||
t.Errorf("finish = %v", resp.FinishReason)
|
||||
}
|
||||
if resp.Usage.InputTokens != 12 || resp.Usage.OutputTokens != 7 {
|
||||
t.Errorf("usage = %+v", resp.Usage)
|
||||
}
|
||||
if resp.Model != "ollama/qwen3:30b" {
|
||||
t.Errorf("resp.Model = %q", resp.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImagesEncodeAsBase64(t *testing.T) {
|
||||
p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"a cat"},"done":true,"done_reason":"stop"}`))
|
||||
imgBytes := []byte{0xFF, 0xD8, 0xFF, 0xE0, 1, 2, 3}
|
||||
|
||||
m, _ := p.Model("llava")
|
||||
_, err := m.Generate(context.Background(), llm.Request{
|
||||
Messages: []llm.Message{llm.UserParts(llm.Text("describe"), llm.Image("image/jpeg", imgBytes))},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
msgs := cap.body["messages"].([]any)
|
||||
user := msgs[0].(map[string]any)
|
||||
images := user["images"].([]any)
|
||||
if len(images) != 1 || images[0] != base64.StdEncoding.EncodeToString(imgBytes) {
|
||||
t.Errorf("images = %v", images)
|
||||
}
|
||||
if strings.Contains(images[0].(string), "data:") {
|
||||
t.Error("images must be raw base64 without data: prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolsAndToolCallRoundTrip(t *testing.T) {
|
||||
p, cap := serve(t, 200, jsonReply(`{
|
||||
"message":{"role":"assistant","content":"","tool_calls":[
|
||||
{"function":{"index":0,"name":"get_weather","arguments":{"city":"Tokyo"}}}
|
||||
]},
|
||||
"done":true,"done_reason":"stop"
|
||||
}`))
|
||||
|
||||
tool := llm.Tool{
|
||||
Name: "get_weather", Description: "weather",
|
||||
Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}`),
|
||||
}
|
||||
m, _ := p.Model("qwen3")
|
||||
resp, err := m.Generate(context.Background(), basicRequest(), llm.WithTools(tool))
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
|
||||
// Tools serialize with parameters as an object.
|
||||
tools := cap.body["tools"].([]any)
|
||||
fn := tools[0].(map[string]any)["function"].(map[string]any)
|
||||
if fn["name"] != "get_weather" {
|
||||
t.Errorf("tool fn = %v", fn)
|
||||
}
|
||||
if _, ok := fn["parameters"].(map[string]any); !ok {
|
||||
t.Errorf("parameters must be an object, got %T", fn["parameters"])
|
||||
}
|
||||
|
||||
// Tool call comes back with arguments as a JSON object → RawMessage.
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("tool calls = %v", resp.ToolCalls)
|
||||
}
|
||||
tc := resp.ToolCalls[0]
|
||||
if tc.Name != "get_weather" || tc.ID == "" {
|
||||
t.Errorf("call = %+v (id must be synthesized)", tc)
|
||||
}
|
||||
var args struct {
|
||||
City string `json:"city"`
|
||||
}
|
||||
if err := json.Unmarshal(tc.Arguments, &args); err != nil || args.City != "Tokyo" {
|
||||
t.Errorf("arguments = %s (%v)", tc.Arguments, err)
|
||||
}
|
||||
if resp.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("finish = %v, want tool_calls", resp.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultsAndHistoryToolCalls(t *testing.T) {
|
||||
p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"21C"},"done":true,"done_reason":"stop"}`))
|
||||
|
||||
m, _ := p.Model("qwen3")
|
||||
_, err := m.Generate(context.Background(), llm.Request{
|
||||
Messages: []llm.Message{
|
||||
llm.UserText("weather?"),
|
||||
{Role: llm.RoleAssistant, ToolCalls: []llm.ToolCall{
|
||||
{ID: "call_0", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Tokyo"}`)},
|
||||
}},
|
||||
llm.ToolResultsMessage(
|
||||
llm.ToolResult{ID: "call_0", Name: "get_weather", Content: `{"temp":21}`},
|
||||
llm.ToolResult{ID: "call_1", Name: "broken_tool", Content: "boom", IsError: true},
|
||||
),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
msgs := cap.body["messages"].([]any)
|
||||
if len(msgs) != 4 {
|
||||
t.Fatalf("messages = %d, want 4 (user, assistant, 2 tool results)", len(msgs))
|
||||
}
|
||||
asst := msgs[1].(map[string]any)
|
||||
calls := asst["tool_calls"].([]any)
|
||||
args := calls[0].(map[string]any)["function"].(map[string]any)["arguments"]
|
||||
if _, ok := args.(map[string]any); !ok {
|
||||
t.Errorf("history tool-call arguments must be an object, got %T", args)
|
||||
}
|
||||
tr1 := msgs[2].(map[string]any)
|
||||
if tr1["role"] != "tool" || tr1["tool_name"] != "get_weather" || tr1["content"] != `{"temp":21}` {
|
||||
t.Errorf("tool result 1 = %v", tr1)
|
||||
}
|
||||
tr2 := msgs[3].(map[string]any)
|
||||
if tr2["content"] != "ERROR: boom" {
|
||||
t.Errorf("error result content = %v", tr2["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructuredOutputFormat(t *testing.T) {
|
||||
p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"{\"name\":\"Ada\"}"},"done":true,"done_reason":"stop"}`))
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}},"required":["name"]}`)
|
||||
|
||||
m, _ := p.Model("qwen3")
|
||||
resp, err := m.Generate(context.Background(), basicRequest(), llm.WithSchema(schema, "person"))
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
format, ok := cap.body["format"].(map[string]any)
|
||||
if !ok || format["type"] != "object" {
|
||||
t.Errorf("format = %v, want the schema object", cap.body["format"])
|
||||
}
|
||||
if resp.Text() != `{"name":"Ada"}` {
|
||||
t.Errorf("text = %q", resp.Text())
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkMapping(t *testing.T) {
|
||||
p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"ok"},"done":true,"done_reason":"stop"}`))
|
||||
m, _ := p.Model("gpt-oss:120b")
|
||||
_, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("high"))
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if cap.body["think"] != "high" {
|
||||
t.Errorf("think = %v", cap.body["think"])
|
||||
}
|
||||
|
||||
if _, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("max")); err == nil {
|
||||
t.Error("invalid reasoning effort should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolChoiceNoneDropsTools(t *testing.T) {
|
||||
p, cap := serve(t, 200, jsonReply(`{"message":{"role":"assistant","content":"ok"},"done":true,"done_reason":"stop"}`))
|
||||
m, _ := p.Model("qwen3")
|
||||
_, err := m.Generate(context.Background(), basicRequest(),
|
||||
llm.WithTools(llm.Tool{Name: "t"}), llm.WithToolChoice("none"))
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if _, present := cap.body["tools"]; present {
|
||||
t.Error("tool_choice none must omit tools")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingNDJSON(t *testing.T) {
|
||||
p, _ := serve(t, 200, func(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
_, _ = io.WriteString(w, `{"message":{"role":"assistant","content":"Hel"},"done":false}
|
||||
{"message":{"role":"assistant","content":"lo"},"done":false}
|
||||
{"message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"ping","arguments":{}}}]},"done":false}
|
||||
{"message":{"role":"assistant","content":""},"done":true,"done_reason":"stop","prompt_eval_count":5,"eval_count":9}
|
||||
`)
|
||||
})
|
||||
|
||||
m, _ := p.Model("qwen3")
|
||||
s, err := m.Stream(context.Background(), basicRequest())
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
var text strings.Builder
|
||||
var toolCalls []llm.ToolCall
|
||||
var final *llm.Response
|
||||
for {
|
||||
ev, err := s.Next()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Next: %v", err)
|
||||
}
|
||||
text.WriteString(ev.TextDelta)
|
||||
if ev.ToolCall != nil {
|
||||
toolCalls = append(toolCalls, *ev.ToolCall)
|
||||
}
|
||||
if ev.Response != nil {
|
||||
final = ev.Response
|
||||
}
|
||||
}
|
||||
if text.String() != "Hello" {
|
||||
t.Errorf("text = %q", text.String())
|
||||
}
|
||||
if len(toolCalls) != 1 || toolCalls[0].Name != "ping" {
|
||||
t.Errorf("tool calls = %+v", toolCalls)
|
||||
}
|
||||
if final == nil {
|
||||
t.Fatal("no final response event")
|
||||
}
|
||||
if final.Usage.InputTokens != 5 || final.Usage.OutputTokens != 9 {
|
||||
t.Errorf("final usage = %+v", final.Usage)
|
||||
}
|
||||
if final.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("final finish = %v", final.FinishReason)
|
||||
}
|
||||
if final.Text() != "Hello" {
|
||||
t.Errorf("final text = %q", final.Text())
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamingForemanSingleObject: foreman returns one buffered JSON
|
||||
// object to a stream:true request; the stream must still deliver the text
|
||||
// and a final response.
|
||||
func TestStreamingForemanSingleObject(t *testing.T) {
|
||||
p, cap := serve(t, 200, func(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"message":{"role":"assistant","content":"queued answer"},"done":true,"done_reason":"stop","prompt_eval_count":3,"eval_count":4}`)
|
||||
})
|
||||
|
||||
m, _ := p.Model("qwen3:30b")
|
||||
s, err := m.Stream(context.Background(), basicRequest())
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
if stream, ok := cap.body["stream"].(bool); !ok || !stream {
|
||||
t.Errorf("stream flag = %v, want true", cap.body["stream"])
|
||||
}
|
||||
|
||||
var text strings.Builder
|
||||
var final *llm.Response
|
||||
for {
|
||||
ev, err := s.Next()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Next: %v", err)
|
||||
}
|
||||
text.WriteString(ev.TextDelta)
|
||||
if ev.Response != nil {
|
||||
final = ev.Response
|
||||
}
|
||||
}
|
||||
if text.String() != "queued answer" || final == nil || final.Usage.OutputTokens != 4 {
|
||||
t.Errorf("text=%q final=%+v", text.String(), final)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorMapping(t *testing.T) {
|
||||
t.Run("404 is model-not-found", func(t *testing.T) {
|
||||
p, _ := serve(t, 404, jsonReply(`{"error":"model not found"}`))
|
||||
m, _ := p.Model("nope")
|
||||
_, err := m.Generate(context.Background(), basicRequest())
|
||||
if !errors.Is(err, llm.ErrModelNotFound) {
|
||||
t.Errorf("error = %v, want ErrModelNotFound", err)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassPermanent {
|
||||
t.Error("404 must classify permanent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("503 transient with message", func(t *testing.T) {
|
||||
p, _ := serve(t, 503, jsonReply(`{"error":"request cancelled while waiting"}`))
|
||||
m, _ := p.Model("qwen3")
|
||||
_, err := m.Generate(context.Background(), basicRequest())
|
||||
var apiErr *llm.APIError
|
||||
if !errors.As(err, &apiErr) || apiErr.Status != 503 || !strings.Contains(apiErr.Message, "cancelled") {
|
||||
t.Errorf("error = %v", err)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassTransient {
|
||||
t.Error("503 must classify transient")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-JSON error body", func(t *testing.T) {
|
||||
p, _ := serve(t, 500, jsonReply(`upstream exploded`))
|
||||
m, _ := p.Model("qwen3")
|
||||
_, err := m.Generate(context.Background(), basicRequest())
|
||||
var apiErr *llm.APIError
|
||||
if !errors.As(err, &apiErr) || !strings.Contains(apiErr.Message, "upstream exploded") {
|
||||
t.Errorf("error = %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCapabilityEnforcement(t *testing.T) {
|
||||
p, _ := serve(t, 200, jsonReply(`{"message":{"content":"x"},"done":true}`))
|
||||
|
||||
t.Run("too many images", func(t *testing.T) {
|
||||
m, _ := p.Model("llava", llm.WithCapabilities(llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/png"}}))
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{
|
||||
llm.UserParts(llm.Image("image/png", []byte{1}), llm.Image("image/png", []byte{2})),
|
||||
}})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Errorf("error = %v, want ErrUnsupported", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("images on text-only model", func(t *testing.T) {
|
||||
m, _ := p.Model("qwen3", llm.WithCapabilities(llm.Capabilities{}))
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{
|
||||
llm.UserParts(llm.Image("image/png", []byte{1})),
|
||||
}})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Errorf("error = %v, want ErrUnsupported", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disallowed mime", func(t *testing.T) {
|
||||
m, _ := p.Model("llava") // default caps: jpeg/png only
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{
|
||||
llm.UserParts(llm.Image("image/tiff", []byte{1})),
|
||||
}})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Errorf("error = %v, want ErrUnsupported", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNoBaseURL(t *testing.T) {
|
||||
p := New(WithBaseURL(""))
|
||||
m, _ := p.Model("x")
|
||||
if _, err := m.Generate(context.Background(), basicRequest()); err == nil ||
|
||||
!strings.Contains(err.Error(), "no base URL") {
|
||||
t.Errorf("error = %v, want a clear no-base-URL message", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeHost(t *testing.T) {
|
||||
for in, want := range map[string]string{
|
||||
"myhost": "http://myhost:11434",
|
||||
"myhost:8080": "http://myhost:8080",
|
||||
"http://myhost:8080/": "http://myhost:8080",
|
||||
"https://ollama.com": "https://ollama.com",
|
||||
" 127.0.0.1:11434 ": "http://127.0.0.1:11434",
|
||||
} {
|
||||
if got := NormalizeHost(in); got != want {
|
||||
t.Errorf("NormalizeHost(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPresets(t *testing.T) {
|
||||
t.Run("cloud", func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_API_KEY", "cloud-key")
|
||||
p := Cloud()
|
||||
if p.Name() != "ollama-cloud" || p.baseURL != DefaultCloudBaseURL || p.token != "cloud-key" {
|
||||
t.Errorf("cloud preset = %+v", p)
|
||||
}
|
||||
})
|
||||
t.Run("local respects OLLAMA_HOST", func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "box.lan:9999")
|
||||
p := Local()
|
||||
if p.Name() != "ollama" || p.baseURL != "http://box.lan:9999" || p.token != "" {
|
||||
t.Errorf("local preset = %+v", p)
|
||||
}
|
||||
})
|
||||
t.Run("foreman", func(t *testing.T) {
|
||||
p := Foreman("http://foreman-m1:8080", "tok")
|
||||
if p.Name() != "foreman" || p.baseURL != "http://foreman-m1:8080" || p.token != "tok" {
|
||||
t.Errorf("foreman preset = %+v", p)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLocalNoAuthHeader(t *testing.T) {
|
||||
p, cap := serve(t, 200, jsonReply(`{"message":{"content":"x"},"done":true}`))
|
||||
p.token = "" // simulate local mode on the test server
|
||||
m, _ := p.Model("llama3")
|
||||
if _, err := m.Generate(context.Background(), basicRequest()); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if cap.auth != "" {
|
||||
t.Errorf("auth header = %q, want none in local mode", cap.auth)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// Stream implements llm.Model over Ollama's NDJSON streaming. It also
|
||||
// transparently handles foreman's non-streaming degradation (a single
|
||||
// buffered JSON object): one JSON line parses as the final chunk.
|
||||
func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) {
|
||||
req = req.Apply(opts...)
|
||||
if err := m.enforceCapabilities(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wireReq, err := m.buildRequest(req, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := m.do(ctx, wireReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sc := bufio.NewScanner(resp.Body)
|
||||
// Single NDJSON lines can far exceed the 64KB default (thinking dumps,
|
||||
// tool payloads, foreman's whole-response-as-one-line degradation).
|
||||
sc.Buffer(make([]byte, 64<<10), 16<<20)
|
||||
|
||||
return &stream{model: m, body: resp.Body, scanner: sc}, nil
|
||||
}
|
||||
|
||||
type stream struct {
|
||||
model *model
|
||||
body io.Closer
|
||||
scanner *bufio.Scanner
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
finished bool
|
||||
toolCalls []llm.ToolCall
|
||||
text []byte
|
||||
pending []llm.StreamEvent
|
||||
usage llm.Usage
|
||||
doneReason string
|
||||
}
|
||||
|
||||
func (s *stream) Next() (llm.StreamEvent, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for {
|
||||
if len(s.pending) > 0 {
|
||||
ev := s.pending[0]
|
||||
s.pending = s.pending[1:]
|
||||
return ev, nil
|
||||
}
|
||||
if s.finished {
|
||||
return llm.StreamEvent{}, io.EOF
|
||||
}
|
||||
if !s.scanner.Scan() {
|
||||
if err := s.scanner.Err(); err != nil {
|
||||
return llm.StreamEvent{}, fmt.Errorf("ollama %s: read stream: %w", s.model.qualified(), err)
|
||||
}
|
||||
// EOF without a done chunk: synthesize the final response from
|
||||
// what we accumulated rather than losing it.
|
||||
s.queueFinal()
|
||||
continue
|
||||
}
|
||||
line := s.scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var chunk chatResponse
|
||||
if err := json.Unmarshal(line, &chunk); err != nil {
|
||||
return llm.StreamEvent{}, fmt.Errorf("ollama %s: decode stream chunk: %w", s.model.qualified(), err)
|
||||
}
|
||||
|
||||
if chunk.Message.Content != "" {
|
||||
s.text = append(s.text, chunk.Message.Content...)
|
||||
s.pending = append(s.pending, llm.StreamEvent{TextDelta: chunk.Message.Content})
|
||||
}
|
||||
// Tool calls arrive complete per chunk (no partial-argument deltas
|
||||
// in the native protocol).
|
||||
base := len(s.toolCalls)
|
||||
for i, tc := range chunk.Message.ToolCalls {
|
||||
id := tc.ID
|
||||
if id == "" {
|
||||
id = "call_" + strconv.Itoa(base+i)
|
||||
}
|
||||
args := tc.Function.Arguments
|
||||
if len(args) == 0 {
|
||||
args = json.RawMessage("{}")
|
||||
}
|
||||
call := llm.ToolCall{ID: id, Name: tc.Function.Name, Arguments: args}
|
||||
s.toolCalls = append(s.toolCalls, call)
|
||||
s.pending = append(s.pending, llm.StreamEvent{ToolCall: &s.toolCalls[len(s.toolCalls)-1]})
|
||||
}
|
||||
if chunk.Done {
|
||||
s.usage = llm.Usage{InputTokens: chunk.PromptEvalCount, OutputTokens: chunk.EvalCount}
|
||||
s.doneReason = chunk.DoneReason
|
||||
s.queueFinal()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// queueFinal appends the final Response event and marks the stream done.
|
||||
func (s *stream) queueFinal() {
|
||||
resp := &llm.Response{
|
||||
Model: s.model.qualified(),
|
||||
Usage: s.usage,
|
||||
FinishReason: finishReason(s.doneReason, len(s.toolCalls) > 0),
|
||||
}
|
||||
if len(s.text) > 0 {
|
||||
resp.Parts = append(resp.Parts, llm.Text(string(s.text)))
|
||||
}
|
||||
if len(s.toolCalls) > 0 {
|
||||
resp.ToolCalls = append([]llm.ToolCall(nil), s.toolCalls...)
|
||||
}
|
||||
s.pending = append(s.pending, llm.StreamEvent{Response: resp})
|
||||
s.finished = true
|
||||
}
|
||||
|
||||
func (s *stream) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
return s.body.Close()
|
||||
}
|
||||
@@ -0,0 +1,343 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// ---- wire types (field names per ollama api/types.go) ----
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Tools []toolDef `json:"tools,omitempty"`
|
||||
Format json.RawMessage `json:"format,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
// Stream has no omitempty on purpose: the server default is true, so
|
||||
// Generate must send an explicit false.
|
||||
Stream bool `json:"stream"`
|
||||
// Think is bool-or-string on the wire ("low"/"medium"/"high" or a bool).
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Images []string `json:"images,omitempty"` // raw base64, no data: prefix
|
||||
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"` // on role:"tool" results
|
||||
}
|
||||
|
||||
type toolDef struct {
|
||||
Type string `json:"type"`
|
||||
Function toolDefFunc `json:"function"`
|
||||
}
|
||||
|
||||
type toolDefFunc struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type toolCall struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Function toolCallFunc `json:"function"`
|
||||
}
|
||||
|
||||
type toolCallFunc struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
Name string `json:"name"`
|
||||
// Arguments is a JSON OBJECT on the wire (unlike OpenAI's string).
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
Model string `json:"model"`
|
||||
Message respMessage `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
}
|
||||
|
||||
type respMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Thinking string `json:"thinking"`
|
||||
ToolCalls []toolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
type errorBody struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// ---- model ----
|
||||
|
||||
type model struct {
|
||||
provider *Provider
|
||||
id string
|
||||
caps llm.Capabilities
|
||||
}
|
||||
|
||||
func (m *model) Capabilities() llm.Capabilities { return m.caps }
|
||||
|
||||
func (m *model) qualified() string { return m.provider.name + "/" + m.id }
|
||||
|
||||
// enforceCapabilities is the backstop check (the media layer normalizes
|
||||
// before requests get here; see ADR-0009).
|
||||
func (m *model) enforceCapabilities(req llm.Request) error {
|
||||
count := 0
|
||||
for _, msg := range req.Messages {
|
||||
for _, part := range msg.Parts {
|
||||
img, ok := part.(llm.ImagePart)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
if !m.caps.SupportsImages() {
|
||||
return fmt.Errorf("%w: %s does not accept image input", llm.ErrUnsupported, m.qualified())
|
||||
}
|
||||
if !m.caps.MIMEAllowed(img.MIME) {
|
||||
return fmt.Errorf("%w: %s does not accept %s images", llm.ErrUnsupported, m.qualified(), img.MIME)
|
||||
}
|
||||
if m.caps.MaxImageBytes > 0 && len(img.Data) > m.caps.MaxImageBytes {
|
||||
return fmt.Errorf("%w: image of %d bytes exceeds %s limit of %d",
|
||||
llm.ErrUnsupported, len(img.Data), m.qualified(), m.caps.MaxImageBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
if count > 0 && m.caps.MaxImagesPerReq > 0 && count > m.caps.MaxImagesPerReq {
|
||||
return fmt.Errorf("%w: %d images exceed %s limit of %d",
|
||||
llm.ErrUnsupported, count, m.qualified(), m.caps.MaxImagesPerReq)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildRequest maps the canonical request onto the wire shape.
|
||||
func (m *model) buildRequest(req llm.Request, stream bool) (*chatRequest, error) {
|
||||
out := &chatRequest{Model: m.id, Stream: stream}
|
||||
|
||||
// System prompt: dedicated field first, then folded RoleSystem messages.
|
||||
var sys []string
|
||||
if req.System != "" {
|
||||
sys = append(sys, req.System)
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
if msg.Role == llm.RoleSystem {
|
||||
if t := msg.Text(); t != "" {
|
||||
sys = append(sys, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(sys) > 0 {
|
||||
out.Messages = append(out.Messages, chatMessage{
|
||||
Role: "system", Content: strings.Join(sys, "\n\n"),
|
||||
})
|
||||
}
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
switch msg.Role {
|
||||
case llm.RoleSystem:
|
||||
// Already folded above.
|
||||
case llm.RoleTool:
|
||||
for _, res := range msg.ToolResults {
|
||||
content := res.Content
|
||||
if res.IsError {
|
||||
content = "ERROR: " + content
|
||||
}
|
||||
out.Messages = append(out.Messages, chatMessage{
|
||||
Role: "tool", Content: content, ToolName: res.Name,
|
||||
})
|
||||
}
|
||||
default:
|
||||
cm := chatMessage{Role: string(msg.Role), Content: msg.Text()}
|
||||
for _, part := range msg.Parts {
|
||||
if img, ok := part.(llm.ImagePart); ok {
|
||||
cm.Images = append(cm.Images, base64.StdEncoding.EncodeToString(img.Data))
|
||||
}
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
args := tc.Arguments
|
||||
if len(args) == 0 {
|
||||
args = json.RawMessage("{}")
|
||||
}
|
||||
cm.ToolCalls = append(cm.ToolCalls, toolCall{
|
||||
ID: tc.ID,
|
||||
Function: toolCallFunc{Name: tc.Name, Arguments: args},
|
||||
})
|
||||
}
|
||||
out.Messages = append(out.Messages, cm)
|
||||
}
|
||||
}
|
||||
|
||||
// Tools. Ollama has no tool_choice: "none" maps to omitting the tools;
|
||||
// "required"/named choices have no wire equivalent and are best-effort
|
||||
// ignored (documented in the README support matrix).
|
||||
if req.ToolChoice != "none" {
|
||||
for _, t := range req.Tools {
|
||||
params := t.Parameters
|
||||
if len(params) == 0 {
|
||||
params = json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
}
|
||||
out.Tools = append(out.Tools, toolDef{
|
||||
Type: "function",
|
||||
Function: toolDefFunc{Name: t.Name, Description: t.Description, Parameters: params},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.Schema) > 0 {
|
||||
out.Format = req.Schema
|
||||
}
|
||||
|
||||
opts := make(map[string]any)
|
||||
if req.Temperature != nil {
|
||||
opts["temperature"] = *req.Temperature
|
||||
}
|
||||
if req.TopP != nil {
|
||||
opts["top_p"] = *req.TopP
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
opts["num_predict"] = req.MaxTokens
|
||||
}
|
||||
if len(req.StopSequences) > 0 {
|
||||
opts["stop"] = req.StopSequences
|
||||
}
|
||||
if len(opts) > 0 {
|
||||
out.Options = opts
|
||||
}
|
||||
|
||||
switch req.ReasoningEffort {
|
||||
case "":
|
||||
case "low", "medium", "high":
|
||||
out.Think = json.RawMessage(strconv.Quote(req.ReasoningEffort))
|
||||
default:
|
||||
return nil, fmt.Errorf("ollama: invalid reasoning effort %q (want low/medium/high)", req.ReasoningEffort)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// do POSTs /api/chat and returns the response body on 2xx, or a classified
|
||||
// error.
|
||||
func (m *model) do(ctx context.Context, wireReq *chatRequest) (*http.Response, error) {
|
||||
p := m.provider
|
||||
if err := p.checkReady(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err := json.Marshal(wireReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama: encode request: %w", err)
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/chat", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if p.token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+p.token)
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama %s: do request: %w", m.qualified(), err)
|
||||
}
|
||||
if resp.StatusCode/100 != 2 {
|
||||
defer resp.Body.Close()
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||
var eb errorBody
|
||||
_ = json.Unmarshal(raw, &eb)
|
||||
msg := eb.Error
|
||||
if msg == "" {
|
||||
msg = strings.TrimSpace(string(raw))
|
||||
}
|
||||
return nil, &llm.APIError{
|
||||
Provider: p.name, Model: m.id,
|
||||
Status: resp.StatusCode, Message: msg,
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Generate implements llm.Model.
|
||||
func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) {
|
||||
req = req.Apply(opts...)
|
||||
if err := m.enforceCapabilities(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wireReq, err := m.buildRequest(req, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := m.do(ctx, wireReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var cr chatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&cr); err != nil {
|
||||
return nil, fmt.Errorf("ollama %s: decode response: %w", m.qualified(), err)
|
||||
}
|
||||
return m.toResponse(&cr), nil
|
||||
}
|
||||
|
||||
// toResponse converts a final wire chunk into the canonical response.
|
||||
func (m *model) toResponse(cr *chatResponse) *llm.Response {
|
||||
out := &llm.Response{
|
||||
Model: m.qualified(),
|
||||
Usage: llm.Usage{InputTokens: cr.PromptEvalCount, OutputTokens: cr.EvalCount},
|
||||
Raw: cr,
|
||||
}
|
||||
if cr.Message.Content != "" {
|
||||
out.Parts = append(out.Parts, llm.Text(cr.Message.Content))
|
||||
}
|
||||
out.ToolCalls = convertToolCalls(cr.Message.ToolCalls)
|
||||
out.FinishReason = finishReason(cr.DoneReason, len(out.ToolCalls) > 0)
|
||||
return out
|
||||
}
|
||||
|
||||
// convertToolCalls maps wire tool calls, synthesizing ids where the model
|
||||
// omitted them (ids are optional in Ollama's shape but required by our
|
||||
// agent loop to match results to calls).
|
||||
func convertToolCalls(calls []toolCall) []llm.ToolCall {
|
||||
out := make([]llm.ToolCall, 0, len(calls))
|
||||
for i, tc := range calls {
|
||||
id := tc.ID
|
||||
if id == "" {
|
||||
id = "call_" + strconv.Itoa(i)
|
||||
}
|
||||
args := tc.Function.Arguments
|
||||
if len(args) == 0 {
|
||||
args = json.RawMessage("{}")
|
||||
}
|
||||
out = append(out, llm.ToolCall{ID: id, Name: tc.Function.Name, Arguments: args})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func finishReason(doneReason string, hasToolCalls bool) llm.FinishReason {
|
||||
if hasToolCalls {
|
||||
return llm.FinishToolCalls
|
||||
}
|
||||
switch doneReason {
|
||||
case "stop", "":
|
||||
return llm.FinishStop
|
||||
case "length":
|
||||
return llm.FinishLength
|
||||
default:
|
||||
return llm.FinishOther
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// model is one provider-bound target.
|
||||
type model struct {
|
||||
p *Provider
|
||||
id string
|
||||
caps llm.Capabilities
|
||||
}
|
||||
|
||||
// Capabilities implements llm.Model.
|
||||
func (m *model) Capabilities() llm.Capabilities { return m.caps }
|
||||
|
||||
// Generate implements llm.Model.
|
||||
func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) {
|
||||
req = req.Apply(opts...)
|
||||
if err := checkRequest(m.caps, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpResp, err := m.do(ctx, req, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
if httpResp.StatusCode/100 != 2 {
|
||||
return nil, m.apiError(httpResp)
|
||||
}
|
||||
var wire chatResponse
|
||||
if err := json.NewDecoder(httpResp.Body).Decode(&wire); err != nil {
|
||||
return nil, fmt.Errorf("openai: decode response: %w", err)
|
||||
}
|
||||
return m.toResponse(&wire), nil
|
||||
}
|
||||
|
||||
// Stream implements llm.Model.
|
||||
func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) {
|
||||
req = req.Apply(opts...)
|
||||
if !m.caps.SupportsStreaming {
|
||||
return nil, fmt.Errorf("%w: streaming not supported by %s/%s", llm.ErrUnsupported, m.p.name, m.id)
|
||||
}
|
||||
if err := checkRequest(m.caps, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpResp, err := m.do(ctx, req, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if httpResp.StatusCode/100 != 2 {
|
||||
defer httpResp.Body.Close()
|
||||
return nil, m.apiError(httpResp)
|
||||
}
|
||||
sc := bufio.NewScanner(httpResp.Body)
|
||||
// Why: a single SSE data line carries a whole JSON chunk; tool-call
|
||||
// argument fragments can make lines far larger than Scanner's 64 KiB
|
||||
// default cap.
|
||||
sc.Buffer(make([]byte, 0, 64*1024), 16<<20)
|
||||
return &stream{m: m, body: httpResp.Body, sc: sc}, nil
|
||||
}
|
||||
|
||||
// do builds and performs the HTTP request. Transport failures are wrapped
|
||||
// raw (never as *llm.APIError) so llm.Classify still sees net.Error,
|
||||
// syscall errnos, and context errors underneath.
|
||||
func (m *model) do(ctx context.Context, req llm.Request, stream bool) (*http.Response, error) {
|
||||
if m.p.apiKey == "" {
|
||||
// Why a synthetic 401: the constructor never fails, so a missing
|
||||
// key must surface at request time as the auth failure it is —
|
||||
// permanent under llm.Classify, like a real 401.
|
||||
return nil, &llm.APIError{
|
||||
Provider: m.p.name,
|
||||
Model: m.id,
|
||||
Status: http.StatusUnauthorized,
|
||||
Code: "missing_api_key",
|
||||
Message: "no API key configured: set OPENAI_API_KEY or use WithAPIKey",
|
||||
}
|
||||
}
|
||||
body, err := json.Marshal(m.buildRequest(req, stream))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai: encode request: %w", err)
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, m.p.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+m.p.apiKey)
|
||||
if stream {
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
httpResp, err := m.p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai: do request: %w", err)
|
||||
}
|
||||
return httpResp, nil
|
||||
}
|
||||
|
||||
// apiError converts a non-2xx response into *llm.APIError, pulling code and
|
||||
// message from the {"error":{...}} body when it parses.
|
||||
func (m *model) apiError(httpResp *http.Response) error {
|
||||
apiErr := &llm.APIError{Provider: m.p.name, Model: m.id, Status: httpResp.StatusCode}
|
||||
body, _ := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20))
|
||||
var env errorEnvelope
|
||||
if err := json.Unmarshal(body, &env); err == nil &&
|
||||
(env.Error.Message != "" || env.Error.Type != "" || env.Error.Code != "") {
|
||||
apiErr.Message = env.Error.Message
|
||||
apiErr.Code = env.Error.Code
|
||||
if apiErr.Code == "" {
|
||||
apiErr.Code = env.Error.Type
|
||||
}
|
||||
} else {
|
||||
// Why: compat servers emit all sorts of error bodies; a raw snippet
|
||||
// beats silence when the canonical envelope is absent.
|
||||
apiErr.Message = strings.TrimSpace(string(body))
|
||||
}
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// toResponse maps the wire response onto the canonical llm.Response.
|
||||
func (m *model) toResponse(wire *chatResponse) *llm.Response {
|
||||
resp := &llm.Response{Model: m.p.name + "/" + m.id, Raw: wire}
|
||||
if wire.Usage != nil {
|
||||
resp.Usage = llm.Usage{
|
||||
InputTokens: wire.Usage.PromptTokens,
|
||||
OutputTokens: wire.Usage.CompletionTokens,
|
||||
}
|
||||
}
|
||||
if len(wire.Choices) == 0 {
|
||||
resp.FinishReason = llm.FinishOther
|
||||
return resp
|
||||
}
|
||||
choice := wire.Choices[0]
|
||||
if choice.Message.Content != "" {
|
||||
resp.Parts = append(resp.Parts, llm.TextPart{Text: choice.Message.Content})
|
||||
}
|
||||
for i, tc := range choice.Message.ToolCalls {
|
||||
id := tc.ID
|
||||
if id == "" {
|
||||
// Why: ToolResult.ID must echo ToolCall.ID, so calls from compat
|
||||
// servers that omit ids get synthesized ones.
|
||||
id = fmt.Sprintf("call_%d", i)
|
||||
}
|
||||
resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{
|
||||
ID: id,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: json.RawMessage(tc.Function.Arguments),
|
||||
})
|
||||
}
|
||||
resp.FinishReason = mapFinish(choice.FinishReason, len(resp.ToolCalls) > 0)
|
||||
return resp
|
||||
}
|
||||
|
||||
// mapFinish maps a wire finish_reason to the canonical enum. Tool-call
|
||||
// presence wins over the reported reason: a forced (named tool_choice) call
|
||||
// can finish with "stop" while still carrying tool_calls.
|
||||
func mapFinish(reason string, hasToolCalls bool) llm.FinishReason {
|
||||
if hasToolCalls {
|
||||
return llm.FinishToolCalls
|
||||
}
|
||||
switch reason {
|
||||
case "stop":
|
||||
return llm.FinishStop
|
||||
case "length":
|
||||
return llm.FinishLength
|
||||
case "tool_calls":
|
||||
return llm.FinishToolCalls
|
||||
case "content_filter":
|
||||
return llm.FinishContentFilter
|
||||
default:
|
||||
return llm.FinishOther
|
||||
}
|
||||
}
|
||||
|
||||
// checkRequest enforces the model's effective capabilities. Why enforcement
|
||||
// rather than normalization: a separate media layer resizes/transcodes
|
||||
// images BEFORE requests reach the provider; this check is the honest
|
||||
// backstop that refuses, with llm.ErrUnsupported, what the target
|
||||
// declaredly cannot serve (chains advance past it penalty-free).
|
||||
func checkRequest(caps llm.Capabilities, req llm.Request) error {
|
||||
if len(req.Tools) > 0 && !caps.SupportsTools {
|
||||
return fmt.Errorf("%w: tools not supported", llm.ErrUnsupported)
|
||||
}
|
||||
if len(req.Schema) > 0 && !caps.SupportsStructured {
|
||||
return fmt.Errorf("%w: structured output not supported", llm.ErrUnsupported)
|
||||
}
|
||||
images := 0
|
||||
for _, msg := range req.Messages {
|
||||
for _, part := range msg.Parts {
|
||||
img, ok := part.(llm.ImagePart)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
images++
|
||||
if !caps.SupportsImages() {
|
||||
return fmt.Errorf("%w: image input not supported", llm.ErrUnsupported)
|
||||
}
|
||||
if !caps.MIMEAllowed(img.MIME) {
|
||||
return fmt.Errorf("%w: image MIME type %q not allowed (allowed: %s)",
|
||||
llm.ErrUnsupported, img.MIME, strings.Join(caps.AllowedImageMIME, ", "))
|
||||
}
|
||||
if caps.MaxImageBytes > 0 && len(img.Data) > caps.MaxImageBytes {
|
||||
return fmt.Errorf("%w: image is %d bytes, limit is %d",
|
||||
llm.ErrUnsupported, len(img.Data), caps.MaxImageBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
if images > caps.MaxImagesPerReq {
|
||||
return fmt.Errorf("%w: request carries %d images, limit is %d",
|
||||
llm.ErrUnsupported, images, caps.MaxImagesPerReq)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
// Package openai implements llm.Provider for the OpenAI Chat Completions
|
||||
// API and, via WithBaseURL/WithName, any OpenAI-compatible endpoint
|
||||
// (vLLM, Groq, Together, LM Studio, Ollama's /v1 shim, ...).
|
||||
//
|
||||
// Targeted API surface (verified against developers.openai.com, June 2026):
|
||||
// POST {base}/chat/completions with
|
||||
// - messages: plain-string content for text-only turns, part arrays with
|
||||
// base64 data-URL image_url entries for multimodal turns, assistant
|
||||
// tool_calls history, and {"role":"tool","tool_call_id",...} results;
|
||||
// - tools as {"type":"function","function":{...}} with tool_choice
|
||||
// "auto"/"none"/"required" or a named-function object;
|
||||
// - response_format {"type":"json_schema",...} structured output;
|
||||
// - max_completion_tokens (or legacy max_tokens via WithLegacyMaxTokens
|
||||
// for compat servers), temperature, top_p, stop, reasoning_effort;
|
||||
// - data-only SSE streaming with stream_options.include_usage, the
|
||||
// "data: [DONE]" sentinel, and tool-call deltas accumulated by index.
|
||||
//
|
||||
// Newer response fields (refusal, annotations, usage *_details, delta
|
||||
// obfuscation) are tolerated and ignored so both api.openai.com and older
|
||||
// compat servers decode cleanly.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
const defaultBaseURL = "https://api.openai.com/v1"
|
||||
|
||||
// Provider is an llm.Provider backed by an OpenAI Chat Completions endpoint.
|
||||
type Provider struct {
|
||||
name string
|
||||
apiKey string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
caps llm.Capabilities
|
||||
legacyMaxTokens bool
|
||||
}
|
||||
|
||||
// Option configures the provider at construction.
|
||||
type Option func(*Provider)
|
||||
|
||||
// WithAPIKey sets the API key. When absent, New reads OPENAI_API_KEY from
|
||||
// the environment at construction time.
|
||||
func WithAPIKey(key string) Option {
|
||||
return func(p *Provider) { p.apiKey = key }
|
||||
}
|
||||
|
||||
// WithBaseURL points the client at a different endpoint (compat servers).
|
||||
// The path "/chat/completions" is appended; a trailing slash is trimmed.
|
||||
func WithBaseURL(u string) Option {
|
||||
return func(p *Provider) { p.baseURL = u }
|
||||
}
|
||||
|
||||
// WithHTTPClient substitutes the HTTP client (timeouts, proxies, tests).
|
||||
func WithHTTPClient(c *http.Client) Option {
|
||||
return func(p *Provider) {
|
||||
if c != nil {
|
||||
p.client = c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithName overrides the registry name ("openai" by default). Why: the same
|
||||
// client serves many OpenAI-compatible endpoints, and each needs a distinct
|
||||
// name in "provider/model" specs and error reporting.
|
||||
func WithName(name string) Option {
|
||||
return func(p *Provider) { p.name = name }
|
||||
}
|
||||
|
||||
// WithDefaultCapabilities replaces the provider-default capabilities.
|
||||
// Per-model overrides via llm.WithCapabilities still take precedence.
|
||||
func WithDefaultCapabilities(caps llm.Capabilities) Option {
|
||||
return func(p *Provider) { p.caps = caps }
|
||||
}
|
||||
|
||||
// WithLegacyMaxTokens sends Request.MaxTokens as "max_tokens" instead of
|
||||
// "max_completion_tokens". Why: OpenAI deprecated max_tokens, but many
|
||||
// third-party compat servers still only honor the legacy field.
|
||||
func WithLegacyMaxTokens() Option {
|
||||
return func(p *Provider) { p.legacyMaxTokens = true }
|
||||
}
|
||||
|
||||
// defaultCapabilities reflects OpenAI's current vision-capable chat models.
|
||||
// Why these limits: the published per-request caps (1500 images, 512 MB)
|
||||
// are far beyond what compat servers accept; 100 images / 20 MB each is a
|
||||
// conservative envelope, and the MIME list is the documented set (PNG,
|
||||
// JPEG, WEBP, non-animated GIF).
|
||||
func defaultCapabilities() llm.Capabilities {
|
||||
return llm.Capabilities{
|
||||
SupportsTools: true,
|
||||
SupportsStructured: true,
|
||||
SupportsStreaming: true,
|
||||
MaxImagesPerReq: 100,
|
||||
MaxImageBytes: 20 << 20,
|
||||
AllowedImageMIME: []string{"image/jpeg", "image/png", "image/webp", "image/gif"},
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a Provider. It never fails: a missing API key surfaces as a
|
||||
// 401-style *llm.APIError at request time, not at construction.
|
||||
func New(opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
name: "openai",
|
||||
apiKey: os.Getenv("OPENAI_API_KEY"),
|
||||
baseURL: defaultBaseURL,
|
||||
client: http.DefaultClient,
|
||||
caps: defaultCapabilities(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
p.baseURL = strings.TrimRight(p.baseURL, "/")
|
||||
return p
|
||||
}
|
||||
|
||||
// Name implements llm.Provider.
|
||||
func (p *Provider) Name() string { return p.name }
|
||||
|
||||
// Model implements llm.Provider. The id is passed through verbatim — no
|
||||
// catalog validation; unknown models fail at request time with the
|
||||
// backend's own error.
|
||||
func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) {
|
||||
cfg := llm.ApplyModelOptions(opts)
|
||||
caps := p.caps
|
||||
if cfg.Capabilities != nil {
|
||||
caps = *cfg.Capabilities
|
||||
}
|
||||
return &model{p: p, id: id, caps: caps}, nil
|
||||
}
|
||||
@@ -0,0 +1,614 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
var (
|
||||
_ llm.Provider = (*Provider)(nil)
|
||||
_ llm.Model = (*model)(nil)
|
||||
_ llm.Stream = (*stream)(nil)
|
||||
)
|
||||
|
||||
const textResponse = `{
|
||||
"id": "chatcmpl-1", "object": "chat.completion", "created": 1741570283, "model": "gpt-test",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "hello", "refusal": null, "annotations": []},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 19, "completion_tokens": 10, "total_tokens": 29,
|
||||
"prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
|
||||
"completion_tokens_details": {"reasoning_tokens": 0}
|
||||
},
|
||||
"service_tier": "default", "system_fingerprint": "fp_x"
|
||||
}`
|
||||
|
||||
// recorded captures the last request a test server received.
|
||||
type recorded struct {
|
||||
body map[string]any
|
||||
header http.Header
|
||||
path string
|
||||
hits int
|
||||
}
|
||||
|
||||
// newServer starts a test server that records the request and replies with
|
||||
// a fixed status and body.
|
||||
func newServer(t *testing.T, status int, respBody string) (*httptest.Server, *recorded) {
|
||||
t.Helper()
|
||||
rec := &recorded{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rec.hits++
|
||||
rec.header = r.Header.Clone()
|
||||
rec.path = r.URL.Path
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read request body: %v", err)
|
||||
}
|
||||
if len(raw) > 0 {
|
||||
if err := json.Unmarshal(raw, &rec.body); err != nil {
|
||||
t.Errorf("request body is not JSON: %v\n%s", err, raw)
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
io.WriteString(w, respBody)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
return srv, rec
|
||||
}
|
||||
|
||||
func testModel(t *testing.T, srv *httptest.Server, popts []Option, mopts ...llm.ModelOption) llm.Model {
|
||||
t.Helper()
|
||||
opts := append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, popts...)
|
||||
m, err := New(opts...).Model("gpt-test", mopts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func fptr(f float64) *float64 { return &f }
|
||||
|
||||
func TestGenerateRequestShape(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
|
||||
req := llm.Request{
|
||||
System: "base system",
|
||||
Messages: []llm.Message{
|
||||
llm.SystemText("folded system"),
|
||||
llm.UserParts(llm.Text("look:"), llm.Image("image/png", []byte{1, 2, 3})),
|
||||
{
|
||||
Role: llm.RoleAssistant,
|
||||
Parts: []llm.Part{llm.Text("checking")},
|
||||
ToolCalls: []llm.ToolCall{
|
||||
{ID: "call_1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Boston"}`)},
|
||||
},
|
||||
},
|
||||
llm.ToolResultsMessage(
|
||||
llm.ToolResult{ID: "call_1", Name: "get_weather", Content: "72F"},
|
||||
llm.ToolResult{ID: "call_2", Name: "get_weather", Content: "boom", IsError: true},
|
||||
),
|
||||
llm.UserText("thanks"),
|
||||
},
|
||||
Tools: []llm.Tool{{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
|
||||
}},
|
||||
ToolChoice: "auto",
|
||||
Temperature: fptr(0.5),
|
||||
TopP: fptr(0.9),
|
||||
MaxTokens: 256,
|
||||
StopSequences: []string{"END"},
|
||||
ReasoningEffort: "high",
|
||||
Schema: json.RawMessage(`{"type":"object","properties":{"ok":{"type":"boolean"}}}`),
|
||||
SchemaName: "verdict",
|
||||
}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
|
||||
want := map[string]any{
|
||||
"model": "gpt-test",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "system", "content": "base system\n\nfolded system"},
|
||||
map[string]any{"role": "user", "content": []any{
|
||||
map[string]any{"type": "text", "text": "look:"},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,AQID"}},
|
||||
}},
|
||||
map[string]any{"role": "assistant", "content": "checking", "tool_calls": []any{
|
||||
map[string]any{"id": "call_1", "type": "function", "function": map[string]any{
|
||||
"name": "get_weather", "arguments": `{"city":"Boston"}`,
|
||||
}},
|
||||
}},
|
||||
map[string]any{"role": "tool", "content": "72F", "tool_call_id": "call_1"},
|
||||
map[string]any{"role": "tool", "content": "ERROR: boom", "tool_call_id": "call_2"},
|
||||
map[string]any{"role": "user", "content": "thanks"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"type": "function", "function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"parameters": map[string]any{"type": "object", "properties": map[string]any{"city": map[string]any{"type": "string"}}},
|
||||
}},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.9,
|
||||
"max_completion_tokens": float64(256),
|
||||
"stop": []any{"END"},
|
||||
"reasoning_effort": "high",
|
||||
"response_format": map[string]any{"type": "json_schema", "json_schema": map[string]any{
|
||||
"name": "verdict",
|
||||
"schema": map[string]any{"type": "object", "properties": map[string]any{"ok": map[string]any{"type": "boolean"}}},
|
||||
}},
|
||||
}
|
||||
if !reflect.DeepEqual(rec.body, want) {
|
||||
got, _ := json.MarshalIndent(rec.body, "", " ")
|
||||
exp, _ := json.MarshalIndent(want, "", " ")
|
||||
t.Errorf("request body mismatch\ngot:\n%s\nwant:\n%s", got, exp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolChoiceForms(t *testing.T) {
|
||||
tests := []struct {
|
||||
choice string
|
||||
want any // nil = key absent
|
||||
}{
|
||||
{"", nil},
|
||||
{"auto", "auto"},
|
||||
{"none", "none"},
|
||||
{"required", "required"},
|
||||
{"get_weather", map[string]any{"type": "function", "function": map[string]any{"name": "get_weather"}}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run("choice="+tt.choice, func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
req := llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
Tools: []llm.Tool{{Name: "get_weather"}},
|
||||
ToolChoice: tt.choice,
|
||||
}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
got, present := rec.body["tool_choice"]
|
||||
if tt.want == nil {
|
||||
if present {
|
||||
t.Errorf("tool_choice present, want omitted: %v", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("tool_choice = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxTokensField(t *testing.T) {
|
||||
t.Run("default uses max_completion_tokens", func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := rec.body["max_completion_tokens"]; got != float64(64) {
|
||||
t.Errorf("max_completion_tokens = %v, want 64", got)
|
||||
}
|
||||
if _, present := rec.body["max_tokens"]; present {
|
||||
t.Error("max_tokens present, want omitted")
|
||||
}
|
||||
})
|
||||
t.Run("WithLegacyMaxTokens uses max_tokens", func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, []Option{WithLegacyMaxTokens()})
|
||||
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := rec.body["max_tokens"]; got != float64(64) {
|
||||
t.Errorf("max_tokens = %v, want 64", got)
|
||||
}
|
||||
if _, present := rec.body["max_completion_tokens"]; present {
|
||||
t.Error("max_completion_tokens present, want omitted")
|
||||
}
|
||||
})
|
||||
t.Run("zero omits both", func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if _, present := rec.body["max_tokens"]; present {
|
||||
t.Error("max_tokens present, want omitted")
|
||||
}
|
||||
if _, present := rec.body["max_completion_tokens"]; present {
|
||||
t.Error("max_completion_tokens present, want omitted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSchemaNameDefault(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
req := llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
Schema: json.RawMessage(`{"type":"object"}`),
|
||||
}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
rf, ok := rec.body["response_format"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("response_format missing: %v", rec.body)
|
||||
}
|
||||
js, ok := rf["json_schema"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("json_schema missing: %v", rf)
|
||||
}
|
||||
if js["name"] != "response" {
|
||||
t.Errorf("schema name = %v, want %q", js["name"], "response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTextResponse(t *testing.T) {
|
||||
srv, _ := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := resp.Text(); got != "hello" {
|
||||
t.Errorf("Text = %q, want %q", got, "hello")
|
||||
}
|
||||
if resp.FinishReason != llm.FinishStop {
|
||||
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishStop)
|
||||
}
|
||||
if resp.Usage != (llm.Usage{InputTokens: 19, OutputTokens: 10}) {
|
||||
t.Errorf("Usage = %+v, want {19 10}", resp.Usage)
|
||||
}
|
||||
if resp.Model != "openai/gpt-test" {
|
||||
t.Errorf("Model = %q, want %q", resp.Model, "openai/gpt-test")
|
||||
}
|
||||
if len(resp.ToolCalls) != 0 {
|
||||
t.Errorf("ToolCalls = %v, want none", resp.ToolCalls)
|
||||
}
|
||||
if resp.Raw == nil {
|
||||
t.Error("Raw is nil, want wire response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateToolCallResponse(t *testing.T) {
|
||||
const body = `{
|
||||
"id": "chatcmpl-2", "object": "chat.completion", "created": 1, "model": "gpt-test",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": null, "tool_calls": [
|
||||
{"id": "call_9", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Boston\"}"}},
|
||||
{"id": "", "type": "function", "function": {"name": "get_time", "arguments": "{}"}}
|
||||
]},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}
|
||||
}`
|
||||
srv, _ := newServer(t, http.StatusOK, body)
|
||||
m := testModel(t, srv, nil)
|
||||
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 2 {
|
||||
t.Fatalf("ToolCalls = %d, want 2", len(resp.ToolCalls))
|
||||
}
|
||||
tc := resp.ToolCalls[0]
|
||||
if tc.ID != "call_9" || tc.Name != "get_weather" || string(tc.Arguments) != `{"city":"Boston"}` {
|
||||
t.Errorf("ToolCalls[0] = %+v", tc)
|
||||
}
|
||||
if resp.ToolCalls[1].ID != "call_1" {
|
||||
t.Errorf("synthesized ID = %q, want %q", resp.ToolCalls[1].ID, "call_1")
|
||||
}
|
||||
// finish_reason "stop" with tool_calls present: presence wins.
|
||||
if resp.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishToolCalls)
|
||||
}
|
||||
if len(resp.Parts) != 0 {
|
||||
t.Errorf("Parts = %v, want none", resp.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishReasonMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
wire string
|
||||
want llm.FinishReason
|
||||
}{
|
||||
{"stop", llm.FinishStop},
|
||||
{"length", llm.FinishLength},
|
||||
{"tool_calls", llm.FinishToolCalls},
|
||||
{"content_filter", llm.FinishContentFilter},
|
||||
{"function_call", llm.FinishOther},
|
||||
{"weird_new_reason", llm.FinishOther},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.wire, func(t *testing.T) {
|
||||
body := `{"choices":[{"index":0,"message":{"role":"assistant","content":"x"},"finish_reason":"` + tt.wire + `"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`
|
||||
srv, _ := newServer(t, http.StatusOK, body)
|
||||
m := testModel(t, srv, nil)
|
||||
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if resp.FinishReason != tt.want {
|
||||
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIErrorMapping(t *testing.T) {
|
||||
t.Run("429 rate limit is transient", func(t *testing.T) {
|
||||
const body = `{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}`
|
||||
srv, _ := newServer(t, http.StatusTooManyRequests, body)
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusTooManyRequests {
|
||||
t.Errorf("Status = %d, want 429", apiErr.Status)
|
||||
}
|
||||
if apiErr.Code != "rate_limit_exceeded" {
|
||||
t.Errorf("Code = %q, want %q", apiErr.Code, "rate_limit_exceeded")
|
||||
}
|
||||
if apiErr.Message != "Rate limit reached" {
|
||||
t.Errorf("Message = %q", apiErr.Message)
|
||||
}
|
||||
if apiErr.Provider != "openai" || apiErr.Model != "gpt-test" {
|
||||
t.Errorf("Provider/Model = %q/%q", apiErr.Provider, apiErr.Model)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassTransient {
|
||||
t.Errorf("Classify = %v, want transient", got)
|
||||
}
|
||||
})
|
||||
t.Run("401 code null falls back to type, permanent", func(t *testing.T) {
|
||||
const body = `{"error":{"message":"Incorrect API key provided","type":"authentication_error","param":null,"code":null}}`
|
||||
srv, _ := newServer(t, http.StatusUnauthorized, body)
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" {
|
||||
t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassPermanent {
|
||||
t.Errorf("Classify = %v, want permanent", got)
|
||||
}
|
||||
})
|
||||
t.Run("non-JSON body becomes message", func(t *testing.T) {
|
||||
srv, _ := newServer(t, http.StatusServiceUnavailable, "upstream exploded\n")
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusServiceUnavailable || apiErr.Message != "upstream exploded" {
|
||||
t.Errorf("Status/Message = %d/%q", apiErr.Status, apiErr.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMissingAPIKey(t *testing.T) {
|
||||
t.Setenv("OPENAI_API_KEY", "")
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m, err := New(WithBaseURL(srv.URL)).Model("gpt-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
_, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "missing_api_key" {
|
||||
t.Errorf("Status/Code = %d/%q, want 401/missing_api_key", apiErr.Status, apiErr.Code)
|
||||
}
|
||||
if rec.hits != 0 {
|
||||
t.Errorf("server hit %d times, want 0", rec.hits)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvAPIKeyReadAtConstruction(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
t.Setenv("OPENAI_API_KEY", "env-secret")
|
||||
p := New(WithBaseURL(srv.URL))
|
||||
t.Setenv("OPENAI_API_KEY", "changed-later") // must not affect p
|
||||
m, err := p.Model("gpt-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := rec.header.Get("Authorization"); got != "Bearer env-secret" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer env-secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAndContentTypeHeaders(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := rec.header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer test-key")
|
||||
}
|
||||
if got := rec.header.Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", got)
|
||||
}
|
||||
if rec.path != "/chat/completions" {
|
||||
t.Errorf("path = %q, want /chat/completions", rec.path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompatEndpointNameAndBaseURL(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
p := New(WithName("groq"), WithAPIKey("k"), WithBaseURL(srv.URL+"/openai/v1/"))
|
||||
if p.Name() != "groq" {
|
||||
t.Errorf("Name = %q, want groq", p.Name())
|
||||
}
|
||||
m, err := p.Model("llama-3.3-70b")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if rec.path != "/openai/v1/chat/completions" {
|
||||
t.Errorf("path = %q, want /openai/v1/chat/completions (trailing slash trimmed)", rec.path)
|
||||
}
|
||||
if resp.Model != "groq/llama-3.3-70b" {
|
||||
t.Errorf("Model = %q, want groq/llama-3.3-70b", resp.Model)
|
||||
}
|
||||
if rec.body["model"] != "llama-3.3-70b" {
|
||||
t.Errorf("wire model = %v, want llama-3.3-70b (verbatim)", rec.body["model"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapabilityEnforcement(t *testing.T) {
|
||||
img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) }
|
||||
tests := []struct {
|
||||
name string
|
||||
caps *llm.Capabilities // nil = provider defaults
|
||||
msg llm.Message
|
||||
}{
|
||||
{
|
||||
name: "images unsupported",
|
||||
caps: &llm.Capabilities{SupportsTools: true, SupportsStreaming: true},
|
||||
msg: llm.UserParts(img("image/png", 4)),
|
||||
},
|
||||
{
|
||||
name: "too many images",
|
||||
caps: &llm.Capabilities{MaxImagesPerReq: 1},
|
||||
msg: llm.UserParts(img("image/png", 4), img("image/png", 4)),
|
||||
},
|
||||
{
|
||||
name: "disallowed MIME under defaults",
|
||||
msg: llm.UserParts(img("image/bmp", 4)),
|
||||
},
|
||||
{
|
||||
name: "image too large",
|
||||
caps: &llm.Capabilities{MaxImagesPerReq: 4, MaxImageBytes: 2},
|
||||
msg: llm.UserParts(img("image/png", 3)),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
var mopts []llm.ModelOption
|
||||
if tt.caps != nil {
|
||||
mopts = append(mopts, llm.WithCapabilities(*tt.caps))
|
||||
}
|
||||
m := testModel(t, srv, nil, mopts...)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{tt.msg}})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassPermanent {
|
||||
t.Errorf("Classify = %v, want permanent", got)
|
||||
}
|
||||
if rec.hits != 0 {
|
||||
t.Errorf("server hit %d times, want 0 (must refuse before sending)", rec.hits)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("streaming unsupported", func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil, llm.WithCapabilities(llm.Capabilities{SupportsTools: true}))
|
||||
_, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if rec.hits != 0 {
|
||||
t.Errorf("server hit %d times, want 0", rec.hits)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelCapabilitiesOverride(t *testing.T) {
|
||||
p := New(WithAPIKey("k"))
|
||||
def, err := p.Model("a")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
if caps := def.Capabilities(); !caps.SupportsTools || caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 20<<20 {
|
||||
t.Errorf("default caps = %+v", caps)
|
||||
}
|
||||
custom := llm.Capabilities{SupportsStreaming: true, ContextWindow: 8192}
|
||||
ovr, err := p.Model("b", llm.WithCapabilities(custom))
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
if got := ovr.Capabilities(); !reflect.DeepEqual(got, custom) {
|
||||
t.Errorf("override caps = %+v, want %+v", got, custom)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportErrorIsNotAPIError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
url := srv.URL
|
||||
srv.Close() // guarantee connection refused
|
||||
p := New(WithAPIKey("k"), WithBaseURL(url))
|
||||
m, err := p.Model("gpt-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
_, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil {
|
||||
t.Fatal("Generate succeeded against closed server")
|
||||
}
|
||||
if _, ok := errors.AsType[*llm.APIError](err); ok {
|
||||
t.Errorf("transport error wrapped in APIError: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "openai: do request") {
|
||||
t.Errorf("err = %v, want openai: do request context", err)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassTransient {
|
||||
t.Errorf("Classify = %v, want transient (net error must stay visible)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeErrorWrapped(t *testing.T) {
|
||||
srv, _ := newServer(t, http.StatusOK, "{not json")
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil || !strings.Contains(err.Error(), "openai: decode response") {
|
||||
t.Errorf("err = %v, want decode response context", err)
|
||||
}
|
||||
if _, ok := errors.AsType[*llm.APIError](err); ok {
|
||||
t.Errorf("decode error wrapped in APIError: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// stream consumes the data-only SSE stream of chat.completion.chunk events.
|
||||
//
|
||||
// Delivery contract: TextDelta events as content fragments arrive; ToolCall
|
||||
// events only once fully assembled (fragments are buffered internally and
|
||||
// flushed at stream end — simplest correct handling of interleaved parallel
|
||||
// calls); exactly one final Response event; then io.EOF.
|
||||
type stream struct {
|
||||
m *model
|
||||
body io.ReadCloser
|
||||
sc *bufio.Scanner
|
||||
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
|
||||
queue []llm.StreamEvent
|
||||
done bool // finalize ran; drain queue then io.EOF
|
||||
|
||||
text strings.Builder
|
||||
calls []*toolCallAcc // first-appearance order
|
||||
byIndex map[int]*toolCallAcc
|
||||
finish string
|
||||
usage llm.Usage
|
||||
}
|
||||
|
||||
// toolCallAcc accumulates one tool call's fragments. The id and name arrive
|
||||
// on the first fragment for an index; arguments arrive as string pieces to
|
||||
// concatenate.
|
||||
type toolCallAcc struct {
|
||||
id string
|
||||
name string
|
||||
args strings.Builder
|
||||
}
|
||||
|
||||
// Next implements llm.Stream.
|
||||
func (s *stream) Next() (llm.StreamEvent, error) {
|
||||
for {
|
||||
if len(s.queue) > 0 {
|
||||
ev := s.queue[0]
|
||||
s.queue = s.queue[1:]
|
||||
return ev, nil
|
||||
}
|
||||
if s.done {
|
||||
return llm.StreamEvent{}, io.EOF
|
||||
}
|
||||
if !s.sc.Scan() {
|
||||
if err := s.sc.Err(); err != nil {
|
||||
return llm.StreamEvent{}, fmt.Errorf("openai: read stream: %w", err)
|
||||
}
|
||||
// Why: some compat servers close the body without a [DONE]
|
||||
// sentinel; a clean EOF still finalizes with what arrived.
|
||||
s.finalize()
|
||||
continue
|
||||
}
|
||||
line := strings.TrimSpace(s.sc.Text())
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue // SSE comments, event:/id: fields, blank separators
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
if payload == "[DONE]" {
|
||||
s.finalize()
|
||||
continue
|
||||
}
|
||||
if err := s.handleChunk([]byte(payload)); err != nil {
|
||||
return llm.StreamEvent{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleChunk folds one chat.completion.chunk into the stream state,
|
||||
// queueing any events it produces.
|
||||
func (s *stream) handleChunk(data []byte) error {
|
||||
var chunk streamChunk
|
||||
if err := json.Unmarshal(data, &chunk); err != nil {
|
||||
return fmt.Errorf("openai: decode stream chunk: %w", err)
|
||||
}
|
||||
if chunk.Error != nil {
|
||||
// Mid-stream error event on an otherwise-200 stream. Status stays 0:
|
||||
// there is no failing HTTP status to report.
|
||||
apiErr := &llm.APIError{
|
||||
Provider: s.m.p.name,
|
||||
Model: s.m.id,
|
||||
Code: chunk.Error.Code,
|
||||
Message: chunk.Error.Message,
|
||||
}
|
||||
if apiErr.Code == "" {
|
||||
apiErr.Code = chunk.Error.Type
|
||||
}
|
||||
return apiErr
|
||||
}
|
||||
if chunk.Usage != nil {
|
||||
s.usage = llm.Usage{
|
||||
InputTokens: chunk.Usage.PromptTokens,
|
||||
OutputTokens: chunk.Usage.CompletionTokens,
|
||||
}
|
||||
}
|
||||
// Why the guard: the include_usage chunk arrives with an EMPTY choices
|
||||
// array; indexing choices[0] unconditionally would panic on it.
|
||||
if len(chunk.Choices) == 0 {
|
||||
return nil
|
||||
}
|
||||
choice := chunk.Choices[0]
|
||||
if choice.FinishReason != "" {
|
||||
s.finish = choice.FinishReason
|
||||
}
|
||||
if choice.Delta.Content != "" {
|
||||
s.text.WriteString(choice.Delta.Content)
|
||||
s.queue = append(s.queue, llm.StreamEvent{TextDelta: choice.Delta.Content})
|
||||
}
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
acc := s.byIndex[tc.Index]
|
||||
if acc == nil {
|
||||
if s.byIndex == nil {
|
||||
s.byIndex = make(map[int]*toolCallAcc)
|
||||
}
|
||||
acc = &toolCallAcc{}
|
||||
s.byIndex[tc.Index] = acc
|
||||
s.calls = append(s.calls, acc)
|
||||
}
|
||||
if tc.ID != "" {
|
||||
acc.id = tc.ID
|
||||
}
|
||||
if tc.Function.Name != "" {
|
||||
acc.name = tc.Function.Name
|
||||
}
|
||||
acc.args.WriteString(tc.Function.Arguments)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// finalize assembles the buffered tool calls and the final Response, queues
|
||||
// them (ToolCall events first, Response last), and marks the stream done.
|
||||
func (s *stream) finalize() {
|
||||
if s.done {
|
||||
return
|
||||
}
|
||||
s.done = true
|
||||
resp := &llm.Response{Model: s.m.p.name + "/" + s.m.id, Usage: s.usage}
|
||||
if s.text.Len() > 0 {
|
||||
resp.Parts = []llm.Part{llm.TextPart{Text: s.text.String()}}
|
||||
}
|
||||
for i, acc := range s.calls {
|
||||
id := acc.id
|
||||
if id == "" {
|
||||
// Why: ToolResult.ID must echo ToolCall.ID; synthesize for
|
||||
// compat servers that stream calls without ids.
|
||||
id = fmt.Sprintf("call_%d", i)
|
||||
}
|
||||
resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{
|
||||
ID: id,
|
||||
Name: acc.name,
|
||||
Arguments: json.RawMessage(acc.args.String()),
|
||||
})
|
||||
}
|
||||
resp.FinishReason = mapFinish(s.finish, len(resp.ToolCalls) > 0)
|
||||
for i := range resp.ToolCalls {
|
||||
tc := resp.ToolCalls[i] // copy so the event doesn't alias the slice
|
||||
s.queue = append(s.queue, llm.StreamEvent{ToolCall: &tc})
|
||||
}
|
||||
s.queue = append(s.queue, llm.StreamEvent{Response: resp})
|
||||
}
|
||||
|
||||
// Close implements llm.Stream. Closing the body unblocks any in-flight read
|
||||
// and aborts the HTTP stream; safe to call at any time, including twice.
|
||||
func (s *stream) Close() error {
|
||||
s.closeOnce.Do(func() { s.closeErr = s.body.Close() })
|
||||
return s.closeErr
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// sseServer streams each payload as one "data: <payload>" SSE event and
|
||||
// records the request like newServer.
|
||||
func sseServer(t *testing.T, payloads ...string) (*httptest.Server, *recorded) {
|
||||
t.Helper()
|
||||
rec := &recorded{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rec.hits++
|
||||
rec.header = r.Header.Clone()
|
||||
rec.path = r.URL.Path
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read request body: %v", err)
|
||||
}
|
||||
if len(raw) > 0 {
|
||||
if err := json.Unmarshal(raw, &rec.body); err != nil {
|
||||
t.Errorf("request body is not JSON: %v\n%s", err, raw)
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
for _, p := range payloads {
|
||||
io.WriteString(w, "data: "+p+"\n\n")
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
return srv, rec
|
||||
}
|
||||
|
||||
// collect drains a stream to io.EOF, failing the test on any other error.
|
||||
func collect(t *testing.T, s llm.Stream) []llm.StreamEvent {
|
||||
t.Helper()
|
||||
var events []llm.StreamEvent
|
||||
for {
|
||||
ev, err := s.Next()
|
||||
if err == io.EOF {
|
||||
return events
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Next: %v", err)
|
||||
}
|
||||
events = append(events, ev)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamText(t *testing.T) {
|
||||
srv, rec := sseServer(t,
|
||||
`{"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-test","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"Hel"},"finish_reason":null}],"obfuscation":"xK9q"}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"lo"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
`{"choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
|
||||
`[DONE]`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
events := collect(t, s)
|
||||
|
||||
// Request shape: stream flag, usage opt-in, SSE accept header.
|
||||
if rec.body["stream"] != true {
|
||||
t.Errorf("stream = %v, want true", rec.body["stream"])
|
||||
}
|
||||
so, _ := rec.body["stream_options"].(map[string]any)
|
||||
if so == nil || so["include_usage"] != true {
|
||||
t.Errorf("stream_options = %v, want include_usage true", rec.body["stream_options"])
|
||||
}
|
||||
if got := rec.header.Get("Accept"); got != "text/event-stream" {
|
||||
t.Errorf("Accept = %q, want text/event-stream", got)
|
||||
}
|
||||
|
||||
if len(events) != 3 {
|
||||
t.Fatalf("got %d events, want 3: %+v", len(events), events)
|
||||
}
|
||||
if events[0].TextDelta != "Hel" || events[1].TextDelta != "lo" {
|
||||
t.Errorf("deltas = %q, %q, want Hel, lo", events[0].TextDelta, events[1].TextDelta)
|
||||
}
|
||||
final := events[2].Response
|
||||
if final == nil {
|
||||
t.Fatal("last event has no Response")
|
||||
}
|
||||
if got := final.Text(); got != "Hello" {
|
||||
t.Errorf("final text = %q, want Hello", got)
|
||||
}
|
||||
if final.FinishReason != llm.FinishStop {
|
||||
t.Errorf("FinishReason = %v, want stop", final.FinishReason)
|
||||
}
|
||||
if final.Usage != (llm.Usage{InputTokens: 5, OutputTokens: 2}) {
|
||||
t.Errorf("Usage = %+v, want {5 2}", final.Usage)
|
||||
}
|
||||
if final.Model != "openai/gpt-test" {
|
||||
t.Errorf("Model = %q, want openai/gpt-test", final.Model)
|
||||
}
|
||||
|
||||
// Next after EOF keeps returning EOF; Close is idempotent.
|
||||
if _, err := s.Next(); err != io.EOF {
|
||||
t.Errorf("Next after EOF = %v, want io.EOF", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("first Close: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("second Close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParallelToolCalls(t *testing.T) {
|
||||
// Two interleaved calls with distinct indexes; id/name only on the first
|
||||
// fragment of each; arguments split across fragments.
|
||||
srv, _ := sseServer(t,
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_a","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"city\":"}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_b","type":"function","function":{"name":"get_time","arguments":"{\"tz\":"}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Boston\"}"}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"\"EST\"}"}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
`{"choices":[],"usage":{"prompt_tokens":11,"completion_tokens":9,"total_tokens":20}}`,
|
||||
`[DONE]`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
events := collect(t, s)
|
||||
|
||||
if len(events) != 3 {
|
||||
t.Fatalf("got %d events, want 3 (two tool calls + response): %+v", len(events), events)
|
||||
}
|
||||
a, b := events[0].ToolCall, events[1].ToolCall
|
||||
if a == nil || b == nil {
|
||||
t.Fatalf("events 0/1 are not tool calls: %+v", events)
|
||||
}
|
||||
if a.ID != "call_a" || a.Name != "get_weather" || string(a.Arguments) != `{"city":"Boston"}` {
|
||||
t.Errorf("first call = %+v", a)
|
||||
}
|
||||
if b.ID != "call_b" || b.Name != "get_time" || string(b.Arguments) != `{"tz":"EST"}` {
|
||||
t.Errorf("second call = %+v", b)
|
||||
}
|
||||
final := events[2].Response
|
||||
if final == nil {
|
||||
t.Fatal("last event has no Response")
|
||||
}
|
||||
if len(final.ToolCalls) != 2 {
|
||||
t.Fatalf("final ToolCalls = %d, want 2", len(final.ToolCalls))
|
||||
}
|
||||
if final.ToolCalls[0].ID != "call_a" || final.ToolCalls[1].ID != "call_b" {
|
||||
t.Errorf("final ToolCalls order = %q, %q", final.ToolCalls[0].ID, final.ToolCalls[1].ID)
|
||||
}
|
||||
if final.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("FinishReason = %v, want tool_calls", final.FinishReason)
|
||||
}
|
||||
if final.Usage != (llm.Usage{InputTokens: 11, OutputTokens: 9}) {
|
||||
t.Errorf("Usage = %+v, want {11 9}", final.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamMidStreamError(t *testing.T) {
|
||||
srv, _ := sseServer(t,
|
||||
`{"choices":[{"index":0,"delta":{"content":"par"},"finish_reason":null}]}`,
|
||||
`{"error":{"message":"The server had an error while processing your request","type":"server_error","param":null,"code":null}}`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
ev, err := s.Next()
|
||||
if err != nil || ev.TextDelta != "par" {
|
||||
t.Fatalf("first event = %+v, %v; want TextDelta par", ev, err)
|
||||
}
|
||||
_, err = s.Next()
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Code != "server_error" {
|
||||
t.Errorf("Code = %q, want server_error", apiErr.Code)
|
||||
}
|
||||
if apiErr.Message != "The server had an error while processing your request" {
|
||||
t.Errorf("Message = %q", apiErr.Message)
|
||||
}
|
||||
if apiErr.Status != 0 {
|
||||
t.Errorf("Status = %d, want 0 (the HTTP stream was 200)", apiErr.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHTTPError(t *testing.T) {
|
||||
srv, _ := newServer(t, http.StatusTooManyRequests,
|
||||
`{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}`)
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError from Stream itself", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusTooManyRequests || apiErr.Code != "rate_limit_exceeded" {
|
||||
t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassTransient {
|
||||
t.Errorf("Classify = %v, want transient", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamWithoutDoneSentinel(t *testing.T) {
|
||||
// Why: some compat servers close the connection without "data: [DONE]";
|
||||
// a clean EOF must still produce the final Response.
|
||||
srv, _ := sseServer(t,
|
||||
`{"choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
events := collect(t, s)
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("got %d events, want 2: %+v", len(events), events)
|
||||
}
|
||||
final := events[1].Response
|
||||
if final == nil || final.Text() != "ok" || final.FinishReason != llm.FinishStop {
|
||||
t.Errorf("final = %+v", final)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamCloseEarly(t *testing.T) {
|
||||
srv, _ := sseServer(t,
|
||||
`{"choices":[{"index":0,"delta":{"content":"a"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"b"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
`[DONE]`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
if _, err := s.Next(); err != nil {
|
||||
t.Fatalf("Next: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("Close mid-stream: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("Close again: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,321 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// --- request wire shapes ---
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []wireMessage `json:"messages"`
|
||||
Tools []wireTool `json:"tools,omitempty"`
|
||||
// ToolChoice is "auto"/"none"/"required" (string) or a named-function
|
||||
// object; any avoids two fields for one wire key.
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
ResponseFormat *wireRespFormat `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *wireStreamOptions `json:"stream_options,omitempty"`
|
||||
}
|
||||
|
||||
type wireMessage struct {
|
||||
Role string `json:"role"`
|
||||
// Content is a string for text-only turns, a part array for multimodal
|
||||
// turns, or nil (wire null) for assistant turns that only call tools.
|
||||
Content any `json:"content"`
|
||||
ToolCalls []wireToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type wireTextPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type wireImagePart struct {
|
||||
Type string `json:"type"`
|
||||
ImageURL wireImageURL `json:"image_url"`
|
||||
}
|
||||
|
||||
type wireImageURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type wireToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function wireFunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
type wireFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
// Arguments is a JSON-encoded STRING per the wire format, not an object.
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type wireTool struct {
|
||||
Type string `json:"type"`
|
||||
Function wireToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
type wireToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type wireNamedToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Function wireToolName `json:"function"`
|
||||
}
|
||||
|
||||
type wireToolName struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type wireRespFormat struct {
|
||||
Type string `json:"type"`
|
||||
JSONSchema *wireJSONSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// wireJSONSchema omits the strict flag on purpose: strict mode imposes
|
||||
// schema rewrites (every property required, additionalProperties:false at
|
||||
// every level) that belong to the caller, not the transport.
|
||||
type wireJSONSchema struct {
|
||||
Name string `json:"name"`
|
||||
Schema json.RawMessage `json:"schema"`
|
||||
}
|
||||
|
||||
type wireStreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage"`
|
||||
}
|
||||
|
||||
// --- response wire shapes (loose: unknown fields ignored) ---
|
||||
|
||||
type chatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []chatChoice `json:"choices"`
|
||||
Usage *wireUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type chatChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message wireRespMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type wireRespMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"` // null decodes to ""
|
||||
Refusal string `json:"refusal"` // tolerated, unused
|
||||
ToolCalls []wireToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
type wireUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type errorEnvelope struct {
|
||||
Error wireError `json:"error"`
|
||||
}
|
||||
|
||||
type wireError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"` // null decodes to ""
|
||||
}
|
||||
|
||||
// --- streaming wire shapes ---
|
||||
|
||||
type streamChunk struct {
|
||||
Choices []streamChoice `json:"choices"`
|
||||
Usage *wireUsage `json:"usage"`
|
||||
Error *wireError `json:"error"` // mid-stream error event
|
||||
}
|
||||
|
||||
type streamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta streamDelta `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"` // null decodes to ""
|
||||
}
|
||||
|
||||
type streamDelta struct {
|
||||
Content string `json:"content"` // null decodes to ""
|
||||
ToolCalls []streamToolCallDelta `json:"tool_calls"`
|
||||
}
|
||||
|
||||
// streamToolCallDelta is one tool-call fragment. The id and name appear only
|
||||
// on a call's first fragment; later fragments carry just index + an
|
||||
// arguments substring. Accumulation keys on Index, never ID.
|
||||
type streamToolCallDelta struct {
|
||||
Index int `json:"index"`
|
||||
ID string `json:"id"`
|
||||
Function wireFunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// --- mapping: llm.Request -> chatRequest ---
|
||||
|
||||
// buildRequest translates the canonical request to the wire shape. The
|
||||
// capability check has already passed by the time this runs.
|
||||
func (m *model) buildRequest(req llm.Request, stream bool) *chatRequest {
|
||||
out := &chatRequest{
|
||||
Model: m.id,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stop: req.StopSequences,
|
||||
ReasoningEffort: req.ReasoningEffort,
|
||||
}
|
||||
|
||||
// Fold Request.System and every RoleSystem message into one leading
|
||||
// system message, System field first. Why: the canonical contract allows
|
||||
// system content in both places; OpenAI wants one system mechanism.
|
||||
var sys []string
|
||||
if req.System != "" {
|
||||
sys = append(sys, req.System)
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
if msg.Role == llm.RoleSystem {
|
||||
if t := msg.Text(); t != "" {
|
||||
sys = append(sys, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
if joined := strings.Join(sys, "\n\n"); joined != "" {
|
||||
out.Messages = append(out.Messages, wireMessage{Role: "system", Content: joined})
|
||||
}
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
switch msg.Role {
|
||||
case llm.RoleSystem:
|
||||
// Folded above; excluded from the normal message list.
|
||||
case llm.RoleUser:
|
||||
out.Messages = append(out.Messages, wireMessage{Role: "user", Content: contentValue(msg.Parts)})
|
||||
case llm.RoleAssistant:
|
||||
wm := wireMessage{Role: "assistant"}
|
||||
if text := msg.Text(); text != "" {
|
||||
wm.Content = text
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
args := string(tc.Arguments)
|
||||
if args == "" {
|
||||
// Why: arguments must be a valid JSON document string;
|
||||
// an empty string is not one.
|
||||
args = "{}"
|
||||
}
|
||||
wm.ToolCalls = append(wm.ToolCalls, wireToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Function: wireFunctionCall{Name: tc.Name, Arguments: args},
|
||||
})
|
||||
}
|
||||
out.Messages = append(out.Messages, wm)
|
||||
case llm.RoleTool:
|
||||
// One wire message per result: the API pairs each tool output
|
||||
// with its call via tool_call_id, one message each.
|
||||
for _, tr := range msg.ToolResults {
|
||||
content := tr.Content
|
||||
if tr.IsError {
|
||||
content = "ERROR: " + content
|
||||
}
|
||||
out.Messages = append(out.Messages, wireMessage{
|
||||
Role: "tool",
|
||||
Content: content,
|
||||
ToolCallID: tr.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range req.Tools {
|
||||
out.Tools = append(out.Tools, wireTool{
|
||||
Type: "function",
|
||||
Function: wireToolFunction{Name: t.Name, Description: t.Description, Parameters: t.Parameters},
|
||||
})
|
||||
}
|
||||
|
||||
switch req.ToolChoice {
|
||||
case "":
|
||||
// Omit: provider default ("auto" when tools are present).
|
||||
case "auto", "none", "required":
|
||||
out.ToolChoice = req.ToolChoice
|
||||
default:
|
||||
// Any other value names the one tool the model must call.
|
||||
out.ToolChoice = wireNamedToolChoice{Type: "function", Function: wireToolName{Name: req.ToolChoice}}
|
||||
}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
if m.p.legacyMaxTokens {
|
||||
out.MaxTokens = req.MaxTokens
|
||||
} else {
|
||||
out.MaxCompletionTokens = req.MaxTokens
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.Schema) > 0 {
|
||||
name := req.SchemaName
|
||||
if name == "" {
|
||||
name = "response"
|
||||
}
|
||||
out.ResponseFormat = &wireRespFormat{
|
||||
Type: "json_schema",
|
||||
JSONSchema: &wireJSONSchema{Name: name, Schema: req.Schema},
|
||||
}
|
||||
}
|
||||
|
||||
if stream {
|
||||
out.Stream = true
|
||||
// Why: without include_usage the stream never reports token counts;
|
||||
// the usage arrives in one extra chunk with an empty choices array.
|
||||
out.StreamOptions = &wireStreamOptions{IncludeUsage: true}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// contentValue renders message parts as the wire content value: a plain
|
||||
// string when text-only (maximum compat), a part array when images are
|
||||
// present.
|
||||
func contentValue(parts []llm.Part) any {
|
||||
multimodal := false
|
||||
for _, p := range parts {
|
||||
if _, ok := p.(llm.ImagePart); ok {
|
||||
multimodal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !multimodal {
|
||||
var b strings.Builder
|
||||
for _, p := range parts {
|
||||
if t, ok := p.(llm.TextPart); ok {
|
||||
b.WriteString(t.Text)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
out := make([]any, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
switch v := p.(type) {
|
||||
case llm.TextPart:
|
||||
out = append(out, wireTextPart{Type: "text", Text: v.Text})
|
||||
case llm.ImagePart:
|
||||
url := "data:" + v.MIME + ";base64," + base64.StdEncoding.EncodeToString(v.Data)
|
||||
out = append(out, wireImagePart{Type: "image_url", ImageURL: wireImageURL{URL: url}})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
+11
-1
@@ -2,6 +2,7 @@ package majordomo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -78,6 +79,7 @@ type registryConfig struct {
|
||||
envLookup func(string) string
|
||||
environ func() []string
|
||||
skipEnv bool
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// RegistryOption configures New.
|
||||
@@ -115,6 +117,14 @@ func WithoutEnvProviders() RegistryOption {
|
||||
return func(rc *registryConfig) { rc.skipEnv = true }
|
||||
}
|
||||
|
||||
// WithHTTPClient sets the HTTP client used by built-in providers and
|
||||
// env-DSN scheme factories created by this registry (proxies, custom TLS,
|
||||
// test servers). Providers registered explicitly via RegisterProvider keep
|
||||
// whatever client they were built with.
|
||||
func WithHTTPClient(c *http.Client) RegistryOption {
|
||||
return func(rc *registryConfig) { rc.httpClient = c }
|
||||
}
|
||||
|
||||
// New creates a Registry with all built-in providers and scheme factories
|
||||
// registered, then loads LLM_* env-DSN providers from the process
|
||||
// environment (unless WithoutEnvProviders is given). Malformed LLM_* entries
|
||||
@@ -139,7 +149,7 @@ func New(opts ...RegistryOption) *Registry {
|
||||
envLookup: cfg.envLookup,
|
||||
}
|
||||
|
||||
registerBuiltins(r)
|
||||
registerBuiltins(r, cfg.httpClient)
|
||||
|
||||
if !cfg.skipEnv {
|
||||
env := make(map[string]string)
|
||||
|
||||
Reference in New Issue
Block a user