Files
majordomo/provider/anthropic/anthropic_test.go
T
steve 043249e0e1 feat: OpenAI, Anthropic, and native-Ollama providers + media pipeline
Phase 3:
- provider/openai: Chat Completions for OpenAI + compat endpoints (SSE
  streaming with by-index tool-call assembly, response_format json_schema,
  legacy max_tokens option, reasoning_effort)
- provider/anthropic: Messages API (tool_use/tool_result, GA structured
  output via output_config.format, full SSE event parser, 529 transient)
- provider/ollama: one native /api/chat client behind the ollama,
  ollama-cloud, and foreman built-ins (presets; NDJSON streaming tolerant
  of foreman's buffered single-object responses; object tool arguments;
  format-schema structured output; think mapping)
- media/: capability normalization (sniff, downscale, transcode, byte
  ladder, ErrUnsupported), wired into the chain executor per target with
  penalty-free advance past incapable elements
- registry: real provider + scheme wiring, WithHTTPClient option, required
  env-foreman TLS chat round-trip test
- ADR-0009 multimodal strategy, ADR-0010 tools/structured mapping; README
  matrix + CLAUDE.md synced

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-10 12:58:08 +02:00

775 lines
25 KiB
Go

package anthropic
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
// okBody is a minimal successful Messages API response.
const okBody = `{
"id": "msg_01",
"type": "message",
"role": "assistant",
"model": "claude-test",
"content": [{"type": "text", "text": "ok"}],
"stop_reason": "end_turn",
"usage": {"input_tokens": 3, "output_tokens": 5}
}`
// capture records the last request the test server received.
type capture struct {
mu sync.Mutex
hits int
method string
path string
header http.Header
body []byte
}
func (c *capture) handler(status int, respBody string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
c.mu.Lock()
c.hits++
c.method = r.Method
c.path = r.URL.Path
c.header = r.Header.Clone()
c.body = body
c.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_, _ = w.Write([]byte(respBody))
}
}
// bodyMap decodes the captured request body for key-presence assertions.
func (c *capture) bodyMap(t *testing.T) map[string]any {
t.Helper()
c.mu.Lock()
defer c.mu.Unlock()
var m map[string]any
if err := json.Unmarshal(c.body, &m); err != nil {
t.Fatalf("decode captured body: %v\nbody: %s", err, c.body)
}
return m
}
// newTestProvider spins up an httptest server and a provider pointed at it.
func newTestProvider(t *testing.T, h http.Handler, opts ...Option) *Provider {
t.Helper()
srv := httptest.NewServer(h)
t.Cleanup(srv.Close)
return New(append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, opts...)...)
}
func mustModel(t *testing.T, p *Provider, id string, opts ...llm.ModelOption) llm.Model {
t.Helper()
m, err := p.Model(id, opts...)
if err != nil {
t.Fatalf("Model(%q): %v", id, err)
}
return m
}
func generate(t *testing.T, m llm.Model, req llm.Request, opts ...llm.Option) *llm.Response {
t.Helper()
resp, err := m.Generate(context.Background(), req, opts...)
if err != nil {
t.Fatalf("Generate: %v", err)
}
return resp
}
func TestRequestHeadersAndPath(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
m := mustModel(t, p, "claude-test")
generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if c.method != http.MethodPost {
t.Errorf("method = %q, want POST", c.method)
}
if c.path != "/v1/messages" {
t.Errorf("path = %q, want /v1/messages", c.path)
}
for header, want := range map[string]string{
"x-api-key": "test-key",
"anthropic-version": "2023-06-01",
"content-type": "application/json",
} {
if got := c.header.Get(header); got != want {
t.Errorf("header %s = %q, want %q", header, got, want)
}
}
}
func TestSystemFold(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
m := mustModel(t, p, "claude-test")
generate(t, m, llm.Request{
System: "base prompt",
Messages: []llm.Message{
llm.SystemText("first extra"),
llm.UserText("hi"),
llm.SystemText("second extra"),
},
})
body := c.bodyMap(t)
if got, want := body["system"], "base prompt\n\nfirst extra\n\nsecond extra"; got != want {
t.Errorf("system = %q, want %q", got, want)
}
msgs := body["messages"].([]any)
if len(msgs) != 1 {
t.Fatalf("messages length = %d, want 1 (system messages must be excluded)", len(msgs))
}
if role := msgs[0].(map[string]any)["role"]; role != "user" {
t.Errorf("remaining message role = %q, want user", role)
}
}
func TestNoSystemOmitsField(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
m := mustModel(t, p, "claude-test")
generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if _, ok := c.bodyMap(t)["system"]; ok {
t.Error("system key present, want omitted when empty")
}
}
func TestMaxTokens(t *testing.T) {
t.Run("default 4096", func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if got := c.bodyMap(t)["max_tokens"].(float64); got != 4096 {
t.Errorf("max_tokens = %v, want 4096", got)
}
})
t.Run("explicit wins", func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
generate(t, mustModel(t, p, "m"), llm.Request{
Messages: []llm.Message{llm.UserText("hi")},
MaxTokens: 123,
})
if got := c.bodyMap(t)["max_tokens"].(float64); got != 123 {
t.Errorf("max_tokens = %v, want 123", got)
}
})
t.Run("WithDefaultMaxTokens overrides default", func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithDefaultMaxTokens(99))
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if got := c.bodyMap(t)["max_tokens"].(float64); got != 99 {
t.Errorf("max_tokens = %v, want 99", got)
}
})
}
func TestImageBlock(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
m := mustModel(t, p, "claude-test")
raw := []byte{0x01, 0x02, 0x03}
generate(t, m, llm.Request{Messages: []llm.Message{
llm.UserParts(llm.Text("look at this"), llm.Image("image/png", raw)),
}})
msgs := c.bodyMap(t)["messages"].([]any)
content := msgs[0].(map[string]any)["content"].([]any)
if len(content) != 2 {
t.Fatalf("content blocks = %d, want 2", len(content))
}
img := content[1].(map[string]any)
if img["type"] != "image" {
t.Fatalf("block type = %v, want image", img["type"])
}
src := img["source"].(map[string]any)
if src["type"] != "base64" {
t.Errorf("source type = %v, want base64", src["type"])
}
if src["media_type"] != "image/png" {
t.Errorf("media_type = %v, want image/png", src["media_type"])
}
if want := base64.StdEncoding.EncodeToString(raw); src["data"] != want {
t.Errorf("data = %v, want %q", src["data"], want)
}
}
func TestToolUseToolResultRoundTrip(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
m := mustModel(t, p, "claude-test")
generate(t, m, llm.Request{Messages: []llm.Message{
llm.UserText("weather?"),
{
Role: llm.RoleAssistant,
Parts: []llm.Part{llm.Text("checking")},
ToolCalls: []llm.ToolCall{
{ID: "toolu_1", Name: "get_weather", Arguments: json.RawMessage(`{"location":"Paris"}`)},
{ID: "toolu_2", Name: "noop"}, // empty args must become {}
},
},
llm.ToolResultsMessage(
llm.ToolResult{ID: "toolu_1", Name: "get_weather", Content: "72F and sunny"},
llm.ToolResult{ID: "toolu_2", Name: "noop", Content: "boom", IsError: true},
),
}})
msgs := c.bodyMap(t)["messages"].([]any)
if len(msgs) != 3 {
t.Fatalf("messages = %d, want 3", len(msgs))
}
asst := msgs[1].(map[string]any)
if asst["role"] != "assistant" {
t.Errorf("messages[1].role = %v, want assistant", asst["role"])
}
asstContent := asst["content"].([]any)
if len(asstContent) != 3 {
t.Fatalf("assistant blocks = %d, want 3 (text + 2 tool_use)", len(asstContent))
}
tu := asstContent[1].(map[string]any)
if tu["type"] != "tool_use" || tu["id"] != "toolu_1" || tu["name"] != "get_weather" {
t.Errorf("tool_use block = %v", tu)
}
if loc := tu["input"].(map[string]any)["location"]; loc != "Paris" {
t.Errorf("tool_use input.location = %v, want Paris", loc)
}
if input := asstContent[2].(map[string]any)["input"].(map[string]any); len(input) != 0 {
t.Errorf("empty-args tool_use input = %v, want {}", input)
}
// RoleTool → ONE user message with one tool_result block per result.
toolMsg := msgs[2].(map[string]any)
if toolMsg["role"] != "user" {
t.Errorf("messages[2].role = %v, want user", toolMsg["role"])
}
results := toolMsg["content"].([]any)
if len(results) != 2 {
t.Fatalf("tool_result blocks = %d, want 2", len(results))
}
first := results[0].(map[string]any)
if first["type"] != "tool_result" || first["tool_use_id"] != "toolu_1" || first["content"] != "72F and sunny" {
t.Errorf("first tool_result = %v", first)
}
if _, ok := first["is_error"]; ok {
t.Error("first tool_result has is_error, want omitted when false")
}
second := results[1].(map[string]any)
if second["tool_use_id"] != "toolu_2" || second["is_error"] != true {
t.Errorf("second tool_result = %v, want is_error true", second)
}
}
func TestToolDefinitions(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
m := mustModel(t, p, "claude-test")
schema := json.RawMessage(`{"type":"object","properties":{"q":{"type":"string"}},"required":["q"]}`)
generate(t, m, llm.Request{
Messages: []llm.Message{llm.UserText("hi")},
Tools: []llm.Tool{
{Name: "search", Description: "Search the web.", Parameters: schema},
{Name: "ping"}, // nil Parameters → default empty object schema
},
})
tools := c.bodyMap(t)["tools"].([]any)
if len(tools) != 2 {
t.Fatalf("tools = %d, want 2", len(tools))
}
search := tools[0].(map[string]any)
if search["name"] != "search" || search["description"] != "Search the web." {
t.Errorf("tool[0] = %v", search)
}
if typ := search["input_schema"].(map[string]any)["type"]; typ != "object" {
t.Errorf("input_schema.type = %v, want object", typ)
}
ping := tools[1].(map[string]any)
if typ := ping["input_schema"].(map[string]any)["type"]; typ != "object" {
t.Errorf("nil-Parameters input_schema.type = %v, want object", typ)
}
}
func TestToolChoiceForms(t *testing.T) {
cases := []struct {
choice string
wantType string // "" means the field must be absent
wantName string
}{
{choice: "", wantType: ""},
{choice: "auto", wantType: "auto"},
{choice: "required", wantType: "any"},
{choice: "none", wantType: "none"},
{choice: "get_weather", wantType: "tool", wantName: "get_weather"},
}
for _, tc := range cases {
t.Run("choice="+tc.choice, func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
generate(t, mustModel(t, p, "m"), llm.Request{
Messages: []llm.Message{llm.UserText("hi")},
ToolChoice: tc.choice,
})
body := c.bodyMap(t)
raw, present := body["tool_choice"]
if tc.wantType == "" {
if present {
t.Fatalf("tool_choice present (%v), want omitted", raw)
}
return
}
choice := raw.(map[string]any)
if choice["type"] != tc.wantType {
t.Errorf("tool_choice.type = %v, want %q", choice["type"], tc.wantType)
}
if tc.wantName != "" && choice["name"] != tc.wantName {
t.Errorf("tool_choice.name = %v, want %q", choice["name"], tc.wantName)
}
})
}
}
func TestOutputConfigFormat(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
m := mustModel(t, p, "claude-test")
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}},"required":["name"],"additionalProperties":false}`)
generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}},
llm.WithSchema(schema, "person"))
body := c.bodyMap(t)
format := body["output_config"].(map[string]any)["format"].(map[string]any)
if format["type"] != "json_schema" {
t.Errorf("output_config.format.type = %v, want json_schema", format["type"])
}
// Normalize both sides through any → Marshal (sorted keys) to compare.
got, _ := json.Marshal(format["schema"])
var want any
_ = json.Unmarshal(schema, &want)
wantJSON, _ := json.Marshal(want)
if string(got) != string(wantJSON) {
t.Errorf("schema = %s, want %s", got, wantJSON)
}
}
func TestOutputConfigOmittedWithoutSchema(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if _, ok := c.bodyMap(t)["output_config"]; ok {
t.Error("output_config present, want omitted when Schema is nil")
}
}
func TestSamplingKnobs(t *testing.T) {
t.Run("omitted when unset", func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
body := c.bodyMap(t)
if _, ok := body["temperature"]; ok {
t.Error("temperature present, want omitted when unset")
}
if _, ok := body["top_p"]; ok {
t.Error("top_p present, want omitted when unset")
}
if _, ok := body["stop_sequences"]; ok {
t.Error("stop_sequences present, want omitted when unset")
}
})
t.Run("present when set", func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
generate(t, mustModel(t, p, "m"),
llm.Request{Messages: []llm.Message{llm.UserText("hi")}},
llm.WithTemperature(0), // explicit zero must still be sent
llm.WithTopP(0.9),
llm.WithStopSequences("END"))
body := c.bodyMap(t)
if got, ok := body["temperature"]; !ok || got.(float64) != 0 {
t.Errorf("temperature = %v (present=%v), want explicit 0", got, ok)
}
if got := body["top_p"].(float64); got != 0.9 {
t.Errorf("top_p = %v, want 0.9", got)
}
stops := body["stop_sequences"].([]any)
if len(stops) != 1 || stops[0] != "END" {
t.Errorf("stop_sequences = %v, want [END]", stops)
}
})
}
func TestStreamFieldOmittedOnGenerate(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if _, ok := c.bodyMap(t)["stream"]; ok {
t.Error("stream key present on Generate, want omitted")
}
}
func TestResponseParse(t *testing.T) {
const body = `{
"id": "msg_02",
"type": "message",
"role": "assistant",
"model": "claude-test",
"content": [
{"type": "thinking", "thinking": "pondering...", "signature": "sig"},
{"type": "text", "text": "I'll check the weather."},
{"type": "tool_use", "id": "toolu_9", "name": "get_weather", "input": {"location": "Paris"}}
],
"stop_reason": "tool_use",
"usage": {
"input_tokens": 3,
"output_tokens": 7,
"cache_creation_input_tokens": 10,
"cache_read_input_tokens": 20
}
}`
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, body))
resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if len(resp.Parts) != 1 {
t.Fatalf("parts = %d, want 1 (thinking blocks must be skipped)", len(resp.Parts))
}
if got := resp.Text(); got != "I'll check the weather." {
t.Errorf("text = %q", got)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("tool calls = %d, want 1", len(resp.ToolCalls))
}
call := resp.ToolCalls[0]
if call.ID != "toolu_9" || call.Name != "get_weather" {
t.Errorf("tool call = %+v", call)
}
var args map[string]any
if err := json.Unmarshal(call.Arguments, &args); err != nil || args["location"] != "Paris" {
t.Errorf("arguments = %s (err %v), want location Paris", call.Arguments, err)
}
if resp.FinishReason != llm.FinishToolCalls {
t.Errorf("finish = %q, want %q", resp.FinishReason, llm.FinishToolCalls)
}
// Total real input = input + cache_creation + cache_read.
if resp.Usage.InputTokens != 33 || resp.Usage.OutputTokens != 7 {
t.Errorf("usage = %+v, want {33 7}", resp.Usage)
}
if resp.Model != "anthropic/claude-test" {
t.Errorf("model = %q, want anthropic/claude-test", resp.Model)
}
if resp.Raw == nil {
t.Error("Raw = nil, want wire response")
}
}
func TestStopReasonMapping(t *testing.T) {
cases := map[string]llm.FinishReason{
"end_turn": llm.FinishStop,
"stop_sequence": llm.FinishStop,
"max_tokens": llm.FinishLength,
"model_context_window_exceeded": llm.FinishLength,
"tool_use": llm.FinishToolCalls,
"refusal": llm.FinishContentFilter,
"pause_turn": llm.FinishOther,
"some_future_reason": llm.FinishOther,
}
for stop, want := range cases {
if got := mapStopReason(stop); got != want {
t.Errorf("mapStopReason(%q) = %q, want %q", stop, got, want)
}
}
}
func TestHTTPErrorMapping(t *testing.T) {
cases := []struct {
name string
status int
body string
wantCode string
wantClass llm.ErrorClass
}{
{
name: "429 rate limit is transient",
status: http.StatusTooManyRequests,
body: `{"type":"error","error":{"type":"rate_limit_error","message":"slow down"}}`,
wantCode: "rate_limit_error", wantClass: llm.ClassTransient,
},
{
name: "529 overloaded is transient",
status: 529,
body: `{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`,
wantCode: "overloaded_error", wantClass: llm.ClassTransient,
},
{
name: "401 auth is permanent",
status: http.StatusUnauthorized,
body: `{"type":"error","error":{"type":"authentication_error","message":"invalid x-api-key"}}`,
wantCode: "authentication_error", wantClass: llm.ClassPermanent,
},
{
name: "404 is permanent",
status: http.StatusNotFound,
body: `{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`,
wantCode: "not_found_error", wantClass: llm.ClassPermanent,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(tc.status, tc.body))
_, err := mustModel(t, p, "claude-test").Generate(context.Background(),
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err == nil {
t.Fatal("Generate succeeded, want error")
}
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok {
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
}
if apiErr.Provider != "anthropic" || apiErr.Model != "claude-test" {
t.Errorf("provider/model = %s/%s", apiErr.Provider, apiErr.Model)
}
if apiErr.Status != tc.status {
t.Errorf("status = %d, want %d", apiErr.Status, tc.status)
}
if apiErr.Code != tc.wantCode {
t.Errorf("code = %q, want %q", apiErr.Code, tc.wantCode)
}
if apiErr.Message == "" {
t.Error("message empty, want provider message")
}
if got := llm.Classify(err); got != tc.wantClass {
t.Errorf("Classify = %v, want %v", got, tc.wantClass)
}
})
}
t.Run("404 unwraps to ErrModelNotFound", func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusNotFound,
`{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`))
_, err := mustModel(t, p, "missing").Generate(context.Background(),
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if !errors.Is(err, llm.ErrModelNotFound) {
t.Errorf("errors.Is(err, ErrModelNotFound) = false for %v", err)
}
})
t.Run("non-JSON error body falls back to raw text", func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusBadGateway, "upstream exploded"))
_, err := mustModel(t, p, "m").Generate(context.Background(),
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok {
t.Fatalf("error %T, want *llm.APIError", err)
}
if apiErr.Status != http.StatusBadGateway || apiErr.Message != "upstream exploded" {
t.Errorf("apiErr = %+v", apiErr)
}
})
}
func TestMissingAPIKey(t *testing.T) {
t.Setenv("ANTHROPIC_API_KEY", "") // isolate from any real environment
var c capture
srv := httptest.NewServer(c.handler(http.StatusOK, okBody))
t.Cleanup(srv.Close)
p := New(WithBaseURL(srv.URL)) // construction must not fail
_, err := mustModel(t, p, "claude-test").Generate(context.Background(),
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok {
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
}
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" {
t.Errorf("apiErr = %+v, want 401 authentication_error", apiErr)
}
if llm.Classify(err) != llm.ClassPermanent {
t.Error("missing key must classify permanent")
}
if c.hits != 0 {
t.Errorf("server hits = %d, want 0 (no request without a key)", c.hits)
}
}
func TestAPIKeyFromEnv(t *testing.T) {
t.Setenv("ANTHROPIC_API_KEY", "env-key")
var c capture
srv := httptest.NewServer(c.handler(http.StatusOK, okBody))
t.Cleanup(srv.Close)
p := New(WithBaseURL(srv.URL))
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if got := c.header.Get("x-api-key"); got != "env-key" {
t.Errorf("x-api-key = %q, want env-key", got)
}
}
func TestCapabilityEnforcement(t *testing.T) {
img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) }
cases := []struct {
name string
caps *llm.Capabilities // nil = provider defaults
req llm.Request
}{
{
name: "images unsupported",
caps: &llm.Capabilities{}, // MaxImagesPerReq 0 = no images
req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 4))}},
},
{
name: "too many images",
caps: &llm.Capabilities{MaxImagesPerReq: 1},
req: llm.Request{Messages: []llm.Message{
llm.UserParts(img("image/png", 4), img("image/png", 4)),
}},
},
{
name: "disallowed MIME",
req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/bmp", 4))}},
},
{
name: "image too large",
caps: &llm.Capabilities{MaxImagesPerReq: 1, MaxImageBytes: 2},
req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 3))}},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
var opts []llm.ModelOption
if tc.caps != nil {
opts = append(opts, llm.WithCapabilities(*tc.caps))
}
m := mustModel(t, p, "claude-test", opts...)
_, err := m.Generate(context.Background(), tc.req)
if !errors.Is(err, llm.ErrUnsupported) {
t.Errorf("Generate err = %v, want ErrUnsupported", err)
}
_, err = m.Stream(context.Background(), tc.req)
if !errors.Is(err, llm.ErrUnsupported) {
t.Errorf("Stream err = %v, want ErrUnsupported", err)
}
if c.hits != 0 {
t.Errorf("server hits = %d, want 0 (rejected before sending)", c.hits)
}
})
}
t.Run("within limits passes", func(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
generate(t, mustModel(t, p, "m"), llm.Request{
Messages: []llm.Message{llm.UserParts(llm.Text("ok"), img("image/jpeg", 16))},
})
if c.hits != 1 {
t.Errorf("server hits = %d, want 1", c.hits)
}
})
}
func TestCompatEndpointWithNameAndBaseURL(t *testing.T) {
var c capture
p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithName("compat"))
if p.Name() != "compat" {
t.Errorf("Name() = %q, want compat", p.Name())
}
resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if resp.Model != "compat/claude-test" {
t.Errorf("resp.Model = %q, want compat/claude-test", resp.Model)
}
var ec capture
pe := newTestProvider(t, ec.handler(http.StatusTooManyRequests,
`{"type":"error","error":{"type":"rate_limit_error","message":"x"}}`), WithName("compat"))
_, err := mustModel(t, pe, "m").Generate(context.Background(),
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok || apiErr.Provider != "compat" {
t.Errorf("error provider = %v, want compat (err %v)", apiErr, err)
}
}
func TestCapabilitiesDefaultsAndOverrides(t *testing.T) {
p := New(WithAPIKey("k"))
m := mustModel(t, p, "m")
caps := m.Capabilities()
if !caps.SupportsTools || !caps.SupportsStructured || !caps.SupportsStreaming {
t.Errorf("default feature flags = %+v, want all true", caps)
}
if caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 10<<20 || caps.MaxImageDimension != 8000 {
t.Errorf("default image limits = %+v", caps)
}
wantMIME := []string{"image/jpeg", "image/png", "image/gif", "image/webp"}
if len(caps.AllowedImageMIME) != len(wantMIME) {
t.Fatalf("AllowedImageMIME = %v, want %v", caps.AllowedImageMIME, wantMIME)
}
for i, mime := range wantMIME {
if caps.AllowedImageMIME[i] != mime {
t.Errorf("AllowedImageMIME[%d] = %q, want %q", i, caps.AllowedImageMIME[i], mime)
}
}
custom := llm.Capabilities{SupportsStreaming: true, MaxImagesPerReq: 1}
p2 := New(WithAPIKey("k"), WithDefaultCapabilities(custom))
if got := mustModel(t, p2, "m").Capabilities(); got.MaxImagesPerReq != 1 || got.SupportsTools {
t.Errorf("WithDefaultCapabilities not applied: %+v", got)
}
perModel := llm.Capabilities{SupportsTools: true}
if got := mustModel(t, p2, "m", llm.WithCapabilities(perModel)).Capabilities(); !got.SupportsTools || got.MaxImagesPerReq != 0 {
t.Errorf("per-model capabilities not applied: %+v", got)
}
}
func TestTransportErrorNotAPIError(t *testing.T) {
// Point at a server that is immediately closed: the connection failure
// must surface as a wrapped transport error, not *llm.APIError.
srv := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
url := srv.URL
srv.Close()
p := New(WithAPIKey("k"), WithBaseURL(url))
_, err := mustModel(t, p, "m").Generate(context.Background(),
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err == nil {
t.Fatal("Generate succeeded, want transport error")
}
if _, ok := errors.AsType[*llm.APIError](err); ok {
t.Errorf("transport error wrapped in APIError: %v", err)
}
if llm.Classify(err) != llm.ClassTransient {
t.Errorf("connection failure must classify transient: %v", err)
}
}