package anthropic import ( "context" "encoding/base64" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "sync" "testing" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" ) // okBody is a minimal successful Messages API response. const okBody = `{ "id": "msg_01", "type": "message", "role": "assistant", "model": "claude-test", "content": [{"type": "text", "text": "ok"}], "stop_reason": "end_turn", "usage": {"input_tokens": 3, "output_tokens": 5} }` // capture records the last request the test server received. type capture struct { mu sync.Mutex hits int method string path string header http.Header body []byte } func (c *capture) handler(status int, respBody string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) c.mu.Lock() c.hits++ c.method = r.Method c.path = r.URL.Path c.header = r.Header.Clone() c.body = body c.mu.Unlock() w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) _, _ = w.Write([]byte(respBody)) } } // bodyMap decodes the captured request body for key-presence assertions. func (c *capture) bodyMap(t *testing.T) map[string]any { t.Helper() c.mu.Lock() defer c.mu.Unlock() var m map[string]any if err := json.Unmarshal(c.body, &m); err != nil { t.Fatalf("decode captured body: %v\nbody: %s", err, c.body) } return m } // newTestProvider spins up an httptest server and a provider pointed at it. func newTestProvider(t *testing.T, h http.Handler, opts ...Option) *Provider { t.Helper() srv := httptest.NewServer(h) t.Cleanup(srv.Close) return New(append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, opts...)...) } func mustModel(t *testing.T, p *Provider, id string, opts ...llm.ModelOption) llm.Model { t.Helper() m, err := p.Model(id, opts...) if err != nil { t.Fatalf("Model(%q): %v", id, err) } return m } func generate(t *testing.T, m llm.Model, req llm.Request, opts ...llm.Option) *llm.Response { t.Helper() resp, err := m.Generate(context.Background(), req, opts...) if err != nil { t.Fatalf("Generate: %v", err) } return resp } func TestRequestHeadersAndPath(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) m := mustModel(t, p, "claude-test") generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if c.method != http.MethodPost { t.Errorf("method = %q, want POST", c.method) } if c.path != "/v1/messages" { t.Errorf("path = %q, want /v1/messages", c.path) } for header, want := range map[string]string{ "x-api-key": "test-key", "anthropic-version": "2023-06-01", "content-type": "application/json", } { if got := c.header.Get(header); got != want { t.Errorf("header %s = %q, want %q", header, got, want) } } } func TestSystemFold(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) m := mustModel(t, p, "claude-test") generate(t, m, llm.Request{ System: "base prompt", Messages: []llm.Message{ llm.SystemText("first extra"), llm.UserText("hi"), llm.SystemText("second extra"), }, }) body := c.bodyMap(t) if got, want := body["system"], "base prompt\n\nfirst extra\n\nsecond extra"; got != want { t.Errorf("system = %q, want %q", got, want) } msgs := body["messages"].([]any) if len(msgs) != 1 { t.Fatalf("messages length = %d, want 1 (system messages must be excluded)", len(msgs)) } if role := msgs[0].(map[string]any)["role"]; role != "user" { t.Errorf("remaining message role = %q, want user", role) } } func TestNoSystemOmitsField(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) m := mustModel(t, p, "claude-test") generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if _, ok := c.bodyMap(t)["system"]; ok { t.Error("system key present, want omitted when empty") } } func TestMaxTokens(t *testing.T) { t.Run("default 4096", func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if got := c.bodyMap(t)["max_tokens"].(float64); got != 4096 { t.Errorf("max_tokens = %v, want 4096", got) } }) t.Run("explicit wins", func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) generate(t, mustModel(t, p, "m"), llm.Request{ Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 123, }) if got := c.bodyMap(t)["max_tokens"].(float64); got != 123 { t.Errorf("max_tokens = %v, want 123", got) } }) t.Run("WithDefaultMaxTokens overrides default", func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithDefaultMaxTokens(99)) generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if got := c.bodyMap(t)["max_tokens"].(float64); got != 99 { t.Errorf("max_tokens = %v, want 99", got) } }) } func TestImageBlock(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) m := mustModel(t, p, "claude-test") raw := []byte{0x01, 0x02, 0x03} generate(t, m, llm.Request{Messages: []llm.Message{ llm.UserParts(llm.Text("look at this"), llm.Image("image/png", raw)), }}) msgs := c.bodyMap(t)["messages"].([]any) content := msgs[0].(map[string]any)["content"].([]any) if len(content) != 2 { t.Fatalf("content blocks = %d, want 2", len(content)) } img := content[1].(map[string]any) if img["type"] != "image" { t.Fatalf("block type = %v, want image", img["type"]) } src := img["source"].(map[string]any) if src["type"] != "base64" { t.Errorf("source type = %v, want base64", src["type"]) } if src["media_type"] != "image/png" { t.Errorf("media_type = %v, want image/png", src["media_type"]) } if want := base64.StdEncoding.EncodeToString(raw); src["data"] != want { t.Errorf("data = %v, want %q", src["data"], want) } } func TestToolUseToolResultRoundTrip(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) m := mustModel(t, p, "claude-test") generate(t, m, llm.Request{Messages: []llm.Message{ llm.UserText("weather?"), { Role: llm.RoleAssistant, Parts: []llm.Part{llm.Text("checking")}, ToolCalls: []llm.ToolCall{ {ID: "toolu_1", Name: "get_weather", Arguments: json.RawMessage(`{"location":"Paris"}`)}, {ID: "toolu_2", Name: "noop"}, // empty args must become {} }, }, llm.ToolResultsMessage( llm.ToolResult{ID: "toolu_1", Name: "get_weather", Content: "72F and sunny"}, llm.ToolResult{ID: "toolu_2", Name: "noop", Content: "boom", IsError: true}, ), }}) msgs := c.bodyMap(t)["messages"].([]any) if len(msgs) != 3 { t.Fatalf("messages = %d, want 3", len(msgs)) } asst := msgs[1].(map[string]any) if asst["role"] != "assistant" { t.Errorf("messages[1].role = %v, want assistant", asst["role"]) } asstContent := asst["content"].([]any) if len(asstContent) != 3 { t.Fatalf("assistant blocks = %d, want 3 (text + 2 tool_use)", len(asstContent)) } tu := asstContent[1].(map[string]any) if tu["type"] != "tool_use" || tu["id"] != "toolu_1" || tu["name"] != "get_weather" { t.Errorf("tool_use block = %v", tu) } if loc := tu["input"].(map[string]any)["location"]; loc != "Paris" { t.Errorf("tool_use input.location = %v, want Paris", loc) } if input := asstContent[2].(map[string]any)["input"].(map[string]any); len(input) != 0 { t.Errorf("empty-args tool_use input = %v, want {}", input) } // RoleTool → ONE user message with one tool_result block per result. toolMsg := msgs[2].(map[string]any) if toolMsg["role"] != "user" { t.Errorf("messages[2].role = %v, want user", toolMsg["role"]) } results := toolMsg["content"].([]any) if len(results) != 2 { t.Fatalf("tool_result blocks = %d, want 2", len(results)) } first := results[0].(map[string]any) if first["type"] != "tool_result" || first["tool_use_id"] != "toolu_1" || first["content"] != "72F and sunny" { t.Errorf("first tool_result = %v", first) } if _, ok := first["is_error"]; ok { t.Error("first tool_result has is_error, want omitted when false") } second := results[1].(map[string]any) if second["tool_use_id"] != "toolu_2" || second["is_error"] != true { t.Errorf("second tool_result = %v, want is_error true", second) } } func TestToolDefinitions(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) m := mustModel(t, p, "claude-test") schema := json.RawMessage(`{"type":"object","properties":{"q":{"type":"string"}},"required":["q"]}`) generate(t, m, llm.Request{ Messages: []llm.Message{llm.UserText("hi")}, Tools: []llm.Tool{ {Name: "search", Description: "Search the web.", Parameters: schema}, {Name: "ping"}, // nil Parameters → default empty object schema }, }) tools := c.bodyMap(t)["tools"].([]any) if len(tools) != 2 { t.Fatalf("tools = %d, want 2", len(tools)) } search := tools[0].(map[string]any) if search["name"] != "search" || search["description"] != "Search the web." { t.Errorf("tool[0] = %v", search) } if typ := search["input_schema"].(map[string]any)["type"]; typ != "object" { t.Errorf("input_schema.type = %v, want object", typ) } ping := tools[1].(map[string]any) if typ := ping["input_schema"].(map[string]any)["type"]; typ != "object" { t.Errorf("nil-Parameters input_schema.type = %v, want object", typ) } } func TestToolChoiceForms(t *testing.T) { cases := []struct { choice string wantType string // "" means the field must be absent wantName string }{ {choice: "", wantType: ""}, {choice: "auto", wantType: "auto"}, {choice: "required", wantType: "any"}, {choice: "none", wantType: "none"}, {choice: "get_weather", wantType: "tool", wantName: "get_weather"}, } for _, tc := range cases { t.Run("choice="+tc.choice, func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) generate(t, mustModel(t, p, "m"), llm.Request{ Messages: []llm.Message{llm.UserText("hi")}, ToolChoice: tc.choice, }) body := c.bodyMap(t) raw, present := body["tool_choice"] if tc.wantType == "" { if present { t.Fatalf("tool_choice present (%v), want omitted", raw) } return } choice := raw.(map[string]any) if choice["type"] != tc.wantType { t.Errorf("tool_choice.type = %v, want %q", choice["type"], tc.wantType) } if tc.wantName != "" && choice["name"] != tc.wantName { t.Errorf("tool_choice.name = %v, want %q", choice["name"], tc.wantName) } }) } } func TestOutputConfigFormat(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) m := mustModel(t, p, "claude-test") schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}},"required":["name"],"additionalProperties":false}`) generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}}, llm.WithSchema(schema, "person")) body := c.bodyMap(t) format := body["output_config"].(map[string]any)["format"].(map[string]any) if format["type"] != "json_schema" { t.Errorf("output_config.format.type = %v, want json_schema", format["type"]) } // Normalize both sides through any → Marshal (sorted keys) to compare. got, _ := json.Marshal(format["schema"]) var want any _ = json.Unmarshal(schema, &want) wantJSON, _ := json.Marshal(want) if string(got) != string(wantJSON) { t.Errorf("schema = %s, want %s", got, wantJSON) } } func TestOutputConfigOmittedWithoutSchema(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if _, ok := c.bodyMap(t)["output_config"]; ok { t.Error("output_config present, want omitted when Schema is nil") } } func TestSamplingKnobs(t *testing.T) { t.Run("omitted when unset", func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) body := c.bodyMap(t) if _, ok := body["temperature"]; ok { t.Error("temperature present, want omitted when unset") } if _, ok := body["top_p"]; ok { t.Error("top_p present, want omitted when unset") } if _, ok := body["stop_sequences"]; ok { t.Error("stop_sequences present, want omitted when unset") } }) t.Run("present when set", func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}, llm.WithTemperature(0), // explicit zero must still be sent llm.WithTopP(0.9), llm.WithStopSequences("END")) body := c.bodyMap(t) if got, ok := body["temperature"]; !ok || got.(float64) != 0 { t.Errorf("temperature = %v (present=%v), want explicit 0", got, ok) } if got := body["top_p"].(float64); got != 0.9 { t.Errorf("top_p = %v, want 0.9", got) } stops := body["stop_sequences"].([]any) if len(stops) != 1 || stops[0] != "END" { t.Errorf("stop_sequences = %v, want [END]", stops) } }) } func TestStreamFieldOmittedOnGenerate(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if _, ok := c.bodyMap(t)["stream"]; ok { t.Error("stream key present on Generate, want omitted") } } func TestResponseParse(t *testing.T) { const body = `{ "id": "msg_02", "type": "message", "role": "assistant", "model": "claude-test", "content": [ {"type": "thinking", "thinking": "pondering...", "signature": "sig"}, {"type": "text", "text": "I'll check the weather."}, {"type": "tool_use", "id": "toolu_9", "name": "get_weather", "input": {"location": "Paris"}} ], "stop_reason": "tool_use", "usage": { "input_tokens": 3, "output_tokens": 7, "cache_creation_input_tokens": 10, "cache_read_input_tokens": 20 } }` var c capture p := newTestProvider(t, c.handler(http.StatusOK, body)) resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if len(resp.Parts) != 1 { t.Fatalf("parts = %d, want 1 (thinking blocks must be skipped)", len(resp.Parts)) } if got := resp.Text(); got != "I'll check the weather." { t.Errorf("text = %q", got) } if len(resp.ToolCalls) != 1 { t.Fatalf("tool calls = %d, want 1", len(resp.ToolCalls)) } call := resp.ToolCalls[0] if call.ID != "toolu_9" || call.Name != "get_weather" { t.Errorf("tool call = %+v", call) } var args map[string]any if err := json.Unmarshal(call.Arguments, &args); err != nil || args["location"] != "Paris" { t.Errorf("arguments = %s (err %v), want location Paris", call.Arguments, err) } if resp.FinishReason != llm.FinishToolCalls { t.Errorf("finish = %q, want %q", resp.FinishReason, llm.FinishToolCalls) } // Total real input = input + cache_creation + cache_read. if resp.Usage.InputTokens != 33 || resp.Usage.OutputTokens != 7 { t.Errorf("usage = %+v, want {33 7}", resp.Usage) } if resp.Model != "anthropic/claude-test" { t.Errorf("model = %q, want anthropic/claude-test", resp.Model) } if resp.Raw == nil { t.Error("Raw = nil, want wire response") } } func TestStopReasonMapping(t *testing.T) { cases := map[string]llm.FinishReason{ "end_turn": llm.FinishStop, "stop_sequence": llm.FinishStop, "max_tokens": llm.FinishLength, "model_context_window_exceeded": llm.FinishLength, "tool_use": llm.FinishToolCalls, "refusal": llm.FinishContentFilter, "pause_turn": llm.FinishOther, "some_future_reason": llm.FinishOther, } for stop, want := range cases { if got := mapStopReason(stop); got != want { t.Errorf("mapStopReason(%q) = %q, want %q", stop, got, want) } } } func TestHTTPErrorMapping(t *testing.T) { cases := []struct { name string status int body string wantCode string wantClass llm.ErrorClass }{ { name: "429 rate limit is transient", status: http.StatusTooManyRequests, body: `{"type":"error","error":{"type":"rate_limit_error","message":"slow down"}}`, wantCode: "rate_limit_error", wantClass: llm.ClassTransient, }, { name: "529 overloaded is transient", status: 529, body: `{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`, wantCode: "overloaded_error", wantClass: llm.ClassTransient, }, { name: "401 auth is permanent", status: http.StatusUnauthorized, body: `{"type":"error","error":{"type":"authentication_error","message":"invalid x-api-key"}}`, wantCode: "authentication_error", wantClass: llm.ClassPermanent, }, { name: "404 is permanent", status: http.StatusNotFound, body: `{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`, wantCode: "not_found_error", wantClass: llm.ClassPermanent, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(tc.status, tc.body)) _, err := mustModel(t, p, "claude-test").Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err == nil { t.Fatal("Generate succeeded, want error") } apiErr, ok := errors.AsType[*llm.APIError](err) if !ok { t.Fatalf("error %T (%v), want *llm.APIError", err, err) } if apiErr.Provider != "anthropic" || apiErr.Model != "claude-test" { t.Errorf("provider/model = %s/%s", apiErr.Provider, apiErr.Model) } if apiErr.Status != tc.status { t.Errorf("status = %d, want %d", apiErr.Status, tc.status) } if apiErr.Code != tc.wantCode { t.Errorf("code = %q, want %q", apiErr.Code, tc.wantCode) } if apiErr.Message == "" { t.Error("message empty, want provider message") } if got := llm.Classify(err); got != tc.wantClass { t.Errorf("Classify = %v, want %v", got, tc.wantClass) } }) } t.Run("404 unwraps to ErrModelNotFound", func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusNotFound, `{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`)) _, err := mustModel(t, p, "missing").Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if !errors.Is(err, llm.ErrModelNotFound) { t.Errorf("errors.Is(err, ErrModelNotFound) = false for %v", err) } }) t.Run("non-JSON error body falls back to raw text", func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusBadGateway, "upstream exploded")) _, err := mustModel(t, p, "m").Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) apiErr, ok := errors.AsType[*llm.APIError](err) if !ok { t.Fatalf("error %T, want *llm.APIError", err) } if apiErr.Status != http.StatusBadGateway || apiErr.Message != "upstream exploded" { t.Errorf("apiErr = %+v", apiErr) } }) } func TestMissingAPIKey(t *testing.T) { t.Setenv("ANTHROPIC_API_KEY", "") // isolate from any real environment var c capture srv := httptest.NewServer(c.handler(http.StatusOK, okBody)) t.Cleanup(srv.Close) p := New(WithBaseURL(srv.URL)) // construction must not fail _, err := mustModel(t, p, "claude-test").Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) apiErr, ok := errors.AsType[*llm.APIError](err) if !ok { t.Fatalf("error %T (%v), want *llm.APIError", err, err) } if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" { t.Errorf("apiErr = %+v, want 401 authentication_error", apiErr) } if llm.Classify(err) != llm.ClassPermanent { t.Error("missing key must classify permanent") } if c.hits != 0 { t.Errorf("server hits = %d, want 0 (no request without a key)", c.hits) } } func TestAPIKeyFromEnv(t *testing.T) { t.Setenv("ANTHROPIC_API_KEY", "env-key") var c capture srv := httptest.NewServer(c.handler(http.StatusOK, okBody)) t.Cleanup(srv.Close) p := New(WithBaseURL(srv.URL)) generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if got := c.header.Get("x-api-key"); got != "env-key" { t.Errorf("x-api-key = %q, want env-key", got) } } func TestCapabilityEnforcement(t *testing.T) { img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) } cases := []struct { name string caps *llm.Capabilities // nil = provider defaults req llm.Request }{ { name: "images unsupported", caps: &llm.Capabilities{}, // MaxImagesPerReq 0 = no images req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 4))}}, }, { name: "too many images", caps: &llm.Capabilities{MaxImagesPerReq: 1}, req: llm.Request{Messages: []llm.Message{ llm.UserParts(img("image/png", 4), img("image/png", 4)), }}, }, { name: "disallowed MIME", req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/bmp", 4))}}, }, { name: "image too large", caps: &llm.Capabilities{MaxImagesPerReq: 1, MaxImageBytes: 2}, req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 3))}}, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) var opts []llm.ModelOption if tc.caps != nil { opts = append(opts, llm.WithCapabilities(*tc.caps)) } m := mustModel(t, p, "claude-test", opts...) _, err := m.Generate(context.Background(), tc.req) if !errors.Is(err, llm.ErrUnsupported) { t.Errorf("Generate err = %v, want ErrUnsupported", err) } _, err = m.Stream(context.Background(), tc.req) if !errors.Is(err, llm.ErrUnsupported) { t.Errorf("Stream err = %v, want ErrUnsupported", err) } if c.hits != 0 { t.Errorf("server hits = %d, want 0 (rejected before sending)", c.hits) } }) } t.Run("within limits passes", func(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody)) generate(t, mustModel(t, p, "m"), llm.Request{ Messages: []llm.Message{llm.UserParts(llm.Text("ok"), img("image/jpeg", 16))}, }) if c.hits != 1 { t.Errorf("server hits = %d, want 1", c.hits) } }) } func TestCompatEndpointWithNameAndBaseURL(t *testing.T) { var c capture p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithName("compat")) if p.Name() != "compat" { t.Errorf("Name() = %q, want compat", p.Name()) } resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if resp.Model != "compat/claude-test" { t.Errorf("resp.Model = %q, want compat/claude-test", resp.Model) } var ec capture pe := newTestProvider(t, ec.handler(http.StatusTooManyRequests, `{"type":"error","error":{"type":"rate_limit_error","message":"x"}}`), WithName("compat")) _, err := mustModel(t, pe, "m").Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) apiErr, ok := errors.AsType[*llm.APIError](err) if !ok || apiErr.Provider != "compat" { t.Errorf("error provider = %v, want compat (err %v)", apiErr, err) } } func TestCapabilitiesDefaultsAndOverrides(t *testing.T) { p := New(WithAPIKey("k")) m := mustModel(t, p, "m") caps := m.Capabilities() if !caps.SupportsTools || !caps.SupportsStructured || !caps.SupportsStreaming { t.Errorf("default feature flags = %+v, want all true", caps) } if caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 10<<20 || caps.MaxImageDimension != 8000 { t.Errorf("default image limits = %+v", caps) } wantMIME := []string{"image/jpeg", "image/png", "image/gif", "image/webp"} if len(caps.AllowedImageMIME) != len(wantMIME) { t.Fatalf("AllowedImageMIME = %v, want %v", caps.AllowedImageMIME, wantMIME) } for i, mime := range wantMIME { if caps.AllowedImageMIME[i] != mime { t.Errorf("AllowedImageMIME[%d] = %q, want %q", i, caps.AllowedImageMIME[i], mime) } } custom := llm.Capabilities{SupportsStreaming: true, MaxImagesPerReq: 1} p2 := New(WithAPIKey("k"), WithDefaultCapabilities(custom)) if got := mustModel(t, p2, "m").Capabilities(); got.MaxImagesPerReq != 1 || got.SupportsTools { t.Errorf("WithDefaultCapabilities not applied: %+v", got) } perModel := llm.Capabilities{SupportsTools: true} if got := mustModel(t, p2, "m", llm.WithCapabilities(perModel)).Capabilities(); !got.SupportsTools || got.MaxImagesPerReq != 0 { t.Errorf("per-model capabilities not applied: %+v", got) } } func TestTransportErrorNotAPIError(t *testing.T) { // Point at a server that is immediately closed: the connection failure // must surface as a wrapped transport error, not *llm.APIError. srv := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) url := srv.URL srv.Close() p := New(WithAPIKey("k"), WithBaseURL(url)) _, err := mustModel(t, p, "m").Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err == nil { t.Fatal("Generate succeeded, want transport error") } if _, ok := errors.AsType[*llm.APIError](err); ok { t.Errorf("transport error wrapped in APIError: %v", err) } if llm.Classify(err) != llm.ClassTransient { t.Errorf("connection failure must classify transient: %v", err) } }