feat: OpenAI, Anthropic, and native-Ollama providers + media pipeline
Phase 3: - provider/openai: Chat Completions for OpenAI + compat endpoints (SSE streaming with by-index tool-call assembly, response_format json_schema, legacy max_tokens option, reasoning_effort) - provider/anthropic: Messages API (tool_use/tool_result, GA structured output via output_config.format, full SSE event parser, 529 transient) - provider/ollama: one native /api/chat client behind the ollama, ollama-cloud, and foreman built-ins (presets; NDJSON streaming tolerant of foreman's buffered single-object responses; object tool arguments; format-schema structured output; think mapping) - media/: capability normalization (sniff, downscale, transcode, byte ladder, ErrUnsupported), wired into the chain executor per target with penalty-free advance past incapable elements - registry: real provider + scheme wiring, WithHTTPClient option, required env-foreman TLS chat round-trip test - ADR-0009 multimodal strategy, ADR-0010 tools/structured mapping; README matrix + CLAUDE.md synced Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,319 @@
|
||||
// Package anthropic implements llm.Provider for the Anthropic Messages API
|
||||
// and Anthropic-compatible endpoints.
|
||||
//
|
||||
// API surface targeted: POST {base}/v1/messages with headers x-api-key,
|
||||
// anthropic-version: 2023-06-01, and content-type: application/json, per the
|
||||
// platform.claude.com Messages API reference as of June 2026. Streaming uses
|
||||
// the documented SSE event sequence (message_start, content_block_start,
|
||||
// content_block_delta, content_block_stop, message_delta, message_stop).
|
||||
// Structured output uses the GA output_config.format mechanism with
|
||||
// {"type":"json_schema"}; the result arrives as JSON text in the first text
|
||||
// content block.
|
||||
//
|
||||
// Why a hand-rolled client (no SDK): ADR-0007 — majordomo is stdlib-first,
|
||||
// and the canonical llm contract needs only a narrow slice of the API.
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultName = "anthropic"
|
||||
defaultBaseURL = "https://api.anthropic.com"
|
||||
|
||||
// apiVersion is the anthropic-version header value. 2023-06-01 remains
|
||||
// the current (and only) stable version string as of June 2026.
|
||||
apiVersion = "2023-06-01"
|
||||
|
||||
// defaultMaxTokens is used when Request.MaxTokens is 0, because the
|
||||
// Messages API requires max_tokens on every request.
|
||||
defaultMaxTokens = 4096
|
||||
)
|
||||
|
||||
// defaultCapabilities reflects the documented first-party API image limits:
|
||||
// 100 images per request (200K-context models), 10 MB per image, 8000 px per
|
||||
// side, and the four supported media types.
|
||||
func defaultCapabilities() llm.Capabilities {
|
||||
return llm.Capabilities{
|
||||
SupportsTools: true,
|
||||
SupportsStructured: true,
|
||||
SupportsStreaming: true,
|
||||
MaxImagesPerReq: 100,
|
||||
MaxImageBytes: 10 << 20,
|
||||
MaxImageDimension: 8000,
|
||||
AllowedImageMIME: []string{
|
||||
"image/jpeg", "image/png", "image/gif", "image/webp",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Provider is an llm.Provider backed by the Anthropic Messages API.
|
||||
type Provider struct {
|
||||
name string
|
||||
apiKey string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
caps llm.Capabilities
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
// Option configures the provider at construction.
|
||||
type Option func(*Provider)
|
||||
|
||||
// WithAPIKey sets the API key explicitly, bypassing the ANTHROPIC_API_KEY
|
||||
// environment default.
|
||||
func WithAPIKey(key string) Option {
|
||||
return func(p *Provider) { p.apiKey = key }
|
||||
}
|
||||
|
||||
// WithBaseURL points the provider at an Anthropic-compatible endpoint. A
|
||||
// trailing slash is trimmed; "/v1/messages" is appended per request.
|
||||
func WithBaseURL(u string) Option {
|
||||
return func(p *Provider) { p.baseURL = strings.TrimRight(u, "/") }
|
||||
}
|
||||
|
||||
// WithHTTPClient replaces the HTTP client (timeouts, proxies, test doubles).
|
||||
func WithHTTPClient(c *http.Client) Option {
|
||||
return func(p *Provider) { p.client = c }
|
||||
}
|
||||
|
||||
// WithName overrides the registry name. Why: an Anthropic-compatible
|
||||
// endpoint registered under its own name must surface that name in
|
||||
// Response.Model and errors, not "anthropic".
|
||||
func WithName(name string) Option {
|
||||
return func(p *Provider) { p.name = name }
|
||||
}
|
||||
|
||||
// WithDefaultCapabilities replaces the provider-default capabilities.
|
||||
func WithDefaultCapabilities(caps llm.Capabilities) Option {
|
||||
return func(p *Provider) { p.caps = caps }
|
||||
}
|
||||
|
||||
// WithDefaultMaxTokens overrides the max_tokens value used when
|
||||
// Request.MaxTokens is 0. Why: the Messages API rejects requests without
|
||||
// max_tokens, so the provider must always send something.
|
||||
func WithDefaultMaxTokens(n int) Option {
|
||||
return func(p *Provider) { p.maxTokens = n }
|
||||
}
|
||||
|
||||
// New creates an Anthropic provider. It never fails: a missing API key
|
||||
// (no WithAPIKey and no ANTHROPIC_API_KEY in the environment) surfaces as a
|
||||
// 401-style *llm.APIError at request time, not at construction.
|
||||
func New(opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
name: defaultName,
|
||||
baseURL: defaultBaseURL,
|
||||
client: http.DefaultClient,
|
||||
caps: defaultCapabilities(),
|
||||
maxTokens: defaultMaxTokens,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
if p.apiKey == "" {
|
||||
p.apiKey = os.Getenv("ANTHROPIC_API_KEY")
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Name implements llm.Provider.
|
||||
func (p *Provider) Name() string { return p.name }
|
||||
|
||||
// Model implements llm.Provider. The id is passed through verbatim — it is
|
||||
// never validated against a catalog.
|
||||
func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) {
|
||||
cfg := llm.ApplyModelOptions(opts)
|
||||
caps := p.caps
|
||||
if cfg.Capabilities != nil {
|
||||
caps = *cfg.Capabilities
|
||||
}
|
||||
return &model{provider: p, id: id, caps: caps}, nil
|
||||
}
|
||||
|
||||
type model struct {
|
||||
provider *Provider
|
||||
id string
|
||||
caps llm.Capabilities
|
||||
}
|
||||
|
||||
// Capabilities implements llm.Model.
|
||||
func (m *model) Capabilities() llm.Capabilities { return m.caps }
|
||||
|
||||
// fullName is the "provider/model" identifier used in Response.Model.
|
||||
func (m *model) fullName() string { return m.provider.name + "/" + m.id }
|
||||
|
||||
// Generate implements llm.Model.
|
||||
func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) {
|
||||
req = req.Apply(opts...)
|
||||
if err := m.enforceCapabilities(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpResp, err := m.do(ctx, req, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
if httpResp.StatusCode/100 != 2 {
|
||||
return nil, m.apiError(httpResp)
|
||||
}
|
||||
var wr wireResponse
|
||||
if err := json.NewDecoder(httpResp.Body).Decode(&wr); err != nil {
|
||||
return nil, fmt.Errorf("%s: decode response: %w", m.provider.name, err)
|
||||
}
|
||||
return m.toResponse(&wr), nil
|
||||
}
|
||||
|
||||
// Stream implements llm.Model. A non-2xx status is returned as an error from
|
||||
// Stream itself, before any events are delivered.
|
||||
func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) {
|
||||
req = req.Apply(opts...)
|
||||
if err := m.enforceCapabilities(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpResp, err := m.do(ctx, req, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if httpResp.StatusCode/100 != 2 {
|
||||
defer httpResp.Body.Close()
|
||||
return nil, m.apiError(httpResp)
|
||||
}
|
||||
return newStream(m, httpResp.Body), nil
|
||||
}
|
||||
|
||||
// enforceCapabilities is the honest backstop behind the media layer: it
|
||||
// rejects (rather than silently mutates) requests the target cannot serve.
|
||||
// Why: a separate media layer resizes/transcodes images BEFORE requests
|
||||
// reach the provider, so anything still out of bounds here is a real error.
|
||||
func (m *model) enforceCapabilities(req llm.Request) error {
|
||||
images := 0
|
||||
for _, msg := range req.Messages {
|
||||
for _, part := range msg.Parts {
|
||||
img, ok := part.(llm.ImagePart)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
images++
|
||||
if !m.caps.SupportsImages() {
|
||||
return fmt.Errorf("%w: %s does not accept image input", llm.ErrUnsupported, m.fullName())
|
||||
}
|
||||
if !m.caps.MIMEAllowed(img.MIME) {
|
||||
return fmt.Errorf("%w: %s does not accept image MIME %q", llm.ErrUnsupported, m.fullName(), img.MIME)
|
||||
}
|
||||
if m.caps.MaxImageBytes > 0 && len(img.Data) > m.caps.MaxImageBytes {
|
||||
return fmt.Errorf("%w: image of %d bytes exceeds %s limit of %d bytes",
|
||||
llm.ErrUnsupported, len(img.Data), m.fullName(), m.caps.MaxImageBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
if m.caps.MaxImagesPerReq > 0 && images > m.caps.MaxImagesPerReq {
|
||||
return fmt.Errorf("%w: request carries %d images, %s allows at most %d",
|
||||
llm.ErrUnsupported, images, m.fullName(), m.caps.MaxImagesPerReq)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// do builds and executes one Messages API call. Transport errors are wrapped
|
||||
// with context but NOT converted to *llm.APIError, so llm.Classify still
|
||||
// sees the underlying net.Error / syscall errno.
|
||||
func (m *model) do(ctx context.Context, req llm.Request, streaming bool) (*http.Response, error) {
|
||||
p := m.provider
|
||||
if p.apiKey == "" {
|
||||
// Why request-time, not construction-time: New never fails by
|
||||
// convention, and a 401-shaped APIError classifies permanent so
|
||||
// chains fail fast past a misconfigured target.
|
||||
return nil, &llm.APIError{
|
||||
Provider: p.name,
|
||||
Model: m.id,
|
||||
Status: http.StatusUnauthorized,
|
||||
Code: "authentication_error",
|
||||
Message: "no API key configured: set ANTHROPIC_API_KEY or use WithAPIKey",
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(buildWireRequest(m.id, req, p.maxTokens, streaming))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: encode request: %w", p.name, err)
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/messages", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: build request: %w", p.name, err)
|
||||
}
|
||||
httpReq.Header.Set("x-api-key", p.apiKey)
|
||||
httpReq.Header.Set("anthropic-version", apiVersion)
|
||||
httpReq.Header.Set("content-type", "application/json")
|
||||
if streaming {
|
||||
httpReq.Header.Set("accept", "text/event-stream")
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: do request: %w", p.name, err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// apiError converts a non-2xx response into *llm.APIError, filling Code and
|
||||
// Message from the documented {"type":"error","error":{...}} body when it
|
||||
// parses, and falling back to the raw body text when it does not.
|
||||
func (m *model) apiError(resp *http.Response) error {
|
||||
apiErr := &llm.APIError{
|
||||
Provider: m.provider.name,
|
||||
Model: m.id,
|
||||
Status: resp.StatusCode,
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return apiErr
|
||||
}
|
||||
var we wireErrorEnvelope
|
||||
if json.Unmarshal(body, &we) == nil && we.Error.Type != "" {
|
||||
apiErr.Code = we.Error.Type
|
||||
apiErr.Message = we.Error.Message
|
||||
} else {
|
||||
apiErr.Message = strings.TrimSpace(string(body))
|
||||
}
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// toResponse maps a wire response onto the canonical llm.Response. Thinking
|
||||
// and other unrecognized block types are tolerated and skipped — they are
|
||||
// not part of the canonical content vocabulary.
|
||||
func (m *model) toResponse(wr *wireResponse) *llm.Response {
|
||||
resp := &llm.Response{
|
||||
FinishReason: mapStopReason(wr.StopReason),
|
||||
Usage: wr.Usage.toUsage(),
|
||||
Model: m.fullName(),
|
||||
Raw: wr,
|
||||
}
|
||||
for _, block := range wr.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
resp.Parts = append(resp.Parts, llm.TextPart{Text: block.Text})
|
||||
case "tool_use":
|
||||
args := block.Input
|
||||
if len(args) == 0 {
|
||||
args = json.RawMessage("{}")
|
||||
}
|
||||
resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{
|
||||
ID: block.ID,
|
||||
Name: block.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
default:
|
||||
// thinking, redacted_thinking, server-tool blocks, and any
|
||||
// future types are skipped, not surfaced as parts.
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@@ -0,0 +1,774 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// okBody is a minimal successful Messages API response.
|
||||
const okBody = `{
|
||||
"id": "msg_01",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 3, "output_tokens": 5}
|
||||
}`
|
||||
|
||||
// capture records the last request the test server received.
|
||||
type capture struct {
|
||||
mu sync.Mutex
|
||||
hits int
|
||||
method string
|
||||
path string
|
||||
header http.Header
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (c *capture) handler(status int, respBody string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
c.mu.Lock()
|
||||
c.hits++
|
||||
c.method = r.Method
|
||||
c.path = r.URL.Path
|
||||
c.header = r.Header.Clone()
|
||||
c.body = body
|
||||
c.mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_, _ = w.Write([]byte(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
// bodyMap decodes the captured request body for key-presence assertions.
|
||||
func (c *capture) bodyMap(t *testing.T) map[string]any {
|
||||
t.Helper()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(c.body, &m); err != nil {
|
||||
t.Fatalf("decode captured body: %v\nbody: %s", err, c.body)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// newTestProvider spins up an httptest server and a provider pointed at it.
|
||||
func newTestProvider(t *testing.T, h http.Handler, opts ...Option) *Provider {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(h)
|
||||
t.Cleanup(srv.Close)
|
||||
return New(append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, opts...)...)
|
||||
}
|
||||
|
||||
func mustModel(t *testing.T, p *Provider, id string, opts ...llm.ModelOption) llm.Model {
|
||||
t.Helper()
|
||||
m, err := p.Model(id, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Model(%q): %v", id, err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func generate(t *testing.T, m llm.Model, req llm.Request, opts ...llm.Option) *llm.Response {
|
||||
t.Helper()
|
||||
resp, err := m.Generate(context.Background(), req, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func TestRequestHeadersAndPath(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
|
||||
if c.method != http.MethodPost {
|
||||
t.Errorf("method = %q, want POST", c.method)
|
||||
}
|
||||
if c.path != "/v1/messages" {
|
||||
t.Errorf("path = %q, want /v1/messages", c.path)
|
||||
}
|
||||
for header, want := range map[string]string{
|
||||
"x-api-key": "test-key",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
} {
|
||||
if got := c.header.Get(header); got != want {
|
||||
t.Errorf("header %s = %q, want %q", header, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemFold(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
generate(t, m, llm.Request{
|
||||
System: "base prompt",
|
||||
Messages: []llm.Message{
|
||||
llm.SystemText("first extra"),
|
||||
llm.UserText("hi"),
|
||||
llm.SystemText("second extra"),
|
||||
},
|
||||
})
|
||||
|
||||
body := c.bodyMap(t)
|
||||
if got, want := body["system"], "base prompt\n\nfirst extra\n\nsecond extra"; got != want {
|
||||
t.Errorf("system = %q, want %q", got, want)
|
||||
}
|
||||
msgs := body["messages"].([]any)
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("messages length = %d, want 1 (system messages must be excluded)", len(msgs))
|
||||
}
|
||||
if role := msgs[0].(map[string]any)["role"]; role != "user" {
|
||||
t.Errorf("remaining message role = %q, want user", role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoSystemOmitsField(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
|
||||
if _, ok := c.bodyMap(t)["system"]; ok {
|
||||
t.Error("system key present, want omitted when empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxTokens(t *testing.T) {
|
||||
t.Run("default 4096", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if got := c.bodyMap(t)["max_tokens"].(float64); got != 4096 {
|
||||
t.Errorf("max_tokens = %v, want 4096", got)
|
||||
}
|
||||
})
|
||||
t.Run("explicit wins", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
MaxTokens: 123,
|
||||
})
|
||||
if got := c.bodyMap(t)["max_tokens"].(float64); got != 123 {
|
||||
t.Errorf("max_tokens = %v, want 123", got)
|
||||
}
|
||||
})
|
||||
t.Run("WithDefaultMaxTokens overrides default", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithDefaultMaxTokens(99))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if got := c.bodyMap(t)["max_tokens"].(float64); got != 99 {
|
||||
t.Errorf("max_tokens = %v, want 99", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestImageBlock(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
raw := []byte{0x01, 0x02, 0x03}
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{
|
||||
llm.UserParts(llm.Text("look at this"), llm.Image("image/png", raw)),
|
||||
}})
|
||||
|
||||
msgs := c.bodyMap(t)["messages"].([]any)
|
||||
content := msgs[0].(map[string]any)["content"].([]any)
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("content blocks = %d, want 2", len(content))
|
||||
}
|
||||
img := content[1].(map[string]any)
|
||||
if img["type"] != "image" {
|
||||
t.Fatalf("block type = %v, want image", img["type"])
|
||||
}
|
||||
src := img["source"].(map[string]any)
|
||||
if src["type"] != "base64" {
|
||||
t.Errorf("source type = %v, want base64", src["type"])
|
||||
}
|
||||
if src["media_type"] != "image/png" {
|
||||
t.Errorf("media_type = %v, want image/png", src["media_type"])
|
||||
}
|
||||
if want := base64.StdEncoding.EncodeToString(raw); src["data"] != want {
|
||||
t.Errorf("data = %v, want %q", src["data"], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolUseToolResultRoundTrip(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{
|
||||
llm.UserText("weather?"),
|
||||
{
|
||||
Role: llm.RoleAssistant,
|
||||
Parts: []llm.Part{llm.Text("checking")},
|
||||
ToolCalls: []llm.ToolCall{
|
||||
{ID: "toolu_1", Name: "get_weather", Arguments: json.RawMessage(`{"location":"Paris"}`)},
|
||||
{ID: "toolu_2", Name: "noop"}, // empty args must become {}
|
||||
},
|
||||
},
|
||||
llm.ToolResultsMessage(
|
||||
llm.ToolResult{ID: "toolu_1", Name: "get_weather", Content: "72F and sunny"},
|
||||
llm.ToolResult{ID: "toolu_2", Name: "noop", Content: "boom", IsError: true},
|
||||
),
|
||||
}})
|
||||
|
||||
msgs := c.bodyMap(t)["messages"].([]any)
|
||||
if len(msgs) != 3 {
|
||||
t.Fatalf("messages = %d, want 3", len(msgs))
|
||||
}
|
||||
|
||||
asst := msgs[1].(map[string]any)
|
||||
if asst["role"] != "assistant" {
|
||||
t.Errorf("messages[1].role = %v, want assistant", asst["role"])
|
||||
}
|
||||
asstContent := asst["content"].([]any)
|
||||
if len(asstContent) != 3 {
|
||||
t.Fatalf("assistant blocks = %d, want 3 (text + 2 tool_use)", len(asstContent))
|
||||
}
|
||||
tu := asstContent[1].(map[string]any)
|
||||
if tu["type"] != "tool_use" || tu["id"] != "toolu_1" || tu["name"] != "get_weather" {
|
||||
t.Errorf("tool_use block = %v", tu)
|
||||
}
|
||||
if loc := tu["input"].(map[string]any)["location"]; loc != "Paris" {
|
||||
t.Errorf("tool_use input.location = %v, want Paris", loc)
|
||||
}
|
||||
if input := asstContent[2].(map[string]any)["input"].(map[string]any); len(input) != 0 {
|
||||
t.Errorf("empty-args tool_use input = %v, want {}", input)
|
||||
}
|
||||
|
||||
// RoleTool → ONE user message with one tool_result block per result.
|
||||
toolMsg := msgs[2].(map[string]any)
|
||||
if toolMsg["role"] != "user" {
|
||||
t.Errorf("messages[2].role = %v, want user", toolMsg["role"])
|
||||
}
|
||||
results := toolMsg["content"].([]any)
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("tool_result blocks = %d, want 2", len(results))
|
||||
}
|
||||
first := results[0].(map[string]any)
|
||||
if first["type"] != "tool_result" || first["tool_use_id"] != "toolu_1" || first["content"] != "72F and sunny" {
|
||||
t.Errorf("first tool_result = %v", first)
|
||||
}
|
||||
if _, ok := first["is_error"]; ok {
|
||||
t.Error("first tool_result has is_error, want omitted when false")
|
||||
}
|
||||
second := results[1].(map[string]any)
|
||||
if second["tool_use_id"] != "toolu_2" || second["is_error"] != true {
|
||||
t.Errorf("second tool_result = %v, want is_error true", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolDefinitions(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"q":{"type":"string"}},"required":["q"]}`)
|
||||
generate(t, m, llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
Tools: []llm.Tool{
|
||||
{Name: "search", Description: "Search the web.", Parameters: schema},
|
||||
{Name: "ping"}, // nil Parameters → default empty object schema
|
||||
},
|
||||
})
|
||||
|
||||
tools := c.bodyMap(t)["tools"].([]any)
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("tools = %d, want 2", len(tools))
|
||||
}
|
||||
search := tools[0].(map[string]any)
|
||||
if search["name"] != "search" || search["description"] != "Search the web." {
|
||||
t.Errorf("tool[0] = %v", search)
|
||||
}
|
||||
if typ := search["input_schema"].(map[string]any)["type"]; typ != "object" {
|
||||
t.Errorf("input_schema.type = %v, want object", typ)
|
||||
}
|
||||
ping := tools[1].(map[string]any)
|
||||
if typ := ping["input_schema"].(map[string]any)["type"]; typ != "object" {
|
||||
t.Errorf("nil-Parameters input_schema.type = %v, want object", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolChoiceForms(t *testing.T) {
|
||||
cases := []struct {
|
||||
choice string
|
||||
wantType string // "" means the field must be absent
|
||||
wantName string
|
||||
}{
|
||||
{choice: "", wantType: ""},
|
||||
{choice: "auto", wantType: "auto"},
|
||||
{choice: "required", wantType: "any"},
|
||||
{choice: "none", wantType: "none"},
|
||||
{choice: "get_weather", wantType: "tool", wantName: "get_weather"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run("choice="+tc.choice, func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
ToolChoice: tc.choice,
|
||||
})
|
||||
body := c.bodyMap(t)
|
||||
raw, present := body["tool_choice"]
|
||||
if tc.wantType == "" {
|
||||
if present {
|
||||
t.Fatalf("tool_choice present (%v), want omitted", raw)
|
||||
}
|
||||
return
|
||||
}
|
||||
choice := raw.(map[string]any)
|
||||
if choice["type"] != tc.wantType {
|
||||
t.Errorf("tool_choice.type = %v, want %q", choice["type"], tc.wantType)
|
||||
}
|
||||
if tc.wantName != "" && choice["name"] != tc.wantName {
|
||||
t.Errorf("tool_choice.name = %v, want %q", choice["name"], tc.wantName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputConfigFormat(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
m := mustModel(t, p, "claude-test")
|
||||
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}},"required":["name"],"additionalProperties":false}`)
|
||||
generate(t, m, llm.Request{Messages: []llm.Message{llm.UserText("hi")}},
|
||||
llm.WithSchema(schema, "person"))
|
||||
|
||||
body := c.bodyMap(t)
|
||||
format := body["output_config"].(map[string]any)["format"].(map[string]any)
|
||||
if format["type"] != "json_schema" {
|
||||
t.Errorf("output_config.format.type = %v, want json_schema", format["type"])
|
||||
}
|
||||
// Normalize both sides through any → Marshal (sorted keys) to compare.
|
||||
got, _ := json.Marshal(format["schema"])
|
||||
var want any
|
||||
_ = json.Unmarshal(schema, &want)
|
||||
wantJSON, _ := json.Marshal(want)
|
||||
if string(got) != string(wantJSON) {
|
||||
t.Errorf("schema = %s, want %s", got, wantJSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputConfigOmittedWithoutSchema(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if _, ok := c.bodyMap(t)["output_config"]; ok {
|
||||
t.Error("output_config present, want omitted when Schema is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSamplingKnobs(t *testing.T) {
|
||||
t.Run("omitted when unset", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
body := c.bodyMap(t)
|
||||
if _, ok := body["temperature"]; ok {
|
||||
t.Error("temperature present, want omitted when unset")
|
||||
}
|
||||
if _, ok := body["top_p"]; ok {
|
||||
t.Error("top_p present, want omitted when unset")
|
||||
}
|
||||
if _, ok := body["stop_sequences"]; ok {
|
||||
t.Error("stop_sequences present, want omitted when unset")
|
||||
}
|
||||
})
|
||||
t.Run("present when set", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}},
|
||||
llm.WithTemperature(0), // explicit zero must still be sent
|
||||
llm.WithTopP(0.9),
|
||||
llm.WithStopSequences("END"))
|
||||
body := c.bodyMap(t)
|
||||
if got, ok := body["temperature"]; !ok || got.(float64) != 0 {
|
||||
t.Errorf("temperature = %v (present=%v), want explicit 0", got, ok)
|
||||
}
|
||||
if got := body["top_p"].(float64); got != 0.9 {
|
||||
t.Errorf("top_p = %v, want 0.9", got)
|
||||
}
|
||||
stops := body["stop_sequences"].([]any)
|
||||
if len(stops) != 1 || stops[0] != "END" {
|
||||
t.Errorf("stop_sequences = %v, want [END]", stops)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamFieldOmittedOnGenerate(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if _, ok := c.bodyMap(t)["stream"]; ok {
|
||||
t.Error("stream key present on Generate, want omitted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseParse(t *testing.T) {
|
||||
const body = `{
|
||||
"id": "msg_02",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-test",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "pondering...", "signature": "sig"},
|
||||
{"type": "text", "text": "I'll check the weather."},
|
||||
{"type": "tool_use", "id": "toolu_9", "name": "get_weather", "input": {"location": "Paris"}}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"usage": {
|
||||
"input_tokens": 3,
|
||||
"output_tokens": 7,
|
||||
"cache_creation_input_tokens": 10,
|
||||
"cache_read_input_tokens": 20
|
||||
}
|
||||
}`
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, body))
|
||||
resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
|
||||
if len(resp.Parts) != 1 {
|
||||
t.Fatalf("parts = %d, want 1 (thinking blocks must be skipped)", len(resp.Parts))
|
||||
}
|
||||
if got := resp.Text(); got != "I'll check the weather." {
|
||||
t.Errorf("text = %q", got)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("tool calls = %d, want 1", len(resp.ToolCalls))
|
||||
}
|
||||
call := resp.ToolCalls[0]
|
||||
if call.ID != "toolu_9" || call.Name != "get_weather" {
|
||||
t.Errorf("tool call = %+v", call)
|
||||
}
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal(call.Arguments, &args); err != nil || args["location"] != "Paris" {
|
||||
t.Errorf("arguments = %s (err %v), want location Paris", call.Arguments, err)
|
||||
}
|
||||
if resp.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("finish = %q, want %q", resp.FinishReason, llm.FinishToolCalls)
|
||||
}
|
||||
// Total real input = input + cache_creation + cache_read.
|
||||
if resp.Usage.InputTokens != 33 || resp.Usage.OutputTokens != 7 {
|
||||
t.Errorf("usage = %+v, want {33 7}", resp.Usage)
|
||||
}
|
||||
if resp.Model != "anthropic/claude-test" {
|
||||
t.Errorf("model = %q, want anthropic/claude-test", resp.Model)
|
||||
}
|
||||
if resp.Raw == nil {
|
||||
t.Error("Raw = nil, want wire response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopReasonMapping(t *testing.T) {
|
||||
cases := map[string]llm.FinishReason{
|
||||
"end_turn": llm.FinishStop,
|
||||
"stop_sequence": llm.FinishStop,
|
||||
"max_tokens": llm.FinishLength,
|
||||
"model_context_window_exceeded": llm.FinishLength,
|
||||
"tool_use": llm.FinishToolCalls,
|
||||
"refusal": llm.FinishContentFilter,
|
||||
"pause_turn": llm.FinishOther,
|
||||
"some_future_reason": llm.FinishOther,
|
||||
}
|
||||
for stop, want := range cases {
|
||||
if got := mapStopReason(stop); got != want {
|
||||
t.Errorf("mapStopReason(%q) = %q, want %q", stop, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPErrorMapping(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
status int
|
||||
body string
|
||||
wantCode string
|
||||
wantClass llm.ErrorClass
|
||||
}{
|
||||
{
|
||||
name: "429 rate limit is transient",
|
||||
status: http.StatusTooManyRequests,
|
||||
body: `{"type":"error","error":{"type":"rate_limit_error","message":"slow down"}}`,
|
||||
wantCode: "rate_limit_error", wantClass: llm.ClassTransient,
|
||||
},
|
||||
{
|
||||
name: "529 overloaded is transient",
|
||||
status: 529,
|
||||
body: `{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`,
|
||||
wantCode: "overloaded_error", wantClass: llm.ClassTransient,
|
||||
},
|
||||
{
|
||||
name: "401 auth is permanent",
|
||||
status: http.StatusUnauthorized,
|
||||
body: `{"type":"error","error":{"type":"authentication_error","message":"invalid x-api-key"}}`,
|
||||
wantCode: "authentication_error", wantClass: llm.ClassPermanent,
|
||||
},
|
||||
{
|
||||
name: "404 is permanent",
|
||||
status: http.StatusNotFound,
|
||||
body: `{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`,
|
||||
wantCode: "not_found_error", wantClass: llm.ClassPermanent,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(tc.status, tc.body))
|
||||
_, err := mustModel(t, p, "claude-test").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil {
|
||||
t.Fatal("Generate succeeded, want error")
|
||||
}
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Provider != "anthropic" || apiErr.Model != "claude-test" {
|
||||
t.Errorf("provider/model = %s/%s", apiErr.Provider, apiErr.Model)
|
||||
}
|
||||
if apiErr.Status != tc.status {
|
||||
t.Errorf("status = %d, want %d", apiErr.Status, tc.status)
|
||||
}
|
||||
if apiErr.Code != tc.wantCode {
|
||||
t.Errorf("code = %q, want %q", apiErr.Code, tc.wantCode)
|
||||
}
|
||||
if apiErr.Message == "" {
|
||||
t.Error("message empty, want provider message")
|
||||
}
|
||||
if got := llm.Classify(err); got != tc.wantClass {
|
||||
t.Errorf("Classify = %v, want %v", got, tc.wantClass)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("404 unwraps to ErrModelNotFound", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusNotFound,
|
||||
`{"type":"error","error":{"type":"not_found_error","message":"model: nope"}}`))
|
||||
_, err := mustModel(t, p, "missing").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if !errors.Is(err, llm.ErrModelNotFound) {
|
||||
t.Errorf("errors.Is(err, ErrModelNotFound) = false for %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-JSON error body falls back to raw text", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusBadGateway, "upstream exploded"))
|
||||
_, err := mustModel(t, p, "m").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T, want *llm.APIError", err)
|
||||
}
|
||||
if apiErr.Status != http.StatusBadGateway || apiErr.Message != "upstream exploded" {
|
||||
t.Errorf("apiErr = %+v", apiErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMissingAPIKey(t *testing.T) {
|
||||
t.Setenv("ANTHROPIC_API_KEY", "") // isolate from any real environment
|
||||
|
||||
var c capture
|
||||
srv := httptest.NewServer(c.handler(http.StatusOK, okBody))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := New(WithBaseURL(srv.URL)) // construction must not fail
|
||||
_, err := mustModel(t, p, "claude-test").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" {
|
||||
t.Errorf("apiErr = %+v, want 401 authentication_error", apiErr)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassPermanent {
|
||||
t.Error("missing key must classify permanent")
|
||||
}
|
||||
if c.hits != 0 {
|
||||
t.Errorf("server hits = %d, want 0 (no request without a key)", c.hits)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyFromEnv(t *testing.T) {
|
||||
t.Setenv("ANTHROPIC_API_KEY", "env-key")
|
||||
|
||||
var c capture
|
||||
srv := httptest.NewServer(c.handler(http.StatusOK, okBody))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p := New(WithBaseURL(srv.URL))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if got := c.header.Get("x-api-key"); got != "env-key" {
|
||||
t.Errorf("x-api-key = %q, want env-key", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapabilityEnforcement(t *testing.T) {
|
||||
img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) }
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
caps *llm.Capabilities // nil = provider defaults
|
||||
req llm.Request
|
||||
}{
|
||||
{
|
||||
name: "images unsupported",
|
||||
caps: &llm.Capabilities{}, // MaxImagesPerReq 0 = no images
|
||||
req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 4))}},
|
||||
},
|
||||
{
|
||||
name: "too many images",
|
||||
caps: &llm.Capabilities{MaxImagesPerReq: 1},
|
||||
req: llm.Request{Messages: []llm.Message{
|
||||
llm.UserParts(img("image/png", 4), img("image/png", 4)),
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "disallowed MIME",
|
||||
req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/bmp", 4))}},
|
||||
},
|
||||
{
|
||||
name: "image too large",
|
||||
caps: &llm.Capabilities{MaxImagesPerReq: 1, MaxImageBytes: 2},
|
||||
req: llm.Request{Messages: []llm.Message{llm.UserParts(img("image/png", 3))}},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
var opts []llm.ModelOption
|
||||
if tc.caps != nil {
|
||||
opts = append(opts, llm.WithCapabilities(*tc.caps))
|
||||
}
|
||||
m := mustModel(t, p, "claude-test", opts...)
|
||||
|
||||
_, err := m.Generate(context.Background(), tc.req)
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Errorf("Generate err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
_, err = m.Stream(context.Background(), tc.req)
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Errorf("Stream err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if c.hits != 0 {
|
||||
t.Errorf("server hits = %d, want 0 (rejected before sending)", c.hits)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("within limits passes", func(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody))
|
||||
generate(t, mustModel(t, p, "m"), llm.Request{
|
||||
Messages: []llm.Message{llm.UserParts(llm.Text("ok"), img("image/jpeg", 16))},
|
||||
})
|
||||
if c.hits != 1 {
|
||||
t.Errorf("server hits = %d, want 1", c.hits)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCompatEndpointWithNameAndBaseURL(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(http.StatusOK, okBody), WithName("compat"))
|
||||
if p.Name() != "compat" {
|
||||
t.Errorf("Name() = %q, want compat", p.Name())
|
||||
}
|
||||
resp := generate(t, mustModel(t, p, "claude-test"), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if resp.Model != "compat/claude-test" {
|
||||
t.Errorf("resp.Model = %q, want compat/claude-test", resp.Model)
|
||||
}
|
||||
|
||||
var ec capture
|
||||
pe := newTestProvider(t, ec.handler(http.StatusTooManyRequests,
|
||||
`{"type":"error","error":{"type":"rate_limit_error","message":"x"}}`), WithName("compat"))
|
||||
_, err := mustModel(t, pe, "m").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok || apiErr.Provider != "compat" {
|
||||
t.Errorf("error provider = %v, want compat (err %v)", apiErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapabilitiesDefaultsAndOverrides(t *testing.T) {
|
||||
p := New(WithAPIKey("k"))
|
||||
m := mustModel(t, p, "m")
|
||||
caps := m.Capabilities()
|
||||
if !caps.SupportsTools || !caps.SupportsStructured || !caps.SupportsStreaming {
|
||||
t.Errorf("default feature flags = %+v, want all true", caps)
|
||||
}
|
||||
if caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 10<<20 || caps.MaxImageDimension != 8000 {
|
||||
t.Errorf("default image limits = %+v", caps)
|
||||
}
|
||||
wantMIME := []string{"image/jpeg", "image/png", "image/gif", "image/webp"}
|
||||
if len(caps.AllowedImageMIME) != len(wantMIME) {
|
||||
t.Fatalf("AllowedImageMIME = %v, want %v", caps.AllowedImageMIME, wantMIME)
|
||||
}
|
||||
for i, mime := range wantMIME {
|
||||
if caps.AllowedImageMIME[i] != mime {
|
||||
t.Errorf("AllowedImageMIME[%d] = %q, want %q", i, caps.AllowedImageMIME[i], mime)
|
||||
}
|
||||
}
|
||||
|
||||
custom := llm.Capabilities{SupportsStreaming: true, MaxImagesPerReq: 1}
|
||||
p2 := New(WithAPIKey("k"), WithDefaultCapabilities(custom))
|
||||
if got := mustModel(t, p2, "m").Capabilities(); got.MaxImagesPerReq != 1 || got.SupportsTools {
|
||||
t.Errorf("WithDefaultCapabilities not applied: %+v", got)
|
||||
}
|
||||
|
||||
perModel := llm.Capabilities{SupportsTools: true}
|
||||
if got := mustModel(t, p2, "m", llm.WithCapabilities(perModel)).Capabilities(); !got.SupportsTools || got.MaxImagesPerReq != 0 {
|
||||
t.Errorf("per-model capabilities not applied: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportErrorNotAPIError(t *testing.T) {
|
||||
// Point at a server that is immediately closed: the connection failure
|
||||
// must surface as a wrapped transport error, not *llm.APIError.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
url := srv.URL
|
||||
srv.Close()
|
||||
|
||||
p := New(WithAPIKey("k"), WithBaseURL(url))
|
||||
_, err := mustModel(t, p, "m").Generate(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil {
|
||||
t.Fatal("Generate succeeded, want transport error")
|
||||
}
|
||||
if _, ok := errors.AsType[*llm.APIError](err); ok {
|
||||
t.Errorf("transport error wrapped in APIError: %v", err)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassTransient {
|
||||
t.Errorf("connection failure must classify transient: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// wireStreamEvent is the union of all SSE data payloads the Messages API
|
||||
// emits. Dispatch is on Type (the data always carries one), so the SSE
|
||||
// "event:" line is informational only.
|
||||
type wireStreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Index int `json:"index"`
|
||||
|
||||
// message_start
|
||||
Message *struct {
|
||||
Usage wireUsage `json:"usage"`
|
||||
} `json:"message"`
|
||||
|
||||
// content_block_start
|
||||
ContentBlock *struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"content_block"`
|
||||
|
||||
// content_block_delta / message_delta
|
||||
Delta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
PartialJSON string `json:"partial_json"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
} `json:"delta"`
|
||||
|
||||
// message_delta
|
||||
Usage *wireUsage `json:"usage"`
|
||||
|
||||
// error
|
||||
Error *struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// stream adapts the Messages API SSE stream to llm.Stream.
|
||||
//
|
||||
// Why single-threaded pull (no reader goroutine): Next is already the
|
||||
// consumer's pull point, so parsing lazily inside Next keeps cancellation,
|
||||
// buffering, and error propagation trivial — Close just closes the body and
|
||||
// the next read fails.
|
||||
type stream struct {
|
||||
provider string
|
||||
model string
|
||||
full string // provider/model
|
||||
body io.ReadCloser
|
||||
scanner *bufio.Scanner
|
||||
|
||||
// accumulated response
|
||||
parts []llm.Part
|
||||
toolCalls []llm.ToolCall
|
||||
usage llm.Usage
|
||||
finish llm.FinishReason
|
||||
|
||||
// current content block state
|
||||
blockType string
|
||||
textBuf strings.Builder
|
||||
toolID string
|
||||
toolName string
|
||||
argsBuf strings.Builder
|
||||
|
||||
done bool // final Response event emitted
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func newStream(m *model, body io.ReadCloser) *stream {
|
||||
sc := bufio.NewScanner(body)
|
||||
// Why a large limit: one SSE line carries one whole delta; default 64K
|
||||
// can be exceeded by large structured-output or tool-argument deltas.
|
||||
sc.Buffer(make([]byte, 0, 64*1024), 10*1024*1024)
|
||||
return &stream{
|
||||
provider: m.provider.name,
|
||||
model: m.id,
|
||||
full: m.fullName(),
|
||||
body: body,
|
||||
scanner: sc,
|
||||
finish: llm.FinishOther,
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements llm.Stream. Safe to call at any time and more than once.
|
||||
func (s *stream) Close() error {
|
||||
s.closeOnce.Do(func() { s.closeErr = s.body.Close() })
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
// Next implements llm.Stream. It emits TextDelta fragments as they arrive,
|
||||
// fully-assembled ToolCalls at content_block_stop, exactly one final
|
||||
// Response event at message_stop, then io.EOF.
|
||||
func (s *stream) Next() (llm.StreamEvent, error) {
|
||||
if s.done {
|
||||
return llm.StreamEvent{}, io.EOF
|
||||
}
|
||||
for {
|
||||
data, err := s.nextData()
|
||||
if err != nil {
|
||||
return llm.StreamEvent{}, err
|
||||
}
|
||||
var ev wireStreamEvent
|
||||
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
||||
return llm.StreamEvent{}, fmt.Errorf("%s: decode stream event: %w", s.provider, err)
|
||||
}
|
||||
|
||||
switch ev.Type {
|
||||
case "message_start":
|
||||
if ev.Message != nil {
|
||||
s.usage = ev.Message.Usage.toUsage()
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
s.blockType = ""
|
||||
s.textBuf.Reset()
|
||||
s.argsBuf.Reset()
|
||||
if ev.ContentBlock != nil {
|
||||
s.blockType = ev.ContentBlock.Type
|
||||
if s.blockType == "tool_use" {
|
||||
s.toolID = ev.ContentBlock.ID
|
||||
s.toolName = ev.ContentBlock.Name
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
switch ev.Delta.Type {
|
||||
case "text_delta":
|
||||
s.textBuf.WriteString(ev.Delta.Text)
|
||||
return llm.StreamEvent{TextDelta: ev.Delta.Text}, nil
|
||||
case "input_json_delta":
|
||||
// Buffer partial JSON internally; consumers never see it.
|
||||
s.argsBuf.WriteString(ev.Delta.PartialJSON)
|
||||
default:
|
||||
// thinking_delta / signature_delta: tolerated, skipped.
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
if event, ok := s.finishBlock(); ok {
|
||||
return event, nil
|
||||
}
|
||||
|
||||
case "message_delta":
|
||||
if ev.Delta.StopReason != "" {
|
||||
s.finish = mapStopReason(ev.Delta.StopReason)
|
||||
}
|
||||
if ev.Usage != nil {
|
||||
// Output tokens arrive cumulatively in the final delta;
|
||||
// input tokens were reported in message_start.
|
||||
s.usage.OutputTokens = ev.Usage.OutputTokens
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
s.done = true
|
||||
return llm.StreamEvent{Response: &llm.Response{
|
||||
Parts: s.parts,
|
||||
ToolCalls: s.toolCalls,
|
||||
FinishReason: s.finish,
|
||||
Usage: s.usage,
|
||||
Model: s.full,
|
||||
}}, nil
|
||||
|
||||
case "error":
|
||||
// Mid-stream failure after the 200 (e.g. overloaded_error).
|
||||
// Status stays 0: there is no HTTP status for it, and the
|
||||
// default Classify treats it as transient, which fits overload.
|
||||
apiErr := &llm.APIError{Provider: s.provider, Model: s.model}
|
||||
if ev.Error != nil {
|
||||
apiErr.Code = ev.Error.Type
|
||||
apiErr.Message = ev.Error.Message
|
||||
}
|
||||
return llm.StreamEvent{}, apiErr
|
||||
|
||||
default:
|
||||
// ping and unknown event types: ignored.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finishBlock closes out the current content block, appending its result to
|
||||
// the accumulated response. Tool-use blocks produce a stream event.
|
||||
func (s *stream) finishBlock() (llm.StreamEvent, bool) {
|
||||
defer func() {
|
||||
s.blockType = ""
|
||||
s.textBuf.Reset()
|
||||
s.argsBuf.Reset()
|
||||
}()
|
||||
switch s.blockType {
|
||||
case "text":
|
||||
if s.textBuf.Len() > 0 {
|
||||
s.parts = append(s.parts, llm.TextPart{Text: s.textBuf.String()})
|
||||
}
|
||||
case "tool_use":
|
||||
args := s.argsBuf.String()
|
||||
if args == "" {
|
||||
// A tool called with no arguments streams zero (or empty)
|
||||
// input_json_delta fragments; the canonical form is "{}".
|
||||
args = "{}"
|
||||
}
|
||||
call := llm.ToolCall{ID: s.toolID, Name: s.toolName, Arguments: json.RawMessage(args)}
|
||||
s.toolCalls = append(s.toolCalls, call)
|
||||
return llm.StreamEvent{ToolCall: &call}, true
|
||||
}
|
||||
return llm.StreamEvent{}, false
|
||||
}
|
||||
|
||||
// nextData reads SSE lines until one complete event's data is assembled
|
||||
// (multi-line data fields are joined with "\n" per the SSE spec). "event:"
|
||||
// lines and comments are ignored; dispatch keys off the JSON "type" field.
|
||||
func (s *stream) nextData() (string, error) {
|
||||
var data strings.Builder
|
||||
for s.scanner.Scan() {
|
||||
line := s.scanner.Text()
|
||||
if line == "" {
|
||||
if data.Len() > 0 {
|
||||
return data.String(), nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
if rest, ok := strings.CutPrefix(line, "data:"); ok {
|
||||
if data.Len() > 0 {
|
||||
data.WriteByte('\n')
|
||||
}
|
||||
data.WriteString(strings.TrimPrefix(rest, " "))
|
||||
}
|
||||
}
|
||||
if err := s.scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("%s: read stream: %w", s.provider, err)
|
||||
}
|
||||
if data.Len() > 0 {
|
||||
return data.String(), nil
|
||||
}
|
||||
// EOF before message_stop: the connection dropped mid-response.
|
||||
return "", fmt.Errorf("%s: stream ended before message_stop: %w", s.provider, io.ErrUnexpectedEOF)
|
||||
}
|
||||
@@ -0,0 +1,324 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// sse joins data payloads into an SSE body. Each payload becomes one event
|
||||
// ("event:" name derived from the JSON type field is what the real API
|
||||
// sends, but the client dispatches on the data, so a generic name is fine).
|
||||
func sse(payloads ...string) string {
|
||||
var b strings.Builder
|
||||
for _, p := range payloads {
|
||||
b.WriteString("event: event\n")
|
||||
b.WriteString("data: ")
|
||||
b.WriteString(p)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func sseServer(t *testing.T, c *capture, body string) *Provider {
|
||||
t.Helper()
|
||||
return newTestProvider(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
raw, _ := io.ReadAll(r.Body)
|
||||
c.mu.Lock()
|
||||
c.hits++
|
||||
c.header = r.Header.Clone()
|
||||
c.body = raw
|
||||
c.mu.Unlock()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
}
|
||||
|
||||
// drain collects all events until io.EOF, failing the test on any error.
|
||||
func drain(t *testing.T, s llm.Stream) []llm.StreamEvent {
|
||||
t.Helper()
|
||||
var events []llm.StreamEvent
|
||||
for {
|
||||
ev, err := s.Next()
|
||||
if err == io.EOF {
|
||||
return events
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Next: %v", err)
|
||||
}
|
||||
events = append(events, ev)
|
||||
}
|
||||
}
|
||||
|
||||
func openStream(t *testing.T, p *Provider, modelID string) llm.Stream {
|
||||
t.Helper()
|
||||
s, err := mustModel(t, p, modelID).Stream(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func TestStreamTextDeltas(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"m","usage":{"input_tokens":10,"cache_creation_input_tokens":2,"cache_read_input_tokens":3,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"ping"}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hel"}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"lo"}}`,
|
||||
`{"type":"content_block_stop","index":0}`,
|
||||
`{"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":" world"}}`,
|
||||
`{"type":"content_block_stop","index":1}`,
|
||||
`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":12}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
s := openStream(t, p, "claude-test")
|
||||
events := drain(t, s)
|
||||
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("events = %d, want 4 (3 deltas + final response)", len(events))
|
||||
}
|
||||
for i, want := range []string{"Hel", "lo", " world"} {
|
||||
if events[i].TextDelta != want {
|
||||
t.Errorf("event[%d].TextDelta = %q, want %q", i, events[i].TextDelta, want)
|
||||
}
|
||||
}
|
||||
|
||||
final := events[3].Response
|
||||
if final == nil {
|
||||
t.Fatal("last event has no Response")
|
||||
}
|
||||
if len(final.Parts) != 2 {
|
||||
t.Fatalf("final parts = %d, want 2 (one per text block)", len(final.Parts))
|
||||
}
|
||||
if final.Text() != "Hello world" {
|
||||
t.Errorf("final text = %q, want %q", final.Text(), "Hello world")
|
||||
}
|
||||
if final.FinishReason != llm.FinishStop {
|
||||
t.Errorf("finish = %q, want stop", final.FinishReason)
|
||||
}
|
||||
// Input = 10+2+3 from message_start; output = 12 from message_delta.
|
||||
if final.Usage.InputTokens != 15 || final.Usage.OutputTokens != 12 {
|
||||
t.Errorf("usage = %+v, want {15 12}", final.Usage)
|
||||
}
|
||||
if final.Model != "anthropic/claude-test" {
|
||||
t.Errorf("model = %q, want anthropic/claude-test", final.Model)
|
||||
}
|
||||
|
||||
// Past EOF, Next keeps returning io.EOF.
|
||||
if _, err := s.Next(); err != io.EOF {
|
||||
t.Errorf("Next after EOF = %v, want io.EOF", err)
|
||||
}
|
||||
|
||||
// The request must carry "stream": true.
|
||||
if streamFlag := c.bodyMap(t)["stream"]; streamFlag != true {
|
||||
t.Errorf("request stream = %v, want true", streamFlag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamToolCallAssembly(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":8,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Checking."}}`,
|
||||
`{"type":"content_block_stop","index":0}`,
|
||||
`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_9","name":"get_weather","input":{}}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":""}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"location\":"}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" \"San Francisco, CA\"}"}}`,
|
||||
`{"type":"content_block_stop","index":1}`,
|
||||
`{"type":"content_block_start","index":2,"content_block":{"type":"tool_use","id":"toolu_10","name":"noop","input":{}}}`,
|
||||
`{"type":"content_block_stop","index":2}`,
|
||||
`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":21}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
events := drain(t, openStream(t, p, "claude-test"))
|
||||
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("events = %d, want 4 (text, 2 tool calls, final)", len(events))
|
||||
}
|
||||
if events[0].TextDelta != "Checking." {
|
||||
t.Errorf("event[0] = %+v, want text delta", events[0])
|
||||
}
|
||||
|
||||
call := events[1].ToolCall
|
||||
if call == nil {
|
||||
t.Fatal("event[1] has no ToolCall")
|
||||
}
|
||||
if call.ID != "toolu_9" || call.Name != "get_weather" {
|
||||
t.Errorf("tool call = %+v", call)
|
||||
}
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal(call.Arguments, &args); err != nil {
|
||||
t.Fatalf("assembled arguments invalid JSON: %v (%s)", err, call.Arguments)
|
||||
}
|
||||
if args["location"] != "San Francisco, CA" {
|
||||
t.Errorf("arguments = %v", args)
|
||||
}
|
||||
|
||||
empty := events[2].ToolCall
|
||||
if empty == nil || empty.ID != "toolu_10" {
|
||||
t.Fatalf("event[2] = %+v, want second tool call", events[2])
|
||||
}
|
||||
if string(empty.Arguments) != "{}" {
|
||||
t.Errorf("empty tool call arguments = %s, want {}", empty.Arguments)
|
||||
}
|
||||
|
||||
final := events[3].Response
|
||||
if final == nil {
|
||||
t.Fatal("last event has no Response")
|
||||
}
|
||||
if len(final.ToolCalls) != 2 {
|
||||
t.Errorf("final tool calls = %d, want 2", len(final.ToolCalls))
|
||||
}
|
||||
if final.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("finish = %q, want tool_calls", final.FinishReason)
|
||||
}
|
||||
if final.Text() != "Checking." {
|
||||
t.Errorf("final text = %q", final.Text())
|
||||
}
|
||||
if final.Usage.InputTokens != 8 || final.Usage.OutputTokens != 21 {
|
||||
t.Errorf("usage = %+v, want {8 21}", final.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamThinkingSkipped(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":"hmm"}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":"sig"}}`,
|
||||
`{"type":"content_block_stop","index":0}`,
|
||||
`{"type":"content_block_start","index":1,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":1,"delta":{"type":"text_delta","text":"hi"}}`,
|
||||
`{"type":"content_block_stop","index":1}`,
|
||||
`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":2}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
events := drain(t, openStream(t, p, "claude-test"))
|
||||
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("events = %d, want 2 (thinking produces none)", len(events))
|
||||
}
|
||||
if events[0].TextDelta != "hi" {
|
||||
t.Errorf("event[0] = %+v, want TextDelta hi", events[0])
|
||||
}
|
||||
final := events[1].Response
|
||||
if final == nil || len(final.Parts) != 1 || final.Text() != "hi" {
|
||||
t.Errorf("final = %+v, want single text part %q", final, "hi")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamMidStreamError(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"par"}}`,
|
||||
`{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
s := openStream(t, p, "claude-test")
|
||||
|
||||
ev, err := s.Next()
|
||||
if err != nil || ev.TextDelta != "par" {
|
||||
t.Fatalf("first Next = (%+v, %v), want text delta", ev, err)
|
||||
}
|
||||
_, err = s.Next()
|
||||
if err == nil {
|
||||
t.Fatal("second Next succeeded, want mid-stream error")
|
||||
}
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Code != "overloaded_error" || apiErr.Message != "Overloaded" || apiErr.Status != 0 {
|
||||
t.Errorf("apiErr = %+v", apiErr)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassTransient {
|
||||
t.Error("overloaded_error must classify transient")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHTTPErrorBeforeEvents(t *testing.T) {
|
||||
var c capture
|
||||
p := newTestProvider(t, c.handler(529,
|
||||
`{"type":"error","error":{"type":"overloaded_error","message":"Overloaded"}}`))
|
||||
_, err := mustModel(t, p, "claude-test").Stream(context.Background(),
|
||||
llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil {
|
||||
t.Fatal("Stream succeeded, want APIError before any events")
|
||||
}
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("error %T (%v), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != 529 || apiErr.Code != "overloaded_error" {
|
||||
t.Errorf("apiErr = %+v, want 529 overloaded_error", apiErr)
|
||||
}
|
||||
if llm.Classify(err) != llm.ClassTransient {
|
||||
t.Error("529 must classify transient")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTruncatedBody(t *testing.T) {
|
||||
// Stream ends without message_stop: Next must surface unexpected EOF.
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
s := openStream(t, p, "claude-test")
|
||||
|
||||
if ev, err := s.Next(); err != nil || ev.TextDelta != "hi" {
|
||||
t.Fatalf("first Next = (%+v, %v)", ev, err)
|
||||
}
|
||||
if _, err := s.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Errorf("Next on truncated stream = %v, want io.ErrUnexpectedEOF", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamCloseIsSafe(t *testing.T) {
|
||||
body := sse(
|
||||
`{"type":"message_start","message":{"id":"msg_1","usage":{"input_tokens":5,"output_tokens":1}}}`,
|
||||
`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}`,
|
||||
`{"type":"content_block_stop","index":0}`,
|
||||
`{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":2}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
)
|
||||
var c capture
|
||||
p := sseServer(t, &c, body)
|
||||
s := openStream(t, p, "claude-test")
|
||||
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("first Close: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("second Close: %v", err)
|
||||
}
|
||||
|
||||
// After EOF, Close is still fine.
|
||||
s2 := openStream(t, p, "claude-test")
|
||||
drain(t, s2)
|
||||
if err := s2.Close(); err != nil {
|
||||
t.Errorf("Close after EOF: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// Wire types mirror the Messages API JSON shapes (June 2026 docs). Only the
|
||||
// fields majordomo uses are modeled; unknown response fields are ignored by
|
||||
// encoding/json.
|
||||
|
||||
type wireRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []wireMessage `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []wireTool `json:"tools,omitempty"`
|
||||
ToolChoice *wireToolChoice `json:"tool_choice,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
OutputConfig *wireOutputConfig `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
type wireMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []wireBlock `json:"content"`
|
||||
}
|
||||
|
||||
// wireBlock is a request-side content block. Exactly one shape is populated
|
||||
// per block, keyed by Type: text, image, tool_use, or tool_result.
|
||||
type wireBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// text
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// image
|
||||
Source *wireImageSource `json:"source,omitempty"`
|
||||
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// tool_result
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
type wireImageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type wireTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
type wireToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type wireOutputConfig struct {
|
||||
Format *wireOutputFormat `json:"format,omitempty"`
|
||||
}
|
||||
|
||||
type wireOutputFormat struct {
|
||||
Type string `json:"type"`
|
||||
Schema json.RawMessage `json:"schema"`
|
||||
}
|
||||
|
||||
type wireResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Content []wireRespBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Usage wireUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type wireRespBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Input json.RawMessage `json:"input"`
|
||||
}
|
||||
|
||||
type wireUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
}
|
||||
|
||||
// toUsage maps API token accounting onto the canonical Usage. Why the sum:
|
||||
// the API's input_tokens counts only tokens after the last cache breakpoint;
|
||||
// real total input is input + cache_creation + cache_read.
|
||||
func (u wireUsage) toUsage() llm.Usage {
|
||||
return llm.Usage{
|
||||
InputTokens: u.InputTokens + u.CacheCreationInputTokens + u.CacheReadInputTokens,
|
||||
OutputTokens: u.OutputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
type wireErrorEnvelope struct {
|
||||
Type string `json:"type"`
|
||||
Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// buildWireRequest translates the canonical request into the Messages API
|
||||
// shape.
|
||||
//
|
||||
// Request.ReasoningEffort is intentionally ignored: the current Messages API
|
||||
// has no low/medium/high reasoning knob — thinking is adaptive on current
|
||||
// models, and the legacy budget/disable parameters 400 on them. The llm
|
||||
// contract says providers ignore ReasoningEffort where no mapping exists.
|
||||
//
|
||||
// Request.SchemaName is likewise ignored: output_config.format takes a bare
|
||||
// schema with no name field.
|
||||
func buildWireRequest(modelID string, req llm.Request, defaultMax int, stream bool) wireRequest {
|
||||
maxTokens := req.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
// max_tokens is required by the API; 0 means "provider default".
|
||||
maxTokens = defaultMax
|
||||
}
|
||||
|
||||
wr := wireRequest{
|
||||
Model: modelID,
|
||||
MaxTokens: maxTokens,
|
||||
System: foldSystem(req),
|
||||
Messages: toWireMessages(req.Messages),
|
||||
Stream: stream,
|
||||
Tools: toWireTools(req.Tools),
|
||||
ToolChoice: toWireToolChoice(req.ToolChoice),
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
StopSequences: req.StopSequences,
|
||||
}
|
||||
if req.Schema != nil {
|
||||
wr.OutputConfig = &wireOutputConfig{Format: &wireOutputFormat{
|
||||
Type: "json_schema",
|
||||
Schema: req.Schema,
|
||||
}}
|
||||
}
|
||||
return wr
|
||||
}
|
||||
|
||||
// foldSystem joins Request.System with the text of every RoleSystem message
|
||||
// (System field first, original order, "\n\n" separators). Why: the API
|
||||
// takes the system prompt as a top-level field and rejects system roles
|
||||
// inside messages, so canonical RoleSystem messages must fold in here.
|
||||
func foldSystem(req llm.Request) string {
|
||||
parts := make([]string, 0, 2)
|
||||
if req.System != "" {
|
||||
parts = append(parts, req.System)
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
if msg.Role != llm.RoleSystem {
|
||||
continue
|
||||
}
|
||||
if text := msg.Text(); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
func toWireMessages(msgs []llm.Message) []wireMessage {
|
||||
out := make([]wireMessage, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case llm.RoleSystem:
|
||||
// Folded into the top-level system field by foldSystem.
|
||||
continue
|
||||
|
||||
case llm.RoleTool:
|
||||
// One user message carrying one tool_result block per result.
|
||||
blocks := make([]wireBlock, 0, len(msg.ToolResults))
|
||||
for _, res := range msg.ToolResults {
|
||||
blocks = append(blocks, wireBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: res.ID,
|
||||
Content: res.Content,
|
||||
IsError: res.IsError,
|
||||
})
|
||||
}
|
||||
out = append(out, wireMessage{Role: "user", Content: blocks})
|
||||
|
||||
case llm.RoleAssistant:
|
||||
blocks := toWireBlocks(msg.Parts)
|
||||
for _, call := range msg.ToolCalls {
|
||||
args := call.Arguments
|
||||
if len(args) == 0 {
|
||||
// The API requires input to be a JSON object.
|
||||
args = json.RawMessage("{}")
|
||||
}
|
||||
blocks = append(blocks, wireBlock{
|
||||
Type: "tool_use",
|
||||
ID: call.ID,
|
||||
Name: call.Name,
|
||||
Input: args,
|
||||
})
|
||||
}
|
||||
out = append(out, wireMessage{Role: "assistant", Content: blocks})
|
||||
|
||||
default: // llm.RoleUser and anything unrecognized
|
||||
out = append(out, wireMessage{Role: "user", Content: toWireBlocks(msg.Parts)})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toWireBlocks(parts []llm.Part) []wireBlock {
|
||||
blocks := make([]wireBlock, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
switch p := part.(type) {
|
||||
case llm.TextPart:
|
||||
blocks = append(blocks, wireBlock{Type: "text", Text: p.Text})
|
||||
case llm.ImagePart:
|
||||
blocks = append(blocks, wireBlock{Type: "image", Source: &wireImageSource{
|
||||
Type: "base64",
|
||||
MediaType: p.MIME,
|
||||
Data: base64.StdEncoding.EncodeToString(p.Data),
|
||||
}})
|
||||
}
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
func toWireTools(tools []llm.Tool) []wireTool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]wireTool, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
schema := t.Parameters
|
||||
if len(schema) == 0 {
|
||||
// Why: input_schema is required by the API; a tool with no
|
||||
// arguments still needs an (empty) object schema.
|
||||
schema = json.RawMessage(`{"type":"object","properties":{}}`)
|
||||
}
|
||||
out = append(out, wireTool{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// toWireToolChoice maps the canonical tool-choice policy. "" omits the field
|
||||
// (API default is auto); any value other than the three keywords names the
|
||||
// one tool the model must call.
|
||||
func toWireToolChoice(choice string) *wireToolChoice {
|
||||
switch choice {
|
||||
case "":
|
||||
return nil
|
||||
case "auto":
|
||||
return &wireToolChoice{Type: "auto"}
|
||||
case "required":
|
||||
return &wireToolChoice{Type: "any"}
|
||||
case "none":
|
||||
return &wireToolChoice{Type: "none"}
|
||||
default:
|
||||
return &wireToolChoice{Type: "tool", Name: choice}
|
||||
}
|
||||
}
|
||||
|
||||
// mapStopReason maps the API stop_reason onto the canonical FinishReason.
|
||||
func mapStopReason(stop string) llm.FinishReason {
|
||||
switch stop {
|
||||
case "end_turn", "stop_sequence":
|
||||
return llm.FinishStop
|
||||
case "max_tokens", "model_context_window_exceeded":
|
||||
return llm.FinishLength
|
||||
case "tool_use":
|
||||
return llm.FinishToolCalls
|
||||
case "refusal":
|
||||
return llm.FinishContentFilter
|
||||
default:
|
||||
// pause_turn and any future provider-specific reasons.
|
||||
return llm.FinishOther
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user