feat: OpenAI, Anthropic, and native-Ollama providers + media pipeline

Phase 3:
- provider/openai: Chat Completions for OpenAI + compat endpoints (SSE
  streaming with by-index tool-call assembly, response_format json_schema,
  legacy max_tokens option, reasoning_effort)
- provider/anthropic: Messages API (tool_use/tool_result, GA structured
  output via output_config.format, full SSE event parser, 529 transient)
- provider/ollama: one native /api/chat client behind the ollama,
  ollama-cloud, and foreman built-ins (presets; NDJSON streaming tolerant
  of foreman's buffered single-object responses; object tool arguments;
  format-schema structured output; think mapping)
- media/: capability normalization (sniff, downscale, transcode, byte
  ladder, ErrUnsupported), wired into the chain executor per target with
  penalty-free advance past incapable elements
- registry: real provider + scheme wiring, WithHTTPClient option, required
  env-foreman TLS chat round-trip test
- ADR-0009 multimodal strategy, ADR-0010 tools/structured mapping; README
  matrix + CLAUDE.md synced

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-10 12:58:08 +02:00
parent 323558ed72
commit 043249e0e1
31 changed files with 6194 additions and 74 deletions
+222
View File
@@ -0,0 +1,222 @@
package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
// model is one provider-bound target.
type model struct {
p *Provider
id string
caps llm.Capabilities
}
// Capabilities implements llm.Model.
func (m *model) Capabilities() llm.Capabilities { return m.caps }
// Generate implements llm.Model.
func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) {
req = req.Apply(opts...)
if err := checkRequest(m.caps, req); err != nil {
return nil, err
}
httpResp, err := m.do(ctx, req, false)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()
if httpResp.StatusCode/100 != 2 {
return nil, m.apiError(httpResp)
}
var wire chatResponse
if err := json.NewDecoder(httpResp.Body).Decode(&wire); err != nil {
return nil, fmt.Errorf("openai: decode response: %w", err)
}
return m.toResponse(&wire), nil
}
// Stream implements llm.Model.
func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) {
req = req.Apply(opts...)
if !m.caps.SupportsStreaming {
return nil, fmt.Errorf("%w: streaming not supported by %s/%s", llm.ErrUnsupported, m.p.name, m.id)
}
if err := checkRequest(m.caps, req); err != nil {
return nil, err
}
httpResp, err := m.do(ctx, req, true)
if err != nil {
return nil, err
}
if httpResp.StatusCode/100 != 2 {
defer httpResp.Body.Close()
return nil, m.apiError(httpResp)
}
sc := bufio.NewScanner(httpResp.Body)
// Why: a single SSE data line carries a whole JSON chunk; tool-call
// argument fragments can make lines far larger than Scanner's 64 KiB
// default cap.
sc.Buffer(make([]byte, 0, 64*1024), 16<<20)
return &stream{m: m, body: httpResp.Body, sc: sc}, nil
}
// do builds and performs the HTTP request. Transport failures are wrapped
// raw (never as *llm.APIError) so llm.Classify still sees net.Error,
// syscall errnos, and context errors underneath.
func (m *model) do(ctx context.Context, req llm.Request, stream bool) (*http.Response, error) {
if m.p.apiKey == "" {
// Why a synthetic 401: the constructor never fails, so a missing
// key must surface at request time as the auth failure it is —
// permanent under llm.Classify, like a real 401.
return nil, &llm.APIError{
Provider: m.p.name,
Model: m.id,
Status: http.StatusUnauthorized,
Code: "missing_api_key",
Message: "no API key configured: set OPENAI_API_KEY or use WithAPIKey",
}
}
body, err := json.Marshal(m.buildRequest(req, stream))
if err != nil {
return nil, fmt.Errorf("openai: encode request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, m.p.baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("openai: build request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+m.p.apiKey)
if stream {
httpReq.Header.Set("Accept", "text/event-stream")
}
httpResp, err := m.p.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("openai: do request: %w", err)
}
return httpResp, nil
}
// apiError converts a non-2xx response into *llm.APIError, pulling code and
// message from the {"error":{...}} body when it parses.
func (m *model) apiError(httpResp *http.Response) error {
apiErr := &llm.APIError{Provider: m.p.name, Model: m.id, Status: httpResp.StatusCode}
body, _ := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20))
var env errorEnvelope
if err := json.Unmarshal(body, &env); err == nil &&
(env.Error.Message != "" || env.Error.Type != "" || env.Error.Code != "") {
apiErr.Message = env.Error.Message
apiErr.Code = env.Error.Code
if apiErr.Code == "" {
apiErr.Code = env.Error.Type
}
} else {
// Why: compat servers emit all sorts of error bodies; a raw snippet
// beats silence when the canonical envelope is absent.
apiErr.Message = strings.TrimSpace(string(body))
}
return apiErr
}
// toResponse maps the wire response onto the canonical llm.Response.
func (m *model) toResponse(wire *chatResponse) *llm.Response {
resp := &llm.Response{Model: m.p.name + "/" + m.id, Raw: wire}
if wire.Usage != nil {
resp.Usage = llm.Usage{
InputTokens: wire.Usage.PromptTokens,
OutputTokens: wire.Usage.CompletionTokens,
}
}
if len(wire.Choices) == 0 {
resp.FinishReason = llm.FinishOther
return resp
}
choice := wire.Choices[0]
if choice.Message.Content != "" {
resp.Parts = append(resp.Parts, llm.TextPart{Text: choice.Message.Content})
}
for i, tc := range choice.Message.ToolCalls {
id := tc.ID
if id == "" {
// Why: ToolResult.ID must echo ToolCall.ID, so calls from compat
// servers that omit ids get synthesized ones.
id = fmt.Sprintf("call_%d", i)
}
resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{
ID: id,
Name: tc.Function.Name,
Arguments: json.RawMessage(tc.Function.Arguments),
})
}
resp.FinishReason = mapFinish(choice.FinishReason, len(resp.ToolCalls) > 0)
return resp
}
// mapFinish maps a wire finish_reason to the canonical enum. Tool-call
// presence wins over the reported reason: a forced (named tool_choice) call
// can finish with "stop" while still carrying tool_calls.
func mapFinish(reason string, hasToolCalls bool) llm.FinishReason {
if hasToolCalls {
return llm.FinishToolCalls
}
switch reason {
case "stop":
return llm.FinishStop
case "length":
return llm.FinishLength
case "tool_calls":
return llm.FinishToolCalls
case "content_filter":
return llm.FinishContentFilter
default:
return llm.FinishOther
}
}
// checkRequest enforces the model's effective capabilities. Why enforcement
// rather than normalization: a separate media layer resizes/transcodes
// images BEFORE requests reach the provider; this check is the honest
// backstop that refuses, with llm.ErrUnsupported, what the target
// declaredly cannot serve (chains advance past it penalty-free).
func checkRequest(caps llm.Capabilities, req llm.Request) error {
if len(req.Tools) > 0 && !caps.SupportsTools {
return fmt.Errorf("%w: tools not supported", llm.ErrUnsupported)
}
if len(req.Schema) > 0 && !caps.SupportsStructured {
return fmt.Errorf("%w: structured output not supported", llm.ErrUnsupported)
}
images := 0
for _, msg := range req.Messages {
for _, part := range msg.Parts {
img, ok := part.(llm.ImagePart)
if !ok {
continue
}
images++
if !caps.SupportsImages() {
return fmt.Errorf("%w: image input not supported", llm.ErrUnsupported)
}
if !caps.MIMEAllowed(img.MIME) {
return fmt.Errorf("%w: image MIME type %q not allowed (allowed: %s)",
llm.ErrUnsupported, img.MIME, strings.Join(caps.AllowedImageMIME, ", "))
}
if caps.MaxImageBytes > 0 && len(img.Data) > caps.MaxImageBytes {
return fmt.Errorf("%w: image is %d bytes, limit is %d",
llm.ErrUnsupported, len(img.Data), caps.MaxImageBytes)
}
}
}
if images > caps.MaxImagesPerReq {
return fmt.Errorf("%w: request carries %d images, limit is %d",
llm.ErrUnsupported, images, caps.MaxImagesPerReq)
}
return nil
}
+133
View File
@@ -0,0 +1,133 @@
// Package openai implements llm.Provider for the OpenAI Chat Completions
// API and, via WithBaseURL/WithName, any OpenAI-compatible endpoint
// (vLLM, Groq, Together, LM Studio, Ollama's /v1 shim, ...).
//
// Targeted API surface (verified against developers.openai.com, June 2026):
// POST {base}/chat/completions with
// - messages: plain-string content for text-only turns, part arrays with
// base64 data-URL image_url entries for multimodal turns, assistant
// tool_calls history, and {"role":"tool","tool_call_id",...} results;
// - tools as {"type":"function","function":{...}} with tool_choice
// "auto"/"none"/"required" or a named-function object;
// - response_format {"type":"json_schema",...} structured output;
// - max_completion_tokens (or legacy max_tokens via WithLegacyMaxTokens
// for compat servers), temperature, top_p, stop, reasoning_effort;
// - data-only SSE streaming with stream_options.include_usage, the
// "data: [DONE]" sentinel, and tool-call deltas accumulated by index.
//
// Newer response fields (refusal, annotations, usage *_details, delta
// obfuscation) are tolerated and ignored so both api.openai.com and older
// compat servers decode cleanly.
package openai
import (
"net/http"
"os"
"strings"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
const defaultBaseURL = "https://api.openai.com/v1"
// Provider is an llm.Provider backed by an OpenAI Chat Completions endpoint.
type Provider struct {
name string
apiKey string
baseURL string
client *http.Client
caps llm.Capabilities
legacyMaxTokens bool
}
// Option configures the provider at construction.
type Option func(*Provider)
// WithAPIKey sets the API key. When absent, New reads OPENAI_API_KEY from
// the environment at construction time.
func WithAPIKey(key string) Option {
return func(p *Provider) { p.apiKey = key }
}
// WithBaseURL points the client at a different endpoint (compat servers).
// The path "/chat/completions" is appended; a trailing slash is trimmed.
func WithBaseURL(u string) Option {
return func(p *Provider) { p.baseURL = u }
}
// WithHTTPClient substitutes the HTTP client (timeouts, proxies, tests).
func WithHTTPClient(c *http.Client) Option {
return func(p *Provider) {
if c != nil {
p.client = c
}
}
}
// WithName overrides the registry name ("openai" by default). Why: the same
// client serves many OpenAI-compatible endpoints, and each needs a distinct
// name in "provider/model" specs and error reporting.
func WithName(name string) Option {
return func(p *Provider) { p.name = name }
}
// WithDefaultCapabilities replaces the provider-default capabilities.
// Per-model overrides via llm.WithCapabilities still take precedence.
func WithDefaultCapabilities(caps llm.Capabilities) Option {
return func(p *Provider) { p.caps = caps }
}
// WithLegacyMaxTokens sends Request.MaxTokens as "max_tokens" instead of
// "max_completion_tokens". Why: OpenAI deprecated max_tokens, but many
// third-party compat servers still only honor the legacy field.
func WithLegacyMaxTokens() Option {
return func(p *Provider) { p.legacyMaxTokens = true }
}
// defaultCapabilities reflects OpenAI's current vision-capable chat models.
// Why these limits: the published per-request caps (1500 images, 512 MB)
// are far beyond what compat servers accept; 100 images / 20 MB each is a
// conservative envelope, and the MIME list is the documented set (PNG,
// JPEG, WEBP, non-animated GIF).
func defaultCapabilities() llm.Capabilities {
return llm.Capabilities{
SupportsTools: true,
SupportsStructured: true,
SupportsStreaming: true,
MaxImagesPerReq: 100,
MaxImageBytes: 20 << 20,
AllowedImageMIME: []string{"image/jpeg", "image/png", "image/webp", "image/gif"},
}
}
// New creates a Provider. It never fails: a missing API key surfaces as a
// 401-style *llm.APIError at request time, not at construction.
func New(opts ...Option) *Provider {
p := &Provider{
name: "openai",
apiKey: os.Getenv("OPENAI_API_KEY"),
baseURL: defaultBaseURL,
client: http.DefaultClient,
caps: defaultCapabilities(),
}
for _, opt := range opts {
opt(p)
}
p.baseURL = strings.TrimRight(p.baseURL, "/")
return p
}
// Name implements llm.Provider.
func (p *Provider) Name() string { return p.name }
// Model implements llm.Provider. The id is passed through verbatim — no
// catalog validation; unknown models fail at request time with the
// backend's own error.
func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) {
cfg := llm.ApplyModelOptions(opts)
caps := p.caps
if cfg.Capabilities != nil {
caps = *cfg.Capabilities
}
return &model{p: p, id: id, caps: caps}, nil
}
+614
View File
@@ -0,0 +1,614 @@
package openai
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
var (
_ llm.Provider = (*Provider)(nil)
_ llm.Model = (*model)(nil)
_ llm.Stream = (*stream)(nil)
)
const textResponse = `{
"id": "chatcmpl-1", "object": "chat.completion", "created": 1741570283, "model": "gpt-test",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": "hello", "refusal": null, "annotations": []},
"logprobs": null,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 19, "completion_tokens": 10, "total_tokens": 29,
"prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
"completion_tokens_details": {"reasoning_tokens": 0}
},
"service_tier": "default", "system_fingerprint": "fp_x"
}`
// recorded captures the last request a test server received.
type recorded struct {
body map[string]any
header http.Header
path string
hits int
}
// newServer starts a test server that records the request and replies with
// a fixed status and body.
func newServer(t *testing.T, status int, respBody string) (*httptest.Server, *recorded) {
t.Helper()
rec := &recorded{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rec.hits++
rec.header = r.Header.Clone()
rec.path = r.URL.Path
raw, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("read request body: %v", err)
}
if len(raw) > 0 {
if err := json.Unmarshal(raw, &rec.body); err != nil {
t.Errorf("request body is not JSON: %v\n%s", err, raw)
}
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
io.WriteString(w, respBody)
}))
t.Cleanup(srv.Close)
return srv, rec
}
func testModel(t *testing.T, srv *httptest.Server, popts []Option, mopts ...llm.ModelOption) llm.Model {
t.Helper()
opts := append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, popts...)
m, err := New(opts...).Model("gpt-test", mopts...)
if err != nil {
t.Fatalf("Model: %v", err)
}
return m
}
func fptr(f float64) *float64 { return &f }
func TestGenerateRequestShape(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
m := testModel(t, srv, nil)
req := llm.Request{
System: "base system",
Messages: []llm.Message{
llm.SystemText("folded system"),
llm.UserParts(llm.Text("look:"), llm.Image("image/png", []byte{1, 2, 3})),
{
Role: llm.RoleAssistant,
Parts: []llm.Part{llm.Text("checking")},
ToolCalls: []llm.ToolCall{
{ID: "call_1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Boston"}`)},
},
},
llm.ToolResultsMessage(
llm.ToolResult{ID: "call_1", Name: "get_weather", Content: "72F"},
llm.ToolResult{ID: "call_2", Name: "get_weather", Content: "boom", IsError: true},
),
llm.UserText("thanks"),
},
Tools: []llm.Tool{{
Name: "get_weather",
Description: "Get current weather",
Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
}},
ToolChoice: "auto",
Temperature: fptr(0.5),
TopP: fptr(0.9),
MaxTokens: 256,
StopSequences: []string{"END"},
ReasoningEffort: "high",
Schema: json.RawMessage(`{"type":"object","properties":{"ok":{"type":"boolean"}}}`),
SchemaName: "verdict",
}
if _, err := m.Generate(context.Background(), req); err != nil {
t.Fatalf("Generate: %v", err)
}
want := map[string]any{
"model": "gpt-test",
"messages": []any{
map[string]any{"role": "system", "content": "base system\n\nfolded system"},
map[string]any{"role": "user", "content": []any{
map[string]any{"type": "text", "text": "look:"},
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,AQID"}},
}},
map[string]any{"role": "assistant", "content": "checking", "tool_calls": []any{
map[string]any{"id": "call_1", "type": "function", "function": map[string]any{
"name": "get_weather", "arguments": `{"city":"Boston"}`,
}},
}},
map[string]any{"role": "tool", "content": "72F", "tool_call_id": "call_1"},
map[string]any{"role": "tool", "content": "ERROR: boom", "tool_call_id": "call_2"},
map[string]any{"role": "user", "content": "thanks"},
},
"tools": []any{
map[string]any{"type": "function", "function": map[string]any{
"name": "get_weather",
"description": "Get current weather",
"parameters": map[string]any{"type": "object", "properties": map[string]any{"city": map[string]any{"type": "string"}}},
}},
},
"tool_choice": "auto",
"temperature": 0.5,
"top_p": 0.9,
"max_completion_tokens": float64(256),
"stop": []any{"END"},
"reasoning_effort": "high",
"response_format": map[string]any{"type": "json_schema", "json_schema": map[string]any{
"name": "verdict",
"schema": map[string]any{"type": "object", "properties": map[string]any{"ok": map[string]any{"type": "boolean"}}},
}},
}
if !reflect.DeepEqual(rec.body, want) {
got, _ := json.MarshalIndent(rec.body, "", " ")
exp, _ := json.MarshalIndent(want, "", " ")
t.Errorf("request body mismatch\ngot:\n%s\nwant:\n%s", got, exp)
}
}
func TestToolChoiceForms(t *testing.T) {
tests := []struct {
choice string
want any // nil = key absent
}{
{"", nil},
{"auto", "auto"},
{"none", "none"},
{"required", "required"},
{"get_weather", map[string]any{"type": "function", "function": map[string]any{"name": "get_weather"}}},
}
for _, tt := range tests {
t.Run("choice="+tt.choice, func(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
m := testModel(t, srv, nil)
req := llm.Request{
Messages: []llm.Message{llm.UserText("hi")},
Tools: []llm.Tool{{Name: "get_weather"}},
ToolChoice: tt.choice,
}
if _, err := m.Generate(context.Background(), req); err != nil {
t.Fatalf("Generate: %v", err)
}
got, present := rec.body["tool_choice"]
if tt.want == nil {
if present {
t.Errorf("tool_choice present, want omitted: %v", got)
}
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("tool_choice = %v, want %v", got, tt.want)
}
})
}
}
func TestMaxTokensField(t *testing.T) {
t.Run("default uses max_completion_tokens", func(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
m := testModel(t, srv, nil)
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64}
if _, err := m.Generate(context.Background(), req); err != nil {
t.Fatalf("Generate: %v", err)
}
if got := rec.body["max_completion_tokens"]; got != float64(64) {
t.Errorf("max_completion_tokens = %v, want 64", got)
}
if _, present := rec.body["max_tokens"]; present {
t.Error("max_tokens present, want omitted")
}
})
t.Run("WithLegacyMaxTokens uses max_tokens", func(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
m := testModel(t, srv, []Option{WithLegacyMaxTokens()})
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64}
if _, err := m.Generate(context.Background(), req); err != nil {
t.Fatalf("Generate: %v", err)
}
if got := rec.body["max_tokens"]; got != float64(64) {
t.Errorf("max_tokens = %v, want 64", got)
}
if _, present := rec.body["max_completion_tokens"]; present {
t.Error("max_completion_tokens present, want omitted")
}
})
t.Run("zero omits both", func(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
m := testModel(t, srv, nil)
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}}
if _, err := m.Generate(context.Background(), req); err != nil {
t.Fatalf("Generate: %v", err)
}
if _, present := rec.body["max_tokens"]; present {
t.Error("max_tokens present, want omitted")
}
if _, present := rec.body["max_completion_tokens"]; present {
t.Error("max_completion_tokens present, want omitted")
}
})
}
func TestSchemaNameDefault(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
m := testModel(t, srv, nil)
req := llm.Request{
Messages: []llm.Message{llm.UserText("hi")},
Schema: json.RawMessage(`{"type":"object"}`),
}
if _, err := m.Generate(context.Background(), req); err != nil {
t.Fatalf("Generate: %v", err)
}
rf, ok := rec.body["response_format"].(map[string]any)
if !ok {
t.Fatalf("response_format missing: %v", rec.body)
}
js, ok := rf["json_schema"].(map[string]any)
if !ok {
t.Fatalf("json_schema missing: %v", rf)
}
if js["name"] != "response" {
t.Errorf("schema name = %v, want %q", js["name"], "response")
}
}
func TestGenerateTextResponse(t *testing.T) {
srv, _ := newServer(t, http.StatusOK, textResponse)
m := testModel(t, srv, nil)
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err != nil {
t.Fatalf("Generate: %v", err)
}
if got := resp.Text(); got != "hello" {
t.Errorf("Text = %q, want %q", got, "hello")
}
if resp.FinishReason != llm.FinishStop {
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishStop)
}
if resp.Usage != (llm.Usage{InputTokens: 19, OutputTokens: 10}) {
t.Errorf("Usage = %+v, want {19 10}", resp.Usage)
}
if resp.Model != "openai/gpt-test" {
t.Errorf("Model = %q, want %q", resp.Model, "openai/gpt-test")
}
if len(resp.ToolCalls) != 0 {
t.Errorf("ToolCalls = %v, want none", resp.ToolCalls)
}
if resp.Raw == nil {
t.Error("Raw is nil, want wire response")
}
}
func TestGenerateToolCallResponse(t *testing.T) {
const body = `{
"id": "chatcmpl-2", "object": "chat.completion", "created": 1, "model": "gpt-test",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": null, "tool_calls": [
{"id": "call_9", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Boston\"}"}},
{"id": "", "type": "function", "function": {"name": "get_time", "arguments": "{}"}}
]},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}
}`
srv, _ := newServer(t, http.StatusOK, body)
m := testModel(t, srv, nil)
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err != nil {
t.Fatalf("Generate: %v", err)
}
if len(resp.ToolCalls) != 2 {
t.Fatalf("ToolCalls = %d, want 2", len(resp.ToolCalls))
}
tc := resp.ToolCalls[0]
if tc.ID != "call_9" || tc.Name != "get_weather" || string(tc.Arguments) != `{"city":"Boston"}` {
t.Errorf("ToolCalls[0] = %+v", tc)
}
if resp.ToolCalls[1].ID != "call_1" {
t.Errorf("synthesized ID = %q, want %q", resp.ToolCalls[1].ID, "call_1")
}
// finish_reason "stop" with tool_calls present: presence wins.
if resp.FinishReason != llm.FinishToolCalls {
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishToolCalls)
}
if len(resp.Parts) != 0 {
t.Errorf("Parts = %v, want none", resp.Parts)
}
}
func TestFinishReasonMapping(t *testing.T) {
tests := []struct {
wire string
want llm.FinishReason
}{
{"stop", llm.FinishStop},
{"length", llm.FinishLength},
{"tool_calls", llm.FinishToolCalls},
{"content_filter", llm.FinishContentFilter},
{"function_call", llm.FinishOther},
{"weird_new_reason", llm.FinishOther},
}
for _, tt := range tests {
t.Run(tt.wire, func(t *testing.T) {
body := `{"choices":[{"index":0,"message":{"role":"assistant","content":"x"},"finish_reason":"` + tt.wire + `"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`
srv, _ := newServer(t, http.StatusOK, body)
m := testModel(t, srv, nil)
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err != nil {
t.Fatalf("Generate: %v", err)
}
if resp.FinishReason != tt.want {
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, tt.want)
}
})
}
}
func TestAPIErrorMapping(t *testing.T) {
t.Run("429 rate limit is transient", func(t *testing.T) {
const body = `{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}`
srv, _ := newServer(t, http.StatusTooManyRequests, body)
m := testModel(t, srv, nil)
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok {
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
}
if apiErr.Status != http.StatusTooManyRequests {
t.Errorf("Status = %d, want 429", apiErr.Status)
}
if apiErr.Code != "rate_limit_exceeded" {
t.Errorf("Code = %q, want %q", apiErr.Code, "rate_limit_exceeded")
}
if apiErr.Message != "Rate limit reached" {
t.Errorf("Message = %q", apiErr.Message)
}
if apiErr.Provider != "openai" || apiErr.Model != "gpt-test" {
t.Errorf("Provider/Model = %q/%q", apiErr.Provider, apiErr.Model)
}
if got := llm.Classify(err); got != llm.ClassTransient {
t.Errorf("Classify = %v, want transient", got)
}
})
t.Run("401 code null falls back to type, permanent", func(t *testing.T) {
const body = `{"error":{"message":"Incorrect API key provided","type":"authentication_error","param":null,"code":null}}`
srv, _ := newServer(t, http.StatusUnauthorized, body)
m := testModel(t, srv, nil)
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok {
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
}
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" {
t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code)
}
if got := llm.Classify(err); got != llm.ClassPermanent {
t.Errorf("Classify = %v, want permanent", got)
}
})
t.Run("non-JSON body becomes message", func(t *testing.T) {
srv, _ := newServer(t, http.StatusServiceUnavailable, "upstream exploded\n")
m := testModel(t, srv, nil)
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok {
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
}
if apiErr.Status != http.StatusServiceUnavailable || apiErr.Message != "upstream exploded" {
t.Errorf("Status/Message = %d/%q", apiErr.Status, apiErr.Message)
}
})
}
func TestMissingAPIKey(t *testing.T) {
t.Setenv("OPENAI_API_KEY", "")
srv, rec := newServer(t, http.StatusOK, textResponse)
m, err := New(WithBaseURL(srv.URL)).Model("gpt-test")
if err != nil {
t.Fatalf("Model: %v", err)
}
_, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok {
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
}
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "missing_api_key" {
t.Errorf("Status/Code = %d/%q, want 401/missing_api_key", apiErr.Status, apiErr.Code)
}
if rec.hits != 0 {
t.Errorf("server hit %d times, want 0", rec.hits)
}
}
func TestEnvAPIKeyReadAtConstruction(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
t.Setenv("OPENAI_API_KEY", "env-secret")
p := New(WithBaseURL(srv.URL))
t.Setenv("OPENAI_API_KEY", "changed-later") // must not affect p
m, err := p.Model("gpt-test")
if err != nil {
t.Fatalf("Model: %v", err)
}
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil {
t.Fatalf("Generate: %v", err)
}
if got := rec.header.Get("Authorization"); got != "Bearer env-secret" {
t.Errorf("Authorization = %q, want %q", got, "Bearer env-secret")
}
}
func TestAuthAndContentTypeHeaders(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
m := testModel(t, srv, nil)
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil {
t.Fatalf("Generate: %v", err)
}
if got := rec.header.Get("Authorization"); got != "Bearer test-key" {
t.Errorf("Authorization = %q, want %q", got, "Bearer test-key")
}
if got := rec.header.Get("Content-Type"); got != "application/json" {
t.Errorf("Content-Type = %q, want application/json", got)
}
if rec.path != "/chat/completions" {
t.Errorf("path = %q, want /chat/completions", rec.path)
}
}
func TestCompatEndpointNameAndBaseURL(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
p := New(WithName("groq"), WithAPIKey("k"), WithBaseURL(srv.URL+"/openai/v1/"))
if p.Name() != "groq" {
t.Errorf("Name = %q, want groq", p.Name())
}
m, err := p.Model("llama-3.3-70b")
if err != nil {
t.Fatalf("Model: %v", err)
}
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err != nil {
t.Fatalf("Generate: %v", err)
}
if rec.path != "/openai/v1/chat/completions" {
t.Errorf("path = %q, want /openai/v1/chat/completions (trailing slash trimmed)", rec.path)
}
if resp.Model != "groq/llama-3.3-70b" {
t.Errorf("Model = %q, want groq/llama-3.3-70b", resp.Model)
}
if rec.body["model"] != "llama-3.3-70b" {
t.Errorf("wire model = %v, want llama-3.3-70b (verbatim)", rec.body["model"])
}
}
func TestCapabilityEnforcement(t *testing.T) {
img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) }
tests := []struct {
name string
caps *llm.Capabilities // nil = provider defaults
msg llm.Message
}{
{
name: "images unsupported",
caps: &llm.Capabilities{SupportsTools: true, SupportsStreaming: true},
msg: llm.UserParts(img("image/png", 4)),
},
{
name: "too many images",
caps: &llm.Capabilities{MaxImagesPerReq: 1},
msg: llm.UserParts(img("image/png", 4), img("image/png", 4)),
},
{
name: "disallowed MIME under defaults",
msg: llm.UserParts(img("image/bmp", 4)),
},
{
name: "image too large",
caps: &llm.Capabilities{MaxImagesPerReq: 4, MaxImageBytes: 2},
msg: llm.UserParts(img("image/png", 3)),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
var mopts []llm.ModelOption
if tt.caps != nil {
mopts = append(mopts, llm.WithCapabilities(*tt.caps))
}
m := testModel(t, srv, nil, mopts...)
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{tt.msg}})
if !errors.Is(err, llm.ErrUnsupported) {
t.Fatalf("err = %v, want ErrUnsupported", err)
}
if got := llm.Classify(err); got != llm.ClassPermanent {
t.Errorf("Classify = %v, want permanent", got)
}
if rec.hits != 0 {
t.Errorf("server hit %d times, want 0 (must refuse before sending)", rec.hits)
}
})
}
t.Run("streaming unsupported", func(t *testing.T) {
srv, rec := newServer(t, http.StatusOK, textResponse)
m := testModel(t, srv, nil, llm.WithCapabilities(llm.Capabilities{SupportsTools: true}))
_, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if !errors.Is(err, llm.ErrUnsupported) {
t.Fatalf("err = %v, want ErrUnsupported", err)
}
if rec.hits != 0 {
t.Errorf("server hit %d times, want 0", rec.hits)
}
})
}
func TestModelCapabilitiesOverride(t *testing.T) {
p := New(WithAPIKey("k"))
def, err := p.Model("a")
if err != nil {
t.Fatalf("Model: %v", err)
}
if caps := def.Capabilities(); !caps.SupportsTools || caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 20<<20 {
t.Errorf("default caps = %+v", caps)
}
custom := llm.Capabilities{SupportsStreaming: true, ContextWindow: 8192}
ovr, err := p.Model("b", llm.WithCapabilities(custom))
if err != nil {
t.Fatalf("Model: %v", err)
}
if got := ovr.Capabilities(); !reflect.DeepEqual(got, custom) {
t.Errorf("override caps = %+v, want %+v", got, custom)
}
}
func TestTransportErrorIsNotAPIError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
url := srv.URL
srv.Close() // guarantee connection refused
p := New(WithAPIKey("k"), WithBaseURL(url))
m, err := p.Model("gpt-test")
if err != nil {
t.Fatalf("Model: %v", err)
}
_, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err == nil {
t.Fatal("Generate succeeded against closed server")
}
if _, ok := errors.AsType[*llm.APIError](err); ok {
t.Errorf("transport error wrapped in APIError: %v", err)
}
if !strings.Contains(err.Error(), "openai: do request") {
t.Errorf("err = %v, want openai: do request context", err)
}
if got := llm.Classify(err); got != llm.ClassTransient {
t.Errorf("Classify = %v, want transient (net error must stay visible)", got)
}
}
func TestDecodeErrorWrapped(t *testing.T) {
srv, _ := newServer(t, http.StatusOK, "{not json")
m := testModel(t, srv, nil)
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err == nil || !strings.Contains(err.Error(), "openai: decode response") {
t.Errorf("err = %v, want decode response context", err)
}
if _, ok := errors.AsType[*llm.APIError](err); ok {
t.Errorf("decode error wrapped in APIError: %v", err)
}
}
+183
View File
@@ -0,0 +1,183 @@
package openai
import (
"bufio"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
// stream consumes the data-only SSE stream of chat.completion.chunk events.
//
// Delivery contract: TextDelta events as content fragments arrive; ToolCall
// events only once fully assembled (fragments are buffered internally and
// flushed at stream end — simplest correct handling of interleaved parallel
// calls); exactly one final Response event; then io.EOF.
type stream struct {
m *model
body io.ReadCloser
sc *bufio.Scanner
closeOnce sync.Once
closeErr error
queue []llm.StreamEvent
done bool // finalize ran; drain queue then io.EOF
text strings.Builder
calls []*toolCallAcc // first-appearance order
byIndex map[int]*toolCallAcc
finish string
usage llm.Usage
}
// toolCallAcc accumulates one tool call's fragments. The id and name arrive
// on the first fragment for an index; arguments arrive as string pieces to
// concatenate.
type toolCallAcc struct {
id string
name string
args strings.Builder
}
// Next implements llm.Stream.
func (s *stream) Next() (llm.StreamEvent, error) {
for {
if len(s.queue) > 0 {
ev := s.queue[0]
s.queue = s.queue[1:]
return ev, nil
}
if s.done {
return llm.StreamEvent{}, io.EOF
}
if !s.sc.Scan() {
if err := s.sc.Err(); err != nil {
return llm.StreamEvent{}, fmt.Errorf("openai: read stream: %w", err)
}
// Why: some compat servers close the body without a [DONE]
// sentinel; a clean EOF still finalizes with what arrived.
s.finalize()
continue
}
line := strings.TrimSpace(s.sc.Text())
if !strings.HasPrefix(line, "data:") {
continue // SSE comments, event:/id: fields, blank separators
}
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if payload == "" {
continue
}
if payload == "[DONE]" {
s.finalize()
continue
}
if err := s.handleChunk([]byte(payload)); err != nil {
return llm.StreamEvent{}, err
}
}
}
// handleChunk folds one chat.completion.chunk into the stream state,
// queueing any events it produces.
func (s *stream) handleChunk(data []byte) error {
var chunk streamChunk
if err := json.Unmarshal(data, &chunk); err != nil {
return fmt.Errorf("openai: decode stream chunk: %w", err)
}
if chunk.Error != nil {
// Mid-stream error event on an otherwise-200 stream. Status stays 0:
// there is no failing HTTP status to report.
apiErr := &llm.APIError{
Provider: s.m.p.name,
Model: s.m.id,
Code: chunk.Error.Code,
Message: chunk.Error.Message,
}
if apiErr.Code == "" {
apiErr.Code = chunk.Error.Type
}
return apiErr
}
if chunk.Usage != nil {
s.usage = llm.Usage{
InputTokens: chunk.Usage.PromptTokens,
OutputTokens: chunk.Usage.CompletionTokens,
}
}
// Why the guard: the include_usage chunk arrives with an EMPTY choices
// array; indexing choices[0] unconditionally would panic on it.
if len(chunk.Choices) == 0 {
return nil
}
choice := chunk.Choices[0]
if choice.FinishReason != "" {
s.finish = choice.FinishReason
}
if choice.Delta.Content != "" {
s.text.WriteString(choice.Delta.Content)
s.queue = append(s.queue, llm.StreamEvent{TextDelta: choice.Delta.Content})
}
for _, tc := range choice.Delta.ToolCalls {
acc := s.byIndex[tc.Index]
if acc == nil {
if s.byIndex == nil {
s.byIndex = make(map[int]*toolCallAcc)
}
acc = &toolCallAcc{}
s.byIndex[tc.Index] = acc
s.calls = append(s.calls, acc)
}
if tc.ID != "" {
acc.id = tc.ID
}
if tc.Function.Name != "" {
acc.name = tc.Function.Name
}
acc.args.WriteString(tc.Function.Arguments)
}
return nil
}
// finalize assembles the buffered tool calls and the final Response, queues
// them (ToolCall events first, Response last), and marks the stream done.
func (s *stream) finalize() {
if s.done {
return
}
s.done = true
resp := &llm.Response{Model: s.m.p.name + "/" + s.m.id, Usage: s.usage}
if s.text.Len() > 0 {
resp.Parts = []llm.Part{llm.TextPart{Text: s.text.String()}}
}
for i, acc := range s.calls {
id := acc.id
if id == "" {
// Why: ToolResult.ID must echo ToolCall.ID; synthesize for
// compat servers that stream calls without ids.
id = fmt.Sprintf("call_%d", i)
}
resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{
ID: id,
Name: acc.name,
Arguments: json.RawMessage(acc.args.String()),
})
}
resp.FinishReason = mapFinish(s.finish, len(resp.ToolCalls) > 0)
for i := range resp.ToolCalls {
tc := resp.ToolCalls[i] // copy so the event doesn't alias the slice
s.queue = append(s.queue, llm.StreamEvent{ToolCall: &tc})
}
s.queue = append(s.queue, llm.StreamEvent{Response: resp})
}
// Close implements llm.Stream. Closing the body unblocks any in-flight read
// and aborts the HTTP stream; safe to call at any time, including twice.
func (s *stream) Close() error {
s.closeOnce.Do(func() { s.closeErr = s.body.Close() })
return s.closeErr
}
+267
View File
@@ -0,0 +1,267 @@
package openai
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
// sseServer streams each payload as one "data: <payload>" SSE event and
// records the request like newServer.
func sseServer(t *testing.T, payloads ...string) (*httptest.Server, *recorded) {
t.Helper()
rec := &recorded{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rec.hits++
rec.header = r.Header.Clone()
rec.path = r.URL.Path
raw, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("read request body: %v", err)
}
if len(raw) > 0 {
if err := json.Unmarshal(raw, &rec.body); err != nil {
t.Errorf("request body is not JSON: %v\n%s", err, raw)
}
}
w.Header().Set("Content-Type", "text/event-stream")
for _, p := range payloads {
io.WriteString(w, "data: "+p+"\n\n")
}
}))
t.Cleanup(srv.Close)
return srv, rec
}
// collect drains a stream to io.EOF, failing the test on any other error.
func collect(t *testing.T, s llm.Stream) []llm.StreamEvent {
t.Helper()
var events []llm.StreamEvent
for {
ev, err := s.Next()
if err == io.EOF {
return events
}
if err != nil {
t.Fatalf("Next: %v", err)
}
events = append(events, ev)
}
}
func TestStreamText(t *testing.T) {
srv, rec := sseServer(t,
`{"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-test","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{"content":"Hel"},"finish_reason":null}],"obfuscation":"xK9q"}`,
`{"choices":[{"index":0,"delta":{"content":"lo"},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
`{"choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
`[DONE]`,
)
m := testModel(t, srv, nil)
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err != nil {
t.Fatalf("Stream: %v", err)
}
defer s.Close()
events := collect(t, s)
// Request shape: stream flag, usage opt-in, SSE accept header.
if rec.body["stream"] != true {
t.Errorf("stream = %v, want true", rec.body["stream"])
}
so, _ := rec.body["stream_options"].(map[string]any)
if so == nil || so["include_usage"] != true {
t.Errorf("stream_options = %v, want include_usage true", rec.body["stream_options"])
}
if got := rec.header.Get("Accept"); got != "text/event-stream" {
t.Errorf("Accept = %q, want text/event-stream", got)
}
if len(events) != 3 {
t.Fatalf("got %d events, want 3: %+v", len(events), events)
}
if events[0].TextDelta != "Hel" || events[1].TextDelta != "lo" {
t.Errorf("deltas = %q, %q, want Hel, lo", events[0].TextDelta, events[1].TextDelta)
}
final := events[2].Response
if final == nil {
t.Fatal("last event has no Response")
}
if got := final.Text(); got != "Hello" {
t.Errorf("final text = %q, want Hello", got)
}
if final.FinishReason != llm.FinishStop {
t.Errorf("FinishReason = %v, want stop", final.FinishReason)
}
if final.Usage != (llm.Usage{InputTokens: 5, OutputTokens: 2}) {
t.Errorf("Usage = %+v, want {5 2}", final.Usage)
}
if final.Model != "openai/gpt-test" {
t.Errorf("Model = %q, want openai/gpt-test", final.Model)
}
// Next after EOF keeps returning EOF; Close is idempotent.
if _, err := s.Next(); err != io.EOF {
t.Errorf("Next after EOF = %v, want io.EOF", err)
}
if err := s.Close(); err != nil {
t.Errorf("first Close: %v", err)
}
if err := s.Close(); err != nil {
t.Errorf("second Close: %v", err)
}
}
func TestStreamParallelToolCalls(t *testing.T) {
// Two interleaved calls with distinct indexes; id/name only on the first
// fragment of each; arguments split across fragments.
srv, _ := sseServer(t,
`{"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_a","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"city\":"}}]},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_b","type":"function","function":{"name":"get_time","arguments":"{\"tz\":"}}]},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Boston\"}"}}]},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"\"EST\"}"}}]},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
`{"choices":[],"usage":{"prompt_tokens":11,"completion_tokens":9,"total_tokens":20}}`,
`[DONE]`,
)
m := testModel(t, srv, nil)
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err != nil {
t.Fatalf("Stream: %v", err)
}
defer s.Close()
events := collect(t, s)
if len(events) != 3 {
t.Fatalf("got %d events, want 3 (two tool calls + response): %+v", len(events), events)
}
a, b := events[0].ToolCall, events[1].ToolCall
if a == nil || b == nil {
t.Fatalf("events 0/1 are not tool calls: %+v", events)
}
if a.ID != "call_a" || a.Name != "get_weather" || string(a.Arguments) != `{"city":"Boston"}` {
t.Errorf("first call = %+v", a)
}
if b.ID != "call_b" || b.Name != "get_time" || string(b.Arguments) != `{"tz":"EST"}` {
t.Errorf("second call = %+v", b)
}
final := events[2].Response
if final == nil {
t.Fatal("last event has no Response")
}
if len(final.ToolCalls) != 2 {
t.Fatalf("final ToolCalls = %d, want 2", len(final.ToolCalls))
}
if final.ToolCalls[0].ID != "call_a" || final.ToolCalls[1].ID != "call_b" {
t.Errorf("final ToolCalls order = %q, %q", final.ToolCalls[0].ID, final.ToolCalls[1].ID)
}
if final.FinishReason != llm.FinishToolCalls {
t.Errorf("FinishReason = %v, want tool_calls", final.FinishReason)
}
if final.Usage != (llm.Usage{InputTokens: 11, OutputTokens: 9}) {
t.Errorf("Usage = %+v, want {11 9}", final.Usage)
}
}
func TestStreamMidStreamError(t *testing.T) {
srv, _ := sseServer(t,
`{"choices":[{"index":0,"delta":{"content":"par"},"finish_reason":null}]}`,
`{"error":{"message":"The server had an error while processing your request","type":"server_error","param":null,"code":null}}`,
)
m := testModel(t, srv, nil)
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err != nil {
t.Fatalf("Stream: %v", err)
}
defer s.Close()
ev, err := s.Next()
if err != nil || ev.TextDelta != "par" {
t.Fatalf("first event = %+v, %v; want TextDelta par", ev, err)
}
_, err = s.Next()
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok {
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
}
if apiErr.Code != "server_error" {
t.Errorf("Code = %q, want server_error", apiErr.Code)
}
if apiErr.Message != "The server had an error while processing your request" {
t.Errorf("Message = %q", apiErr.Message)
}
if apiErr.Status != 0 {
t.Errorf("Status = %d, want 0 (the HTTP stream was 200)", apiErr.Status)
}
}
func TestStreamHTTPError(t *testing.T) {
srv, _ := newServer(t, http.StatusTooManyRequests,
`{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}`)
m := testModel(t, srv, nil)
_, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
apiErr, ok := errors.AsType[*llm.APIError](err)
if !ok {
t.Fatalf("err = %v (%T), want *llm.APIError from Stream itself", err, err)
}
if apiErr.Status != http.StatusTooManyRequests || apiErr.Code != "rate_limit_exceeded" {
t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code)
}
if got := llm.Classify(err); got != llm.ClassTransient {
t.Errorf("Classify = %v, want transient", got)
}
}
func TestStreamWithoutDoneSentinel(t *testing.T) {
// Why: some compat servers close the connection without "data: [DONE]";
// a clean EOF must still produce the final Response.
srv, _ := sseServer(t,
`{"choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
)
m := testModel(t, srv, nil)
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err != nil {
t.Fatalf("Stream: %v", err)
}
defer s.Close()
events := collect(t, s)
if len(events) != 2 {
t.Fatalf("got %d events, want 2: %+v", len(events), events)
}
final := events[1].Response
if final == nil || final.Text() != "ok" || final.FinishReason != llm.FinishStop {
t.Errorf("final = %+v", final)
}
}
func TestStreamCloseEarly(t *testing.T) {
srv, _ := sseServer(t,
`{"choices":[{"index":0,"delta":{"content":"a"},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{"content":"b"},"finish_reason":null}]}`,
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
`[DONE]`,
)
m := testModel(t, srv, nil)
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
if err != nil {
t.Fatalf("Stream: %v", err)
}
if _, err := s.Next(); err != nil {
t.Fatalf("Next: %v", err)
}
if err := s.Close(); err != nil {
t.Errorf("Close mid-stream: %v", err)
}
if err := s.Close(); err != nil {
t.Errorf("Close again: %v", err)
}
}
+321
View File
@@ -0,0 +1,321 @@
package openai
import (
"encoding/base64"
"encoding/json"
"strings"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
// --- request wire shapes ---
type chatRequest struct {
Model string `json:"model"`
Messages []wireMessage `json:"messages"`
Tools []wireTool `json:"tools,omitempty"`
// ToolChoice is "auto"/"none"/"required" (string) or a named-function
// object; any avoids two fields for one wire key.
ToolChoice any `json:"tool_choice,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stop []string `json:"stop,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
ResponseFormat *wireRespFormat `json:"response_format,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *wireStreamOptions `json:"stream_options,omitempty"`
}
type wireMessage struct {
Role string `json:"role"`
// Content is a string for text-only turns, a part array for multimodal
// turns, or nil (wire null) for assistant turns that only call tools.
Content any `json:"content"`
ToolCalls []wireToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type wireTextPart struct {
Type string `json:"type"`
Text string `json:"text"`
}
type wireImagePart struct {
Type string `json:"type"`
ImageURL wireImageURL `json:"image_url"`
}
type wireImageURL struct {
URL string `json:"url"`
}
type wireToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function wireFunctionCall `json:"function"`
}
type wireFunctionCall struct {
Name string `json:"name"`
// Arguments is a JSON-encoded STRING per the wire format, not an object.
Arguments string `json:"arguments"`
}
type wireTool struct {
Type string `json:"type"`
Function wireToolFunction `json:"function"`
}
type wireToolFunction struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
type wireNamedToolChoice struct {
Type string `json:"type"`
Function wireToolName `json:"function"`
}
type wireToolName struct {
Name string `json:"name"`
}
type wireRespFormat struct {
Type string `json:"type"`
JSONSchema *wireJSONSchema `json:"json_schema,omitempty"`
}
// wireJSONSchema omits the strict flag on purpose: strict mode imposes
// schema rewrites (every property required, additionalProperties:false at
// every level) that belong to the caller, not the transport.
type wireJSONSchema struct {
Name string `json:"name"`
Schema json.RawMessage `json:"schema"`
}
type wireStreamOptions struct {
IncludeUsage bool `json:"include_usage"`
}
// --- response wire shapes (loose: unknown fields ignored) ---
type chatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []chatChoice `json:"choices"`
Usage *wireUsage `json:"usage"`
}
type chatChoice struct {
Index int `json:"index"`
Message wireRespMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type wireRespMessage struct {
Role string `json:"role"`
Content string `json:"content"` // null decodes to ""
Refusal string `json:"refusal"` // tolerated, unused
ToolCalls []wireToolCall `json:"tool_calls"`
}
type wireUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type errorEnvelope struct {
Error wireError `json:"error"`
}
type wireError struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"` // null decodes to ""
}
// --- streaming wire shapes ---
type streamChunk struct {
Choices []streamChoice `json:"choices"`
Usage *wireUsage `json:"usage"`
Error *wireError `json:"error"` // mid-stream error event
}
type streamChoice struct {
Index int `json:"index"`
Delta streamDelta `json:"delta"`
FinishReason string `json:"finish_reason"` // null decodes to ""
}
type streamDelta struct {
Content string `json:"content"` // null decodes to ""
ToolCalls []streamToolCallDelta `json:"tool_calls"`
}
// streamToolCallDelta is one tool-call fragment. The id and name appear only
// on a call's first fragment; later fragments carry just index + an
// arguments substring. Accumulation keys on Index, never ID.
type streamToolCallDelta struct {
Index int `json:"index"`
ID string `json:"id"`
Function wireFunctionCall `json:"function"`
}
// --- mapping: llm.Request -> chatRequest ---
// buildRequest translates the canonical request to the wire shape. The
// capability check has already passed by the time this runs.
func (m *model) buildRequest(req llm.Request, stream bool) *chatRequest {
out := &chatRequest{
Model: m.id,
Temperature: req.Temperature,
TopP: req.TopP,
Stop: req.StopSequences,
ReasoningEffort: req.ReasoningEffort,
}
// Fold Request.System and every RoleSystem message into one leading
// system message, System field first. Why: the canonical contract allows
// system content in both places; OpenAI wants one system mechanism.
var sys []string
if req.System != "" {
sys = append(sys, req.System)
}
for _, msg := range req.Messages {
if msg.Role == llm.RoleSystem {
if t := msg.Text(); t != "" {
sys = append(sys, t)
}
}
}
if joined := strings.Join(sys, "\n\n"); joined != "" {
out.Messages = append(out.Messages, wireMessage{Role: "system", Content: joined})
}
for _, msg := range req.Messages {
switch msg.Role {
case llm.RoleSystem:
// Folded above; excluded from the normal message list.
case llm.RoleUser:
out.Messages = append(out.Messages, wireMessage{Role: "user", Content: contentValue(msg.Parts)})
case llm.RoleAssistant:
wm := wireMessage{Role: "assistant"}
if text := msg.Text(); text != "" {
wm.Content = text
}
for _, tc := range msg.ToolCalls {
args := string(tc.Arguments)
if args == "" {
// Why: arguments must be a valid JSON document string;
// an empty string is not one.
args = "{}"
}
wm.ToolCalls = append(wm.ToolCalls, wireToolCall{
ID: tc.ID,
Type: "function",
Function: wireFunctionCall{Name: tc.Name, Arguments: args},
})
}
out.Messages = append(out.Messages, wm)
case llm.RoleTool:
// One wire message per result: the API pairs each tool output
// with its call via tool_call_id, one message each.
for _, tr := range msg.ToolResults {
content := tr.Content
if tr.IsError {
content = "ERROR: " + content
}
out.Messages = append(out.Messages, wireMessage{
Role: "tool",
Content: content,
ToolCallID: tr.ID,
})
}
}
}
for _, t := range req.Tools {
out.Tools = append(out.Tools, wireTool{
Type: "function",
Function: wireToolFunction{Name: t.Name, Description: t.Description, Parameters: t.Parameters},
})
}
switch req.ToolChoice {
case "":
// Omit: provider default ("auto" when tools are present).
case "auto", "none", "required":
out.ToolChoice = req.ToolChoice
default:
// Any other value names the one tool the model must call.
out.ToolChoice = wireNamedToolChoice{Type: "function", Function: wireToolName{Name: req.ToolChoice}}
}
if req.MaxTokens > 0 {
if m.p.legacyMaxTokens {
out.MaxTokens = req.MaxTokens
} else {
out.MaxCompletionTokens = req.MaxTokens
}
}
if len(req.Schema) > 0 {
name := req.SchemaName
if name == "" {
name = "response"
}
out.ResponseFormat = &wireRespFormat{
Type: "json_schema",
JSONSchema: &wireJSONSchema{Name: name, Schema: req.Schema},
}
}
if stream {
out.Stream = true
// Why: without include_usage the stream never reports token counts;
// the usage arrives in one extra chunk with an empty choices array.
out.StreamOptions = &wireStreamOptions{IncludeUsage: true}
}
return out
}
// contentValue renders message parts as the wire content value: a plain
// string when text-only (maximum compat), a part array when images are
// present.
func contentValue(parts []llm.Part) any {
multimodal := false
for _, p := range parts {
if _, ok := p.(llm.ImagePart); ok {
multimodal = true
break
}
}
if !multimodal {
var b strings.Builder
for _, p := range parts {
if t, ok := p.(llm.TextPart); ok {
b.WriteString(t.Text)
}
}
return b.String()
}
out := make([]any, 0, len(parts))
for _, p := range parts {
switch v := p.(type) {
case llm.TextPart:
out = append(out, wireTextPart{Type: "text", Text: v.Text})
case llm.ImagePart:
url := "data:" + v.MIME + ";base64," + base64.StdEncoding.EncodeToString(v.Data)
out = append(out, wireImagePart{Type: "image_url", ImageURL: wireImageURL{URL: url}})
}
}
return out
}