feat: OpenAI, Anthropic, and native-Ollama providers + media pipeline
Phase 3: - provider/openai: Chat Completions for OpenAI + compat endpoints (SSE streaming with by-index tool-call assembly, response_format json_schema, legacy max_tokens option, reasoning_effort) - provider/anthropic: Messages API (tool_use/tool_result, GA structured output via output_config.format, full SSE event parser, 529 transient) - provider/ollama: one native /api/chat client behind the ollama, ollama-cloud, and foreman built-ins (presets; NDJSON streaming tolerant of foreman's buffered single-object responses; object tool arguments; format-schema structured output; think mapping) - media/: capability normalization (sniff, downscale, transcode, byte ladder, ErrUnsupported), wired into the chain executor per target with penalty-free advance past incapable elements - registry: real provider + scheme wiring, WithHTTPClient option, required env-foreman TLS chat round-trip test - ADR-0009 multimodal strategy, ADR-0010 tools/structured mapping; README matrix + CLAUDE.md synced Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,222 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// model is one provider-bound target.
|
||||
type model struct {
|
||||
p *Provider
|
||||
id string
|
||||
caps llm.Capabilities
|
||||
}
|
||||
|
||||
// Capabilities implements llm.Model.
|
||||
func (m *model) Capabilities() llm.Capabilities { return m.caps }
|
||||
|
||||
// Generate implements llm.Model.
|
||||
func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) {
|
||||
req = req.Apply(opts...)
|
||||
if err := checkRequest(m.caps, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpResp, err := m.do(ctx, req, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
if httpResp.StatusCode/100 != 2 {
|
||||
return nil, m.apiError(httpResp)
|
||||
}
|
||||
var wire chatResponse
|
||||
if err := json.NewDecoder(httpResp.Body).Decode(&wire); err != nil {
|
||||
return nil, fmt.Errorf("openai: decode response: %w", err)
|
||||
}
|
||||
return m.toResponse(&wire), nil
|
||||
}
|
||||
|
||||
// Stream implements llm.Model.
|
||||
func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) {
|
||||
req = req.Apply(opts...)
|
||||
if !m.caps.SupportsStreaming {
|
||||
return nil, fmt.Errorf("%w: streaming not supported by %s/%s", llm.ErrUnsupported, m.p.name, m.id)
|
||||
}
|
||||
if err := checkRequest(m.caps, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpResp, err := m.do(ctx, req, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if httpResp.StatusCode/100 != 2 {
|
||||
defer httpResp.Body.Close()
|
||||
return nil, m.apiError(httpResp)
|
||||
}
|
||||
sc := bufio.NewScanner(httpResp.Body)
|
||||
// Why: a single SSE data line carries a whole JSON chunk; tool-call
|
||||
// argument fragments can make lines far larger than Scanner's 64 KiB
|
||||
// default cap.
|
||||
sc.Buffer(make([]byte, 0, 64*1024), 16<<20)
|
||||
return &stream{m: m, body: httpResp.Body, sc: sc}, nil
|
||||
}
|
||||
|
||||
// do builds and performs the HTTP request. Transport failures are wrapped
|
||||
// raw (never as *llm.APIError) so llm.Classify still sees net.Error,
|
||||
// syscall errnos, and context errors underneath.
|
||||
func (m *model) do(ctx context.Context, req llm.Request, stream bool) (*http.Response, error) {
|
||||
if m.p.apiKey == "" {
|
||||
// Why a synthetic 401: the constructor never fails, so a missing
|
||||
// key must surface at request time as the auth failure it is —
|
||||
// permanent under llm.Classify, like a real 401.
|
||||
return nil, &llm.APIError{
|
||||
Provider: m.p.name,
|
||||
Model: m.id,
|
||||
Status: http.StatusUnauthorized,
|
||||
Code: "missing_api_key",
|
||||
Message: "no API key configured: set OPENAI_API_KEY or use WithAPIKey",
|
||||
}
|
||||
}
|
||||
body, err := json.Marshal(m.buildRequest(req, stream))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai: encode request: %w", err)
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, m.p.baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+m.p.apiKey)
|
||||
if stream {
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
httpResp, err := m.p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai: do request: %w", err)
|
||||
}
|
||||
return httpResp, nil
|
||||
}
|
||||
|
||||
// apiError converts a non-2xx response into *llm.APIError, pulling code and
|
||||
// message from the {"error":{...}} body when it parses.
|
||||
func (m *model) apiError(httpResp *http.Response) error {
|
||||
apiErr := &llm.APIError{Provider: m.p.name, Model: m.id, Status: httpResp.StatusCode}
|
||||
body, _ := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20))
|
||||
var env errorEnvelope
|
||||
if err := json.Unmarshal(body, &env); err == nil &&
|
||||
(env.Error.Message != "" || env.Error.Type != "" || env.Error.Code != "") {
|
||||
apiErr.Message = env.Error.Message
|
||||
apiErr.Code = env.Error.Code
|
||||
if apiErr.Code == "" {
|
||||
apiErr.Code = env.Error.Type
|
||||
}
|
||||
} else {
|
||||
// Why: compat servers emit all sorts of error bodies; a raw snippet
|
||||
// beats silence when the canonical envelope is absent.
|
||||
apiErr.Message = strings.TrimSpace(string(body))
|
||||
}
|
||||
return apiErr
|
||||
}
|
||||
|
||||
// toResponse maps the wire response onto the canonical llm.Response.
|
||||
func (m *model) toResponse(wire *chatResponse) *llm.Response {
|
||||
resp := &llm.Response{Model: m.p.name + "/" + m.id, Raw: wire}
|
||||
if wire.Usage != nil {
|
||||
resp.Usage = llm.Usage{
|
||||
InputTokens: wire.Usage.PromptTokens,
|
||||
OutputTokens: wire.Usage.CompletionTokens,
|
||||
}
|
||||
}
|
||||
if len(wire.Choices) == 0 {
|
||||
resp.FinishReason = llm.FinishOther
|
||||
return resp
|
||||
}
|
||||
choice := wire.Choices[0]
|
||||
if choice.Message.Content != "" {
|
||||
resp.Parts = append(resp.Parts, llm.TextPart{Text: choice.Message.Content})
|
||||
}
|
||||
for i, tc := range choice.Message.ToolCalls {
|
||||
id := tc.ID
|
||||
if id == "" {
|
||||
// Why: ToolResult.ID must echo ToolCall.ID, so calls from compat
|
||||
// servers that omit ids get synthesized ones.
|
||||
id = fmt.Sprintf("call_%d", i)
|
||||
}
|
||||
resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{
|
||||
ID: id,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: json.RawMessage(tc.Function.Arguments),
|
||||
})
|
||||
}
|
||||
resp.FinishReason = mapFinish(choice.FinishReason, len(resp.ToolCalls) > 0)
|
||||
return resp
|
||||
}
|
||||
|
||||
// mapFinish maps a wire finish_reason to the canonical enum. Tool-call
|
||||
// presence wins over the reported reason: a forced (named tool_choice) call
|
||||
// can finish with "stop" while still carrying tool_calls.
|
||||
func mapFinish(reason string, hasToolCalls bool) llm.FinishReason {
|
||||
if hasToolCalls {
|
||||
return llm.FinishToolCalls
|
||||
}
|
||||
switch reason {
|
||||
case "stop":
|
||||
return llm.FinishStop
|
||||
case "length":
|
||||
return llm.FinishLength
|
||||
case "tool_calls":
|
||||
return llm.FinishToolCalls
|
||||
case "content_filter":
|
||||
return llm.FinishContentFilter
|
||||
default:
|
||||
return llm.FinishOther
|
||||
}
|
||||
}
|
||||
|
||||
// checkRequest enforces the model's effective capabilities. Why enforcement
|
||||
// rather than normalization: a separate media layer resizes/transcodes
|
||||
// images BEFORE requests reach the provider; this check is the honest
|
||||
// backstop that refuses, with llm.ErrUnsupported, what the target
|
||||
// declaredly cannot serve (chains advance past it penalty-free).
|
||||
func checkRequest(caps llm.Capabilities, req llm.Request) error {
|
||||
if len(req.Tools) > 0 && !caps.SupportsTools {
|
||||
return fmt.Errorf("%w: tools not supported", llm.ErrUnsupported)
|
||||
}
|
||||
if len(req.Schema) > 0 && !caps.SupportsStructured {
|
||||
return fmt.Errorf("%w: structured output not supported", llm.ErrUnsupported)
|
||||
}
|
||||
images := 0
|
||||
for _, msg := range req.Messages {
|
||||
for _, part := range msg.Parts {
|
||||
img, ok := part.(llm.ImagePart)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
images++
|
||||
if !caps.SupportsImages() {
|
||||
return fmt.Errorf("%w: image input not supported", llm.ErrUnsupported)
|
||||
}
|
||||
if !caps.MIMEAllowed(img.MIME) {
|
||||
return fmt.Errorf("%w: image MIME type %q not allowed (allowed: %s)",
|
||||
llm.ErrUnsupported, img.MIME, strings.Join(caps.AllowedImageMIME, ", "))
|
||||
}
|
||||
if caps.MaxImageBytes > 0 && len(img.Data) > caps.MaxImageBytes {
|
||||
return fmt.Errorf("%w: image is %d bytes, limit is %d",
|
||||
llm.ErrUnsupported, len(img.Data), caps.MaxImageBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
if images > caps.MaxImagesPerReq {
|
||||
return fmt.Errorf("%w: request carries %d images, limit is %d",
|
||||
llm.ErrUnsupported, images, caps.MaxImagesPerReq)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
// Package openai implements llm.Provider for the OpenAI Chat Completions
|
||||
// API and, via WithBaseURL/WithName, any OpenAI-compatible endpoint
|
||||
// (vLLM, Groq, Together, LM Studio, Ollama's /v1 shim, ...).
|
||||
//
|
||||
// Targeted API surface (verified against developers.openai.com, June 2026):
|
||||
// POST {base}/chat/completions with
|
||||
// - messages: plain-string content for text-only turns, part arrays with
|
||||
// base64 data-URL image_url entries for multimodal turns, assistant
|
||||
// tool_calls history, and {"role":"tool","tool_call_id",...} results;
|
||||
// - tools as {"type":"function","function":{...}} with tool_choice
|
||||
// "auto"/"none"/"required" or a named-function object;
|
||||
// - response_format {"type":"json_schema",...} structured output;
|
||||
// - max_completion_tokens (or legacy max_tokens via WithLegacyMaxTokens
|
||||
// for compat servers), temperature, top_p, stop, reasoning_effort;
|
||||
// - data-only SSE streaming with stream_options.include_usage, the
|
||||
// "data: [DONE]" sentinel, and tool-call deltas accumulated by index.
|
||||
//
|
||||
// Newer response fields (refusal, annotations, usage *_details, delta
|
||||
// obfuscation) are tolerated and ignored so both api.openai.com and older
|
||||
// compat servers decode cleanly.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
const defaultBaseURL = "https://api.openai.com/v1"
|
||||
|
||||
// Provider is an llm.Provider backed by an OpenAI Chat Completions endpoint.
|
||||
type Provider struct {
|
||||
name string
|
||||
apiKey string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
caps llm.Capabilities
|
||||
legacyMaxTokens bool
|
||||
}
|
||||
|
||||
// Option configures the provider at construction.
|
||||
type Option func(*Provider)
|
||||
|
||||
// WithAPIKey sets the API key. When absent, New reads OPENAI_API_KEY from
|
||||
// the environment at construction time.
|
||||
func WithAPIKey(key string) Option {
|
||||
return func(p *Provider) { p.apiKey = key }
|
||||
}
|
||||
|
||||
// WithBaseURL points the client at a different endpoint (compat servers).
|
||||
// The path "/chat/completions" is appended; a trailing slash is trimmed.
|
||||
func WithBaseURL(u string) Option {
|
||||
return func(p *Provider) { p.baseURL = u }
|
||||
}
|
||||
|
||||
// WithHTTPClient substitutes the HTTP client (timeouts, proxies, tests).
|
||||
func WithHTTPClient(c *http.Client) Option {
|
||||
return func(p *Provider) {
|
||||
if c != nil {
|
||||
p.client = c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithName overrides the registry name ("openai" by default). Why: the same
|
||||
// client serves many OpenAI-compatible endpoints, and each needs a distinct
|
||||
// name in "provider/model" specs and error reporting.
|
||||
func WithName(name string) Option {
|
||||
return func(p *Provider) { p.name = name }
|
||||
}
|
||||
|
||||
// WithDefaultCapabilities replaces the provider-default capabilities.
|
||||
// Per-model overrides via llm.WithCapabilities still take precedence.
|
||||
func WithDefaultCapabilities(caps llm.Capabilities) Option {
|
||||
return func(p *Provider) { p.caps = caps }
|
||||
}
|
||||
|
||||
// WithLegacyMaxTokens sends Request.MaxTokens as "max_tokens" instead of
|
||||
// "max_completion_tokens". Why: OpenAI deprecated max_tokens, but many
|
||||
// third-party compat servers still only honor the legacy field.
|
||||
func WithLegacyMaxTokens() Option {
|
||||
return func(p *Provider) { p.legacyMaxTokens = true }
|
||||
}
|
||||
|
||||
// defaultCapabilities reflects OpenAI's current vision-capable chat models.
|
||||
// Why these limits: the published per-request caps (1500 images, 512 MB)
|
||||
// are far beyond what compat servers accept; 100 images / 20 MB each is a
|
||||
// conservative envelope, and the MIME list is the documented set (PNG,
|
||||
// JPEG, WEBP, non-animated GIF).
|
||||
func defaultCapabilities() llm.Capabilities {
|
||||
return llm.Capabilities{
|
||||
SupportsTools: true,
|
||||
SupportsStructured: true,
|
||||
SupportsStreaming: true,
|
||||
MaxImagesPerReq: 100,
|
||||
MaxImageBytes: 20 << 20,
|
||||
AllowedImageMIME: []string{"image/jpeg", "image/png", "image/webp", "image/gif"},
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a Provider. It never fails: a missing API key surfaces as a
|
||||
// 401-style *llm.APIError at request time, not at construction.
|
||||
func New(opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
name: "openai",
|
||||
apiKey: os.Getenv("OPENAI_API_KEY"),
|
||||
baseURL: defaultBaseURL,
|
||||
client: http.DefaultClient,
|
||||
caps: defaultCapabilities(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(p)
|
||||
}
|
||||
p.baseURL = strings.TrimRight(p.baseURL, "/")
|
||||
return p
|
||||
}
|
||||
|
||||
// Name implements llm.Provider.
|
||||
func (p *Provider) Name() string { return p.name }
|
||||
|
||||
// Model implements llm.Provider. The id is passed through verbatim — no
|
||||
// catalog validation; unknown models fail at request time with the
|
||||
// backend's own error.
|
||||
func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) {
|
||||
cfg := llm.ApplyModelOptions(opts)
|
||||
caps := p.caps
|
||||
if cfg.Capabilities != nil {
|
||||
caps = *cfg.Capabilities
|
||||
}
|
||||
return &model{p: p, id: id, caps: caps}, nil
|
||||
}
|
||||
@@ -0,0 +1,614 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
var (
|
||||
_ llm.Provider = (*Provider)(nil)
|
||||
_ llm.Model = (*model)(nil)
|
||||
_ llm.Stream = (*stream)(nil)
|
||||
)
|
||||
|
||||
const textResponse = `{
|
||||
"id": "chatcmpl-1", "object": "chat.completion", "created": 1741570283, "model": "gpt-test",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "hello", "refusal": null, "annotations": []},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 19, "completion_tokens": 10, "total_tokens": 29,
|
||||
"prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
|
||||
"completion_tokens_details": {"reasoning_tokens": 0}
|
||||
},
|
||||
"service_tier": "default", "system_fingerprint": "fp_x"
|
||||
}`
|
||||
|
||||
// recorded captures the last request a test server received.
|
||||
type recorded struct {
|
||||
body map[string]any
|
||||
header http.Header
|
||||
path string
|
||||
hits int
|
||||
}
|
||||
|
||||
// newServer starts a test server that records the request and replies with
|
||||
// a fixed status and body.
|
||||
func newServer(t *testing.T, status int, respBody string) (*httptest.Server, *recorded) {
|
||||
t.Helper()
|
||||
rec := &recorded{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rec.hits++
|
||||
rec.header = r.Header.Clone()
|
||||
rec.path = r.URL.Path
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read request body: %v", err)
|
||||
}
|
||||
if len(raw) > 0 {
|
||||
if err := json.Unmarshal(raw, &rec.body); err != nil {
|
||||
t.Errorf("request body is not JSON: %v\n%s", err, raw)
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
io.WriteString(w, respBody)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
return srv, rec
|
||||
}
|
||||
|
||||
func testModel(t *testing.T, srv *httptest.Server, popts []Option, mopts ...llm.ModelOption) llm.Model {
|
||||
t.Helper()
|
||||
opts := append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, popts...)
|
||||
m, err := New(opts...).Model("gpt-test", mopts...)
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func fptr(f float64) *float64 { return &f }
|
||||
|
||||
func TestGenerateRequestShape(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
|
||||
req := llm.Request{
|
||||
System: "base system",
|
||||
Messages: []llm.Message{
|
||||
llm.SystemText("folded system"),
|
||||
llm.UserParts(llm.Text("look:"), llm.Image("image/png", []byte{1, 2, 3})),
|
||||
{
|
||||
Role: llm.RoleAssistant,
|
||||
Parts: []llm.Part{llm.Text("checking")},
|
||||
ToolCalls: []llm.ToolCall{
|
||||
{ID: "call_1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Boston"}`)},
|
||||
},
|
||||
},
|
||||
llm.ToolResultsMessage(
|
||||
llm.ToolResult{ID: "call_1", Name: "get_weather", Content: "72F"},
|
||||
llm.ToolResult{ID: "call_2", Name: "get_weather", Content: "boom", IsError: true},
|
||||
),
|
||||
llm.UserText("thanks"),
|
||||
},
|
||||
Tools: []llm.Tool{{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
|
||||
}},
|
||||
ToolChoice: "auto",
|
||||
Temperature: fptr(0.5),
|
||||
TopP: fptr(0.9),
|
||||
MaxTokens: 256,
|
||||
StopSequences: []string{"END"},
|
||||
ReasoningEffort: "high",
|
||||
Schema: json.RawMessage(`{"type":"object","properties":{"ok":{"type":"boolean"}}}`),
|
||||
SchemaName: "verdict",
|
||||
}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
|
||||
want := map[string]any{
|
||||
"model": "gpt-test",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "system", "content": "base system\n\nfolded system"},
|
||||
map[string]any{"role": "user", "content": []any{
|
||||
map[string]any{"type": "text", "text": "look:"},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,AQID"}},
|
||||
}},
|
||||
map[string]any{"role": "assistant", "content": "checking", "tool_calls": []any{
|
||||
map[string]any{"id": "call_1", "type": "function", "function": map[string]any{
|
||||
"name": "get_weather", "arguments": `{"city":"Boston"}`,
|
||||
}},
|
||||
}},
|
||||
map[string]any{"role": "tool", "content": "72F", "tool_call_id": "call_1"},
|
||||
map[string]any{"role": "tool", "content": "ERROR: boom", "tool_call_id": "call_2"},
|
||||
map[string]any{"role": "user", "content": "thanks"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"type": "function", "function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather",
|
||||
"parameters": map[string]any{"type": "object", "properties": map[string]any{"city": map[string]any{"type": "string"}}},
|
||||
}},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.9,
|
||||
"max_completion_tokens": float64(256),
|
||||
"stop": []any{"END"},
|
||||
"reasoning_effort": "high",
|
||||
"response_format": map[string]any{"type": "json_schema", "json_schema": map[string]any{
|
||||
"name": "verdict",
|
||||
"schema": map[string]any{"type": "object", "properties": map[string]any{"ok": map[string]any{"type": "boolean"}}},
|
||||
}},
|
||||
}
|
||||
if !reflect.DeepEqual(rec.body, want) {
|
||||
got, _ := json.MarshalIndent(rec.body, "", " ")
|
||||
exp, _ := json.MarshalIndent(want, "", " ")
|
||||
t.Errorf("request body mismatch\ngot:\n%s\nwant:\n%s", got, exp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolChoiceForms(t *testing.T) {
|
||||
tests := []struct {
|
||||
choice string
|
||||
want any // nil = key absent
|
||||
}{
|
||||
{"", nil},
|
||||
{"auto", "auto"},
|
||||
{"none", "none"},
|
||||
{"required", "required"},
|
||||
{"get_weather", map[string]any{"type": "function", "function": map[string]any{"name": "get_weather"}}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run("choice="+tt.choice, func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
req := llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
Tools: []llm.Tool{{Name: "get_weather"}},
|
||||
ToolChoice: tt.choice,
|
||||
}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
got, present := rec.body["tool_choice"]
|
||||
if tt.want == nil {
|
||||
if present {
|
||||
t.Errorf("tool_choice present, want omitted: %v", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("tool_choice = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxTokensField(t *testing.T) {
|
||||
t.Run("default uses max_completion_tokens", func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := rec.body["max_completion_tokens"]; got != float64(64) {
|
||||
t.Errorf("max_completion_tokens = %v, want 64", got)
|
||||
}
|
||||
if _, present := rec.body["max_tokens"]; present {
|
||||
t.Error("max_tokens present, want omitted")
|
||||
}
|
||||
})
|
||||
t.Run("WithLegacyMaxTokens uses max_tokens", func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, []Option{WithLegacyMaxTokens()})
|
||||
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := rec.body["max_tokens"]; got != float64(64) {
|
||||
t.Errorf("max_tokens = %v, want 64", got)
|
||||
}
|
||||
if _, present := rec.body["max_completion_tokens"]; present {
|
||||
t.Error("max_completion_tokens present, want omitted")
|
||||
}
|
||||
})
|
||||
t.Run("zero omits both", func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if _, present := rec.body["max_tokens"]; present {
|
||||
t.Error("max_tokens present, want omitted")
|
||||
}
|
||||
if _, present := rec.body["max_completion_tokens"]; present {
|
||||
t.Error("max_completion_tokens present, want omitted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSchemaNameDefault(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
req := llm.Request{
|
||||
Messages: []llm.Message{llm.UserText("hi")},
|
||||
Schema: json.RawMessage(`{"type":"object"}`),
|
||||
}
|
||||
if _, err := m.Generate(context.Background(), req); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
rf, ok := rec.body["response_format"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("response_format missing: %v", rec.body)
|
||||
}
|
||||
js, ok := rf["json_schema"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("json_schema missing: %v", rf)
|
||||
}
|
||||
if js["name"] != "response" {
|
||||
t.Errorf("schema name = %v, want %q", js["name"], "response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTextResponse(t *testing.T) {
|
||||
srv, _ := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := resp.Text(); got != "hello" {
|
||||
t.Errorf("Text = %q, want %q", got, "hello")
|
||||
}
|
||||
if resp.FinishReason != llm.FinishStop {
|
||||
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishStop)
|
||||
}
|
||||
if resp.Usage != (llm.Usage{InputTokens: 19, OutputTokens: 10}) {
|
||||
t.Errorf("Usage = %+v, want {19 10}", resp.Usage)
|
||||
}
|
||||
if resp.Model != "openai/gpt-test" {
|
||||
t.Errorf("Model = %q, want %q", resp.Model, "openai/gpt-test")
|
||||
}
|
||||
if len(resp.ToolCalls) != 0 {
|
||||
t.Errorf("ToolCalls = %v, want none", resp.ToolCalls)
|
||||
}
|
||||
if resp.Raw == nil {
|
||||
t.Error("Raw is nil, want wire response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateToolCallResponse(t *testing.T) {
|
||||
const body = `{
|
||||
"id": "chatcmpl-2", "object": "chat.completion", "created": 1, "model": "gpt-test",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": null, "tool_calls": [
|
||||
{"id": "call_9", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Boston\"}"}},
|
||||
{"id": "", "type": "function", "function": {"name": "get_time", "arguments": "{}"}}
|
||||
]},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}
|
||||
}`
|
||||
srv, _ := newServer(t, http.StatusOK, body)
|
||||
m := testModel(t, srv, nil)
|
||||
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 2 {
|
||||
t.Fatalf("ToolCalls = %d, want 2", len(resp.ToolCalls))
|
||||
}
|
||||
tc := resp.ToolCalls[0]
|
||||
if tc.ID != "call_9" || tc.Name != "get_weather" || string(tc.Arguments) != `{"city":"Boston"}` {
|
||||
t.Errorf("ToolCalls[0] = %+v", tc)
|
||||
}
|
||||
if resp.ToolCalls[1].ID != "call_1" {
|
||||
t.Errorf("synthesized ID = %q, want %q", resp.ToolCalls[1].ID, "call_1")
|
||||
}
|
||||
// finish_reason "stop" with tool_calls present: presence wins.
|
||||
if resp.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishToolCalls)
|
||||
}
|
||||
if len(resp.Parts) != 0 {
|
||||
t.Errorf("Parts = %v, want none", resp.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinishReasonMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
wire string
|
||||
want llm.FinishReason
|
||||
}{
|
||||
{"stop", llm.FinishStop},
|
||||
{"length", llm.FinishLength},
|
||||
{"tool_calls", llm.FinishToolCalls},
|
||||
{"content_filter", llm.FinishContentFilter},
|
||||
{"function_call", llm.FinishOther},
|
||||
{"weird_new_reason", llm.FinishOther},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.wire, func(t *testing.T) {
|
||||
body := `{"choices":[{"index":0,"message":{"role":"assistant","content":"x"},"finish_reason":"` + tt.wire + `"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`
|
||||
srv, _ := newServer(t, http.StatusOK, body)
|
||||
m := testModel(t, srv, nil)
|
||||
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if resp.FinishReason != tt.want {
|
||||
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIErrorMapping(t *testing.T) {
|
||||
t.Run("429 rate limit is transient", func(t *testing.T) {
|
||||
const body = `{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}`
|
||||
srv, _ := newServer(t, http.StatusTooManyRequests, body)
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusTooManyRequests {
|
||||
t.Errorf("Status = %d, want 429", apiErr.Status)
|
||||
}
|
||||
if apiErr.Code != "rate_limit_exceeded" {
|
||||
t.Errorf("Code = %q, want %q", apiErr.Code, "rate_limit_exceeded")
|
||||
}
|
||||
if apiErr.Message != "Rate limit reached" {
|
||||
t.Errorf("Message = %q", apiErr.Message)
|
||||
}
|
||||
if apiErr.Provider != "openai" || apiErr.Model != "gpt-test" {
|
||||
t.Errorf("Provider/Model = %q/%q", apiErr.Provider, apiErr.Model)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassTransient {
|
||||
t.Errorf("Classify = %v, want transient", got)
|
||||
}
|
||||
})
|
||||
t.Run("401 code null falls back to type, permanent", func(t *testing.T) {
|
||||
const body = `{"error":{"message":"Incorrect API key provided","type":"authentication_error","param":null,"code":null}}`
|
||||
srv, _ := newServer(t, http.StatusUnauthorized, body)
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" {
|
||||
t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassPermanent {
|
||||
t.Errorf("Classify = %v, want permanent", got)
|
||||
}
|
||||
})
|
||||
t.Run("non-JSON body becomes message", func(t *testing.T) {
|
||||
srv, _ := newServer(t, http.StatusServiceUnavailable, "upstream exploded\n")
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusServiceUnavailable || apiErr.Message != "upstream exploded" {
|
||||
t.Errorf("Status/Message = %d/%q", apiErr.Status, apiErr.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMissingAPIKey(t *testing.T) {
|
||||
t.Setenv("OPENAI_API_KEY", "")
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m, err := New(WithBaseURL(srv.URL)).Model("gpt-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
_, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "missing_api_key" {
|
||||
t.Errorf("Status/Code = %d/%q, want 401/missing_api_key", apiErr.Status, apiErr.Code)
|
||||
}
|
||||
if rec.hits != 0 {
|
||||
t.Errorf("server hit %d times, want 0", rec.hits)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvAPIKeyReadAtConstruction(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
t.Setenv("OPENAI_API_KEY", "env-secret")
|
||||
p := New(WithBaseURL(srv.URL))
|
||||
t.Setenv("OPENAI_API_KEY", "changed-later") // must not affect p
|
||||
m, err := p.Model("gpt-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := rec.header.Get("Authorization"); got != "Bearer env-secret" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer env-secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAndContentTypeHeaders(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil)
|
||||
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if got := rec.header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer test-key")
|
||||
}
|
||||
if got := rec.header.Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", got)
|
||||
}
|
||||
if rec.path != "/chat/completions" {
|
||||
t.Errorf("path = %q, want /chat/completions", rec.path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompatEndpointNameAndBaseURL(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
p := New(WithName("groq"), WithAPIKey("k"), WithBaseURL(srv.URL+"/openai/v1/"))
|
||||
if p.Name() != "groq" {
|
||||
t.Errorf("Name = %q, want groq", p.Name())
|
||||
}
|
||||
m, err := p.Model("llama-3.3-70b")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Generate: %v", err)
|
||||
}
|
||||
if rec.path != "/openai/v1/chat/completions" {
|
||||
t.Errorf("path = %q, want /openai/v1/chat/completions (trailing slash trimmed)", rec.path)
|
||||
}
|
||||
if resp.Model != "groq/llama-3.3-70b" {
|
||||
t.Errorf("Model = %q, want groq/llama-3.3-70b", resp.Model)
|
||||
}
|
||||
if rec.body["model"] != "llama-3.3-70b" {
|
||||
t.Errorf("wire model = %v, want llama-3.3-70b (verbatim)", rec.body["model"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapabilityEnforcement(t *testing.T) {
|
||||
img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) }
|
||||
tests := []struct {
|
||||
name string
|
||||
caps *llm.Capabilities // nil = provider defaults
|
||||
msg llm.Message
|
||||
}{
|
||||
{
|
||||
name: "images unsupported",
|
||||
caps: &llm.Capabilities{SupportsTools: true, SupportsStreaming: true},
|
||||
msg: llm.UserParts(img("image/png", 4)),
|
||||
},
|
||||
{
|
||||
name: "too many images",
|
||||
caps: &llm.Capabilities{MaxImagesPerReq: 1},
|
||||
msg: llm.UserParts(img("image/png", 4), img("image/png", 4)),
|
||||
},
|
||||
{
|
||||
name: "disallowed MIME under defaults",
|
||||
msg: llm.UserParts(img("image/bmp", 4)),
|
||||
},
|
||||
{
|
||||
name: "image too large",
|
||||
caps: &llm.Capabilities{MaxImagesPerReq: 4, MaxImageBytes: 2},
|
||||
msg: llm.UserParts(img("image/png", 3)),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
var mopts []llm.ModelOption
|
||||
if tt.caps != nil {
|
||||
mopts = append(mopts, llm.WithCapabilities(*tt.caps))
|
||||
}
|
||||
m := testModel(t, srv, nil, mopts...)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{tt.msg}})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassPermanent {
|
||||
t.Errorf("Classify = %v, want permanent", got)
|
||||
}
|
||||
if rec.hits != 0 {
|
||||
t.Errorf("server hit %d times, want 0 (must refuse before sending)", rec.hits)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("streaming unsupported", func(t *testing.T) {
|
||||
srv, rec := newServer(t, http.StatusOK, textResponse)
|
||||
m := testModel(t, srv, nil, llm.WithCapabilities(llm.Capabilities{SupportsTools: true}))
|
||||
_, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if !errors.Is(err, llm.ErrUnsupported) {
|
||||
t.Fatalf("err = %v, want ErrUnsupported", err)
|
||||
}
|
||||
if rec.hits != 0 {
|
||||
t.Errorf("server hit %d times, want 0", rec.hits)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelCapabilitiesOverride(t *testing.T) {
|
||||
p := New(WithAPIKey("k"))
|
||||
def, err := p.Model("a")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
if caps := def.Capabilities(); !caps.SupportsTools || caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 20<<20 {
|
||||
t.Errorf("default caps = %+v", caps)
|
||||
}
|
||||
custom := llm.Capabilities{SupportsStreaming: true, ContextWindow: 8192}
|
||||
ovr, err := p.Model("b", llm.WithCapabilities(custom))
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
if got := ovr.Capabilities(); !reflect.DeepEqual(got, custom) {
|
||||
t.Errorf("override caps = %+v, want %+v", got, custom)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportErrorIsNotAPIError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
url := srv.URL
|
||||
srv.Close() // guarantee connection refused
|
||||
p := New(WithAPIKey("k"), WithBaseURL(url))
|
||||
m, err := p.Model("gpt-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Model: %v", err)
|
||||
}
|
||||
_, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil {
|
||||
t.Fatal("Generate succeeded against closed server")
|
||||
}
|
||||
if _, ok := errors.AsType[*llm.APIError](err); ok {
|
||||
t.Errorf("transport error wrapped in APIError: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "openai: do request") {
|
||||
t.Errorf("err = %v, want openai: do request context", err)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassTransient {
|
||||
t.Errorf("Classify = %v, want transient (net error must stay visible)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeErrorWrapped(t *testing.T) {
|
||||
srv, _ := newServer(t, http.StatusOK, "{not json")
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err == nil || !strings.Contains(err.Error(), "openai: decode response") {
|
||||
t.Errorf("err = %v, want decode response context", err)
|
||||
}
|
||||
if _, ok := errors.AsType[*llm.APIError](err); ok {
|
||||
t.Errorf("decode error wrapped in APIError: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// stream consumes the data-only SSE stream of chat.completion.chunk events.
|
||||
//
|
||||
// Delivery contract: TextDelta events as content fragments arrive; ToolCall
|
||||
// events only once fully assembled (fragments are buffered internally and
|
||||
// flushed at stream end — simplest correct handling of interleaved parallel
|
||||
// calls); exactly one final Response event; then io.EOF.
|
||||
type stream struct {
|
||||
m *model
|
||||
body io.ReadCloser
|
||||
sc *bufio.Scanner
|
||||
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
|
||||
queue []llm.StreamEvent
|
||||
done bool // finalize ran; drain queue then io.EOF
|
||||
|
||||
text strings.Builder
|
||||
calls []*toolCallAcc // first-appearance order
|
||||
byIndex map[int]*toolCallAcc
|
||||
finish string
|
||||
usage llm.Usage
|
||||
}
|
||||
|
||||
// toolCallAcc accumulates one tool call's fragments. The id and name arrive
|
||||
// on the first fragment for an index; arguments arrive as string pieces to
|
||||
// concatenate.
|
||||
type toolCallAcc struct {
|
||||
id string
|
||||
name string
|
||||
args strings.Builder
|
||||
}
|
||||
|
||||
// Next implements llm.Stream.
|
||||
func (s *stream) Next() (llm.StreamEvent, error) {
|
||||
for {
|
||||
if len(s.queue) > 0 {
|
||||
ev := s.queue[0]
|
||||
s.queue = s.queue[1:]
|
||||
return ev, nil
|
||||
}
|
||||
if s.done {
|
||||
return llm.StreamEvent{}, io.EOF
|
||||
}
|
||||
if !s.sc.Scan() {
|
||||
if err := s.sc.Err(); err != nil {
|
||||
return llm.StreamEvent{}, fmt.Errorf("openai: read stream: %w", err)
|
||||
}
|
||||
// Why: some compat servers close the body without a [DONE]
|
||||
// sentinel; a clean EOF still finalizes with what arrived.
|
||||
s.finalize()
|
||||
continue
|
||||
}
|
||||
line := strings.TrimSpace(s.sc.Text())
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue // SSE comments, event:/id: fields, blank separators
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" {
|
||||
continue
|
||||
}
|
||||
if payload == "[DONE]" {
|
||||
s.finalize()
|
||||
continue
|
||||
}
|
||||
if err := s.handleChunk([]byte(payload)); err != nil {
|
||||
return llm.StreamEvent{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleChunk folds one chat.completion.chunk into the stream state,
|
||||
// queueing any events it produces.
|
||||
func (s *stream) handleChunk(data []byte) error {
|
||||
var chunk streamChunk
|
||||
if err := json.Unmarshal(data, &chunk); err != nil {
|
||||
return fmt.Errorf("openai: decode stream chunk: %w", err)
|
||||
}
|
||||
if chunk.Error != nil {
|
||||
// Mid-stream error event on an otherwise-200 stream. Status stays 0:
|
||||
// there is no failing HTTP status to report.
|
||||
apiErr := &llm.APIError{
|
||||
Provider: s.m.p.name,
|
||||
Model: s.m.id,
|
||||
Code: chunk.Error.Code,
|
||||
Message: chunk.Error.Message,
|
||||
}
|
||||
if apiErr.Code == "" {
|
||||
apiErr.Code = chunk.Error.Type
|
||||
}
|
||||
return apiErr
|
||||
}
|
||||
if chunk.Usage != nil {
|
||||
s.usage = llm.Usage{
|
||||
InputTokens: chunk.Usage.PromptTokens,
|
||||
OutputTokens: chunk.Usage.CompletionTokens,
|
||||
}
|
||||
}
|
||||
// Why the guard: the include_usage chunk arrives with an EMPTY choices
|
||||
// array; indexing choices[0] unconditionally would panic on it.
|
||||
if len(chunk.Choices) == 0 {
|
||||
return nil
|
||||
}
|
||||
choice := chunk.Choices[0]
|
||||
if choice.FinishReason != "" {
|
||||
s.finish = choice.FinishReason
|
||||
}
|
||||
if choice.Delta.Content != "" {
|
||||
s.text.WriteString(choice.Delta.Content)
|
||||
s.queue = append(s.queue, llm.StreamEvent{TextDelta: choice.Delta.Content})
|
||||
}
|
||||
for _, tc := range choice.Delta.ToolCalls {
|
||||
acc := s.byIndex[tc.Index]
|
||||
if acc == nil {
|
||||
if s.byIndex == nil {
|
||||
s.byIndex = make(map[int]*toolCallAcc)
|
||||
}
|
||||
acc = &toolCallAcc{}
|
||||
s.byIndex[tc.Index] = acc
|
||||
s.calls = append(s.calls, acc)
|
||||
}
|
||||
if tc.ID != "" {
|
||||
acc.id = tc.ID
|
||||
}
|
||||
if tc.Function.Name != "" {
|
||||
acc.name = tc.Function.Name
|
||||
}
|
||||
acc.args.WriteString(tc.Function.Arguments)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// finalize assembles the buffered tool calls and the final Response, queues
|
||||
// them (ToolCall events first, Response last), and marks the stream done.
|
||||
func (s *stream) finalize() {
|
||||
if s.done {
|
||||
return
|
||||
}
|
||||
s.done = true
|
||||
resp := &llm.Response{Model: s.m.p.name + "/" + s.m.id, Usage: s.usage}
|
||||
if s.text.Len() > 0 {
|
||||
resp.Parts = []llm.Part{llm.TextPart{Text: s.text.String()}}
|
||||
}
|
||||
for i, acc := range s.calls {
|
||||
id := acc.id
|
||||
if id == "" {
|
||||
// Why: ToolResult.ID must echo ToolCall.ID; synthesize for
|
||||
// compat servers that stream calls without ids.
|
||||
id = fmt.Sprintf("call_%d", i)
|
||||
}
|
||||
resp.ToolCalls = append(resp.ToolCalls, llm.ToolCall{
|
||||
ID: id,
|
||||
Name: acc.name,
|
||||
Arguments: json.RawMessage(acc.args.String()),
|
||||
})
|
||||
}
|
||||
resp.FinishReason = mapFinish(s.finish, len(resp.ToolCalls) > 0)
|
||||
for i := range resp.ToolCalls {
|
||||
tc := resp.ToolCalls[i] // copy so the event doesn't alias the slice
|
||||
s.queue = append(s.queue, llm.StreamEvent{ToolCall: &tc})
|
||||
}
|
||||
s.queue = append(s.queue, llm.StreamEvent{Response: resp})
|
||||
}
|
||||
|
||||
// Close implements llm.Stream. Closing the body unblocks any in-flight read
|
||||
// and aborts the HTTP stream; safe to call at any time, including twice.
|
||||
func (s *stream) Close() error {
|
||||
s.closeOnce.Do(func() { s.closeErr = s.body.Close() })
|
||||
return s.closeErr
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// sseServer streams each payload as one "data: <payload>" SSE event and
|
||||
// records the request like newServer.
|
||||
func sseServer(t *testing.T, payloads ...string) (*httptest.Server, *recorded) {
|
||||
t.Helper()
|
||||
rec := &recorded{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rec.hits++
|
||||
rec.header = r.Header.Clone()
|
||||
rec.path = r.URL.Path
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("read request body: %v", err)
|
||||
}
|
||||
if len(raw) > 0 {
|
||||
if err := json.Unmarshal(raw, &rec.body); err != nil {
|
||||
t.Errorf("request body is not JSON: %v\n%s", err, raw)
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
for _, p := range payloads {
|
||||
io.WriteString(w, "data: "+p+"\n\n")
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
return srv, rec
|
||||
}
|
||||
|
||||
// collect drains a stream to io.EOF, failing the test on any other error.
|
||||
func collect(t *testing.T, s llm.Stream) []llm.StreamEvent {
|
||||
t.Helper()
|
||||
var events []llm.StreamEvent
|
||||
for {
|
||||
ev, err := s.Next()
|
||||
if err == io.EOF {
|
||||
return events
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Next: %v", err)
|
||||
}
|
||||
events = append(events, ev)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamText(t *testing.T) {
|
||||
srv, rec := sseServer(t,
|
||||
`{"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-test","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"Hel"},"finish_reason":null}],"obfuscation":"xK9q"}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"lo"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
`{"choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
|
||||
`[DONE]`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
events := collect(t, s)
|
||||
|
||||
// Request shape: stream flag, usage opt-in, SSE accept header.
|
||||
if rec.body["stream"] != true {
|
||||
t.Errorf("stream = %v, want true", rec.body["stream"])
|
||||
}
|
||||
so, _ := rec.body["stream_options"].(map[string]any)
|
||||
if so == nil || so["include_usage"] != true {
|
||||
t.Errorf("stream_options = %v, want include_usage true", rec.body["stream_options"])
|
||||
}
|
||||
if got := rec.header.Get("Accept"); got != "text/event-stream" {
|
||||
t.Errorf("Accept = %q, want text/event-stream", got)
|
||||
}
|
||||
|
||||
if len(events) != 3 {
|
||||
t.Fatalf("got %d events, want 3: %+v", len(events), events)
|
||||
}
|
||||
if events[0].TextDelta != "Hel" || events[1].TextDelta != "lo" {
|
||||
t.Errorf("deltas = %q, %q, want Hel, lo", events[0].TextDelta, events[1].TextDelta)
|
||||
}
|
||||
final := events[2].Response
|
||||
if final == nil {
|
||||
t.Fatal("last event has no Response")
|
||||
}
|
||||
if got := final.Text(); got != "Hello" {
|
||||
t.Errorf("final text = %q, want Hello", got)
|
||||
}
|
||||
if final.FinishReason != llm.FinishStop {
|
||||
t.Errorf("FinishReason = %v, want stop", final.FinishReason)
|
||||
}
|
||||
if final.Usage != (llm.Usage{InputTokens: 5, OutputTokens: 2}) {
|
||||
t.Errorf("Usage = %+v, want {5 2}", final.Usage)
|
||||
}
|
||||
if final.Model != "openai/gpt-test" {
|
||||
t.Errorf("Model = %q, want openai/gpt-test", final.Model)
|
||||
}
|
||||
|
||||
// Next after EOF keeps returning EOF; Close is idempotent.
|
||||
if _, err := s.Next(); err != io.EOF {
|
||||
t.Errorf("Next after EOF = %v, want io.EOF", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("first Close: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("second Close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamParallelToolCalls(t *testing.T) {
|
||||
// Two interleaved calls with distinct indexes; id/name only on the first
|
||||
// fragment of each; arguments split across fragments.
|
||||
srv, _ := sseServer(t,
|
||||
`{"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_a","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"city\":"}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_b","type":"function","function":{"name":"get_time","arguments":"{\"tz\":"}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"Boston\"}"}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"\"EST\"}"}}]},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
`{"choices":[],"usage":{"prompt_tokens":11,"completion_tokens":9,"total_tokens":20}}`,
|
||||
`[DONE]`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
events := collect(t, s)
|
||||
|
||||
if len(events) != 3 {
|
||||
t.Fatalf("got %d events, want 3 (two tool calls + response): %+v", len(events), events)
|
||||
}
|
||||
a, b := events[0].ToolCall, events[1].ToolCall
|
||||
if a == nil || b == nil {
|
||||
t.Fatalf("events 0/1 are not tool calls: %+v", events)
|
||||
}
|
||||
if a.ID != "call_a" || a.Name != "get_weather" || string(a.Arguments) != `{"city":"Boston"}` {
|
||||
t.Errorf("first call = %+v", a)
|
||||
}
|
||||
if b.ID != "call_b" || b.Name != "get_time" || string(b.Arguments) != `{"tz":"EST"}` {
|
||||
t.Errorf("second call = %+v", b)
|
||||
}
|
||||
final := events[2].Response
|
||||
if final == nil {
|
||||
t.Fatal("last event has no Response")
|
||||
}
|
||||
if len(final.ToolCalls) != 2 {
|
||||
t.Fatalf("final ToolCalls = %d, want 2", len(final.ToolCalls))
|
||||
}
|
||||
if final.ToolCalls[0].ID != "call_a" || final.ToolCalls[1].ID != "call_b" {
|
||||
t.Errorf("final ToolCalls order = %q, %q", final.ToolCalls[0].ID, final.ToolCalls[1].ID)
|
||||
}
|
||||
if final.FinishReason != llm.FinishToolCalls {
|
||||
t.Errorf("FinishReason = %v, want tool_calls", final.FinishReason)
|
||||
}
|
||||
if final.Usage != (llm.Usage{InputTokens: 11, OutputTokens: 9}) {
|
||||
t.Errorf("Usage = %+v, want {11 9}", final.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamMidStreamError(t *testing.T) {
|
||||
srv, _ := sseServer(t,
|
||||
`{"choices":[{"index":0,"delta":{"content":"par"},"finish_reason":null}]}`,
|
||||
`{"error":{"message":"The server had an error while processing your request","type":"server_error","param":null,"code":null}}`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
ev, err := s.Next()
|
||||
if err != nil || ev.TextDelta != "par" {
|
||||
t.Fatalf("first event = %+v, %v; want TextDelta par", ev, err)
|
||||
}
|
||||
_, err = s.Next()
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
||||
}
|
||||
if apiErr.Code != "server_error" {
|
||||
t.Errorf("Code = %q, want server_error", apiErr.Code)
|
||||
}
|
||||
if apiErr.Message != "The server had an error while processing your request" {
|
||||
t.Errorf("Message = %q", apiErr.Message)
|
||||
}
|
||||
if apiErr.Status != 0 {
|
||||
t.Errorf("Status = %d, want 0 (the HTTP stream was 200)", apiErr.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamHTTPError(t *testing.T) {
|
||||
srv, _ := newServer(t, http.StatusTooManyRequests,
|
||||
`{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}`)
|
||||
m := testModel(t, srv, nil)
|
||||
_, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
apiErr, ok := errors.AsType[*llm.APIError](err)
|
||||
if !ok {
|
||||
t.Fatalf("err = %v (%T), want *llm.APIError from Stream itself", err, err)
|
||||
}
|
||||
if apiErr.Status != http.StatusTooManyRequests || apiErr.Code != "rate_limit_exceeded" {
|
||||
t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code)
|
||||
}
|
||||
if got := llm.Classify(err); got != llm.ClassTransient {
|
||||
t.Errorf("Classify = %v, want transient", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamWithoutDoneSentinel(t *testing.T) {
|
||||
// Why: some compat servers close the connection without "data: [DONE]";
|
||||
// a clean EOF must still produce the final Response.
|
||||
srv, _ := sseServer(t,
|
||||
`{"choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
events := collect(t, s)
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("got %d events, want 2: %+v", len(events), events)
|
||||
}
|
||||
final := events[1].Response
|
||||
if final == nil || final.Text() != "ok" || final.FinishReason != llm.FinishStop {
|
||||
t.Errorf("final = %+v", final)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamCloseEarly(t *testing.T) {
|
||||
srv, _ := sseServer(t,
|
||||
`{"choices":[{"index":0,"delta":{"content":"a"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{"content":"b"},"finish_reason":null}]}`,
|
||||
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
`[DONE]`,
|
||||
)
|
||||
m := testModel(t, srv, nil)
|
||||
s, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
||||
if err != nil {
|
||||
t.Fatalf("Stream: %v", err)
|
||||
}
|
||||
if _, err := s.Next(); err != nil {
|
||||
t.Fatalf("Next: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("Close mid-stream: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Errorf("Close again: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,321 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// --- request wire shapes ---
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []wireMessage `json:"messages"`
|
||||
Tools []wireTool `json:"tools,omitempty"`
|
||||
// ToolChoice is "auto"/"none"/"required" (string) or a named-function
|
||||
// object; any avoids two fields for one wire key.
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
ResponseFormat *wireRespFormat `json:"response_format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *wireStreamOptions `json:"stream_options,omitempty"`
|
||||
}
|
||||
|
||||
type wireMessage struct {
|
||||
Role string `json:"role"`
|
||||
// Content is a string for text-only turns, a part array for multimodal
|
||||
// turns, or nil (wire null) for assistant turns that only call tools.
|
||||
Content any `json:"content"`
|
||||
ToolCalls []wireToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type wireTextPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type wireImagePart struct {
|
||||
Type string `json:"type"`
|
||||
ImageURL wireImageURL `json:"image_url"`
|
||||
}
|
||||
|
||||
type wireImageURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type wireToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function wireFunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
type wireFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
// Arguments is a JSON-encoded STRING per the wire format, not an object.
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type wireTool struct {
|
||||
Type string `json:"type"`
|
||||
Function wireToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
type wireToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type wireNamedToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Function wireToolName `json:"function"`
|
||||
}
|
||||
|
||||
type wireToolName struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type wireRespFormat struct {
|
||||
Type string `json:"type"`
|
||||
JSONSchema *wireJSONSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// wireJSONSchema omits the strict flag on purpose: strict mode imposes
|
||||
// schema rewrites (every property required, additionalProperties:false at
|
||||
// every level) that belong to the caller, not the transport.
|
||||
type wireJSONSchema struct {
|
||||
Name string `json:"name"`
|
||||
Schema json.RawMessage `json:"schema"`
|
||||
}
|
||||
|
||||
type wireStreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage"`
|
||||
}
|
||||
|
||||
// --- response wire shapes (loose: unknown fields ignored) ---
|
||||
|
||||
type chatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []chatChoice `json:"choices"`
|
||||
Usage *wireUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type chatChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message wireRespMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type wireRespMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"` // null decodes to ""
|
||||
Refusal string `json:"refusal"` // tolerated, unused
|
||||
ToolCalls []wireToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
type wireUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type errorEnvelope struct {
|
||||
Error wireError `json:"error"`
|
||||
}
|
||||
|
||||
type wireError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"` // null decodes to ""
|
||||
}
|
||||
|
||||
// --- streaming wire shapes ---
|
||||
|
||||
type streamChunk struct {
|
||||
Choices []streamChoice `json:"choices"`
|
||||
Usage *wireUsage `json:"usage"`
|
||||
Error *wireError `json:"error"` // mid-stream error event
|
||||
}
|
||||
|
||||
type streamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta streamDelta `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"` // null decodes to ""
|
||||
}
|
||||
|
||||
type streamDelta struct {
|
||||
Content string `json:"content"` // null decodes to ""
|
||||
ToolCalls []streamToolCallDelta `json:"tool_calls"`
|
||||
}
|
||||
|
||||
// streamToolCallDelta is one tool-call fragment. The id and name appear only
|
||||
// on a call's first fragment; later fragments carry just index + an
|
||||
// arguments substring. Accumulation keys on Index, never ID.
|
||||
type streamToolCallDelta struct {
|
||||
Index int `json:"index"`
|
||||
ID string `json:"id"`
|
||||
Function wireFunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// --- mapping: llm.Request -> chatRequest ---
|
||||
|
||||
// buildRequest translates the canonical request to the wire shape. The
|
||||
// capability check has already passed by the time this runs.
|
||||
func (m *model) buildRequest(req llm.Request, stream bool) *chatRequest {
|
||||
out := &chatRequest{
|
||||
Model: m.id,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stop: req.StopSequences,
|
||||
ReasoningEffort: req.ReasoningEffort,
|
||||
}
|
||||
|
||||
// Fold Request.System and every RoleSystem message into one leading
|
||||
// system message, System field first. Why: the canonical contract allows
|
||||
// system content in both places; OpenAI wants one system mechanism.
|
||||
var sys []string
|
||||
if req.System != "" {
|
||||
sys = append(sys, req.System)
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
if msg.Role == llm.RoleSystem {
|
||||
if t := msg.Text(); t != "" {
|
||||
sys = append(sys, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
if joined := strings.Join(sys, "\n\n"); joined != "" {
|
||||
out.Messages = append(out.Messages, wireMessage{Role: "system", Content: joined})
|
||||
}
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
switch msg.Role {
|
||||
case llm.RoleSystem:
|
||||
// Folded above; excluded from the normal message list.
|
||||
case llm.RoleUser:
|
||||
out.Messages = append(out.Messages, wireMessage{Role: "user", Content: contentValue(msg.Parts)})
|
||||
case llm.RoleAssistant:
|
||||
wm := wireMessage{Role: "assistant"}
|
||||
if text := msg.Text(); text != "" {
|
||||
wm.Content = text
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
args := string(tc.Arguments)
|
||||
if args == "" {
|
||||
// Why: arguments must be a valid JSON document string;
|
||||
// an empty string is not one.
|
||||
args = "{}"
|
||||
}
|
||||
wm.ToolCalls = append(wm.ToolCalls, wireToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Function: wireFunctionCall{Name: tc.Name, Arguments: args},
|
||||
})
|
||||
}
|
||||
out.Messages = append(out.Messages, wm)
|
||||
case llm.RoleTool:
|
||||
// One wire message per result: the API pairs each tool output
|
||||
// with its call via tool_call_id, one message each.
|
||||
for _, tr := range msg.ToolResults {
|
||||
content := tr.Content
|
||||
if tr.IsError {
|
||||
content = "ERROR: " + content
|
||||
}
|
||||
out.Messages = append(out.Messages, wireMessage{
|
||||
Role: "tool",
|
||||
Content: content,
|
||||
ToolCallID: tr.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range req.Tools {
|
||||
out.Tools = append(out.Tools, wireTool{
|
||||
Type: "function",
|
||||
Function: wireToolFunction{Name: t.Name, Description: t.Description, Parameters: t.Parameters},
|
||||
})
|
||||
}
|
||||
|
||||
switch req.ToolChoice {
|
||||
case "":
|
||||
// Omit: provider default ("auto" when tools are present).
|
||||
case "auto", "none", "required":
|
||||
out.ToolChoice = req.ToolChoice
|
||||
default:
|
||||
// Any other value names the one tool the model must call.
|
||||
out.ToolChoice = wireNamedToolChoice{Type: "function", Function: wireToolName{Name: req.ToolChoice}}
|
||||
}
|
||||
|
||||
if req.MaxTokens > 0 {
|
||||
if m.p.legacyMaxTokens {
|
||||
out.MaxTokens = req.MaxTokens
|
||||
} else {
|
||||
out.MaxCompletionTokens = req.MaxTokens
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.Schema) > 0 {
|
||||
name := req.SchemaName
|
||||
if name == "" {
|
||||
name = "response"
|
||||
}
|
||||
out.ResponseFormat = &wireRespFormat{
|
||||
Type: "json_schema",
|
||||
JSONSchema: &wireJSONSchema{Name: name, Schema: req.Schema},
|
||||
}
|
||||
}
|
||||
|
||||
if stream {
|
||||
out.Stream = true
|
||||
// Why: without include_usage the stream never reports token counts;
|
||||
// the usage arrives in one extra chunk with an empty choices array.
|
||||
out.StreamOptions = &wireStreamOptions{IncludeUsage: true}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// contentValue renders message parts as the wire content value: a plain
|
||||
// string when text-only (maximum compat), a part array when images are
|
||||
// present.
|
||||
func contentValue(parts []llm.Part) any {
|
||||
multimodal := false
|
||||
for _, p := range parts {
|
||||
if _, ok := p.(llm.ImagePart); ok {
|
||||
multimodal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !multimodal {
|
||||
var b strings.Builder
|
||||
for _, p := range parts {
|
||||
if t, ok := p.(llm.TextPart); ok {
|
||||
b.WriteString(t.Text)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
out := make([]any, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
switch v := p.(type) {
|
||||
case llm.TextPart:
|
||||
out = append(out, wireTextPart{Type: "text", Text: v.Text})
|
||||
case llm.ImagePart:
|
||||
url := "data:" + v.MIME + ";base64," + base64.StdEncoding.EncodeToString(v.Data)
|
||||
out = append(out, wireImagePart{Type: "image_url", ImageURL: wireImageURL{URL: url}})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user