043249e0e1
Phase 3: - provider/openai: Chat Completions for OpenAI + compat endpoints (SSE streaming with by-index tool-call assembly, response_format json_schema, legacy max_tokens option, reasoning_effort) - provider/anthropic: Messages API (tool_use/tool_result, GA structured output via output_config.format, full SSE event parser, 529 transient) - provider/ollama: one native /api/chat client behind the ollama, ollama-cloud, and foreman built-ins (presets; NDJSON streaming tolerant of foreman's buffered single-object responses; object tool arguments; format-schema structured output; think mapping) - media/: capability normalization (sniff, downscale, transcode, byte ladder, ErrUnsupported), wired into the chain executor per target with penalty-free advance past incapable elements - registry: real provider + scheme wiring, WithHTTPClient option, required env-foreman TLS chat round-trip test - ADR-0009 multimodal strategy, ADR-0010 tools/structured mapping; README matrix + CLAUDE.md synced Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
615 lines
21 KiB
Go
615 lines
21 KiB
Go
package openai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
|
)
|
|
|
|
var (
|
|
_ llm.Provider = (*Provider)(nil)
|
|
_ llm.Model = (*model)(nil)
|
|
_ llm.Stream = (*stream)(nil)
|
|
)
|
|
|
|
const textResponse = `{
|
|
"id": "chatcmpl-1", "object": "chat.completion", "created": 1741570283, "model": "gpt-test",
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": "hello", "refusal": null, "annotations": []},
|
|
"logprobs": null,
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 19, "completion_tokens": 10, "total_tokens": 29,
|
|
"prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
|
|
"completion_tokens_details": {"reasoning_tokens": 0}
|
|
},
|
|
"service_tier": "default", "system_fingerprint": "fp_x"
|
|
}`
|
|
|
|
// recorded captures the last request a test server received.
|
|
type recorded struct {
|
|
body map[string]any
|
|
header http.Header
|
|
path string
|
|
hits int
|
|
}
|
|
|
|
// newServer starts a test server that records the request and replies with
|
|
// a fixed status and body.
|
|
func newServer(t *testing.T, status int, respBody string) (*httptest.Server, *recorded) {
|
|
t.Helper()
|
|
rec := &recorded{}
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
rec.hits++
|
|
rec.header = r.Header.Clone()
|
|
rec.path = r.URL.Path
|
|
raw, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
t.Errorf("read request body: %v", err)
|
|
}
|
|
if len(raw) > 0 {
|
|
if err := json.Unmarshal(raw, &rec.body); err != nil {
|
|
t.Errorf("request body is not JSON: %v\n%s", err, raw)
|
|
}
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
io.WriteString(w, respBody)
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
return srv, rec
|
|
}
|
|
|
|
func testModel(t *testing.T, srv *httptest.Server, popts []Option, mopts ...llm.ModelOption) llm.Model {
|
|
t.Helper()
|
|
opts := append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, popts...)
|
|
m, err := New(opts...).Model("gpt-test", mopts...)
|
|
if err != nil {
|
|
t.Fatalf("Model: %v", err)
|
|
}
|
|
return m
|
|
}
|
|
|
|
func fptr(f float64) *float64 { return &f }
|
|
|
|
func TestGenerateRequestShape(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
m := testModel(t, srv, nil)
|
|
|
|
req := llm.Request{
|
|
System: "base system",
|
|
Messages: []llm.Message{
|
|
llm.SystemText("folded system"),
|
|
llm.UserParts(llm.Text("look:"), llm.Image("image/png", []byte{1, 2, 3})),
|
|
{
|
|
Role: llm.RoleAssistant,
|
|
Parts: []llm.Part{llm.Text("checking")},
|
|
ToolCalls: []llm.ToolCall{
|
|
{ID: "call_1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Boston"}`)},
|
|
},
|
|
},
|
|
llm.ToolResultsMessage(
|
|
llm.ToolResult{ID: "call_1", Name: "get_weather", Content: "72F"},
|
|
llm.ToolResult{ID: "call_2", Name: "get_weather", Content: "boom", IsError: true},
|
|
),
|
|
llm.UserText("thanks"),
|
|
},
|
|
Tools: []llm.Tool{{
|
|
Name: "get_weather",
|
|
Description: "Get current weather",
|
|
Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
|
|
}},
|
|
ToolChoice: "auto",
|
|
Temperature: fptr(0.5),
|
|
TopP: fptr(0.9),
|
|
MaxTokens: 256,
|
|
StopSequences: []string{"END"},
|
|
ReasoningEffort: "high",
|
|
Schema: json.RawMessage(`{"type":"object","properties":{"ok":{"type":"boolean"}}}`),
|
|
SchemaName: "verdict",
|
|
}
|
|
if _, err := m.Generate(context.Background(), req); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
|
|
want := map[string]any{
|
|
"model": "gpt-test",
|
|
"messages": []any{
|
|
map[string]any{"role": "system", "content": "base system\n\nfolded system"},
|
|
map[string]any{"role": "user", "content": []any{
|
|
map[string]any{"type": "text", "text": "look:"},
|
|
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,AQID"}},
|
|
}},
|
|
map[string]any{"role": "assistant", "content": "checking", "tool_calls": []any{
|
|
map[string]any{"id": "call_1", "type": "function", "function": map[string]any{
|
|
"name": "get_weather", "arguments": `{"city":"Boston"}`,
|
|
}},
|
|
}},
|
|
map[string]any{"role": "tool", "content": "72F", "tool_call_id": "call_1"},
|
|
map[string]any{"role": "tool", "content": "ERROR: boom", "tool_call_id": "call_2"},
|
|
map[string]any{"role": "user", "content": "thanks"},
|
|
},
|
|
"tools": []any{
|
|
map[string]any{"type": "function", "function": map[string]any{
|
|
"name": "get_weather",
|
|
"description": "Get current weather",
|
|
"parameters": map[string]any{"type": "object", "properties": map[string]any{"city": map[string]any{"type": "string"}}},
|
|
}},
|
|
},
|
|
"tool_choice": "auto",
|
|
"temperature": 0.5,
|
|
"top_p": 0.9,
|
|
"max_completion_tokens": float64(256),
|
|
"stop": []any{"END"},
|
|
"reasoning_effort": "high",
|
|
"response_format": map[string]any{"type": "json_schema", "json_schema": map[string]any{
|
|
"name": "verdict",
|
|
"schema": map[string]any{"type": "object", "properties": map[string]any{"ok": map[string]any{"type": "boolean"}}},
|
|
}},
|
|
}
|
|
if !reflect.DeepEqual(rec.body, want) {
|
|
got, _ := json.MarshalIndent(rec.body, "", " ")
|
|
exp, _ := json.MarshalIndent(want, "", " ")
|
|
t.Errorf("request body mismatch\ngot:\n%s\nwant:\n%s", got, exp)
|
|
}
|
|
}
|
|
|
|
func TestToolChoiceForms(t *testing.T) {
|
|
tests := []struct {
|
|
choice string
|
|
want any // nil = key absent
|
|
}{
|
|
{"", nil},
|
|
{"auto", "auto"},
|
|
{"none", "none"},
|
|
{"required", "required"},
|
|
{"get_weather", map[string]any{"type": "function", "function": map[string]any{"name": "get_weather"}}},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run("choice="+tt.choice, func(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
m := testModel(t, srv, nil)
|
|
req := llm.Request{
|
|
Messages: []llm.Message{llm.UserText("hi")},
|
|
Tools: []llm.Tool{{Name: "get_weather"}},
|
|
ToolChoice: tt.choice,
|
|
}
|
|
if _, err := m.Generate(context.Background(), req); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
got, present := rec.body["tool_choice"]
|
|
if tt.want == nil {
|
|
if present {
|
|
t.Errorf("tool_choice present, want omitted: %v", got)
|
|
}
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("tool_choice = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMaxTokensField(t *testing.T) {
|
|
t.Run("default uses max_completion_tokens", func(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
m := testModel(t, srv, nil)
|
|
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64}
|
|
if _, err := m.Generate(context.Background(), req); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if got := rec.body["max_completion_tokens"]; got != float64(64) {
|
|
t.Errorf("max_completion_tokens = %v, want 64", got)
|
|
}
|
|
if _, present := rec.body["max_tokens"]; present {
|
|
t.Error("max_tokens present, want omitted")
|
|
}
|
|
})
|
|
t.Run("WithLegacyMaxTokens uses max_tokens", func(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
m := testModel(t, srv, []Option{WithLegacyMaxTokens()})
|
|
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64}
|
|
if _, err := m.Generate(context.Background(), req); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if got := rec.body["max_tokens"]; got != float64(64) {
|
|
t.Errorf("max_tokens = %v, want 64", got)
|
|
}
|
|
if _, present := rec.body["max_completion_tokens"]; present {
|
|
t.Error("max_completion_tokens present, want omitted")
|
|
}
|
|
})
|
|
t.Run("zero omits both", func(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
m := testModel(t, srv, nil)
|
|
req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}}
|
|
if _, err := m.Generate(context.Background(), req); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if _, present := rec.body["max_tokens"]; present {
|
|
t.Error("max_tokens present, want omitted")
|
|
}
|
|
if _, present := rec.body["max_completion_tokens"]; present {
|
|
t.Error("max_completion_tokens present, want omitted")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestSchemaNameDefault(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
m := testModel(t, srv, nil)
|
|
req := llm.Request{
|
|
Messages: []llm.Message{llm.UserText("hi")},
|
|
Schema: json.RawMessage(`{"type":"object"}`),
|
|
}
|
|
if _, err := m.Generate(context.Background(), req); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
rf, ok := rec.body["response_format"].(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("response_format missing: %v", rec.body)
|
|
}
|
|
js, ok := rf["json_schema"].(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("json_schema missing: %v", rf)
|
|
}
|
|
if js["name"] != "response" {
|
|
t.Errorf("schema name = %v, want %q", js["name"], "response")
|
|
}
|
|
}
|
|
|
|
func TestGenerateTextResponse(t *testing.T) {
|
|
srv, _ := newServer(t, http.StatusOK, textResponse)
|
|
m := testModel(t, srv, nil)
|
|
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if got := resp.Text(); got != "hello" {
|
|
t.Errorf("Text = %q, want %q", got, "hello")
|
|
}
|
|
if resp.FinishReason != llm.FinishStop {
|
|
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishStop)
|
|
}
|
|
if resp.Usage != (llm.Usage{InputTokens: 19, OutputTokens: 10}) {
|
|
t.Errorf("Usage = %+v, want {19 10}", resp.Usage)
|
|
}
|
|
if resp.Model != "openai/gpt-test" {
|
|
t.Errorf("Model = %q, want %q", resp.Model, "openai/gpt-test")
|
|
}
|
|
if len(resp.ToolCalls) != 0 {
|
|
t.Errorf("ToolCalls = %v, want none", resp.ToolCalls)
|
|
}
|
|
if resp.Raw == nil {
|
|
t.Error("Raw is nil, want wire response")
|
|
}
|
|
}
|
|
|
|
func TestGenerateToolCallResponse(t *testing.T) {
|
|
const body = `{
|
|
"id": "chatcmpl-2", "object": "chat.completion", "created": 1, "model": "gpt-test",
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": null, "tool_calls": [
|
|
{"id": "call_9", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Boston\"}"}},
|
|
{"id": "", "type": "function", "function": {"name": "get_time", "arguments": "{}"}}
|
|
]},
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}
|
|
}`
|
|
srv, _ := newServer(t, http.StatusOK, body)
|
|
m := testModel(t, srv, nil)
|
|
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if len(resp.ToolCalls) != 2 {
|
|
t.Fatalf("ToolCalls = %d, want 2", len(resp.ToolCalls))
|
|
}
|
|
tc := resp.ToolCalls[0]
|
|
if tc.ID != "call_9" || tc.Name != "get_weather" || string(tc.Arguments) != `{"city":"Boston"}` {
|
|
t.Errorf("ToolCalls[0] = %+v", tc)
|
|
}
|
|
if resp.ToolCalls[1].ID != "call_1" {
|
|
t.Errorf("synthesized ID = %q, want %q", resp.ToolCalls[1].ID, "call_1")
|
|
}
|
|
// finish_reason "stop" with tool_calls present: presence wins.
|
|
if resp.FinishReason != llm.FinishToolCalls {
|
|
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishToolCalls)
|
|
}
|
|
if len(resp.Parts) != 0 {
|
|
t.Errorf("Parts = %v, want none", resp.Parts)
|
|
}
|
|
}
|
|
|
|
func TestFinishReasonMapping(t *testing.T) {
|
|
tests := []struct {
|
|
wire string
|
|
want llm.FinishReason
|
|
}{
|
|
{"stop", llm.FinishStop},
|
|
{"length", llm.FinishLength},
|
|
{"tool_calls", llm.FinishToolCalls},
|
|
{"content_filter", llm.FinishContentFilter},
|
|
{"function_call", llm.FinishOther},
|
|
{"weird_new_reason", llm.FinishOther},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.wire, func(t *testing.T) {
|
|
body := `{"choices":[{"index":0,"message":{"role":"assistant","content":"x"},"finish_reason":"` + tt.wire + `"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`
|
|
srv, _ := newServer(t, http.StatusOK, body)
|
|
m := testModel(t, srv, nil)
|
|
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if resp.FinishReason != tt.want {
|
|
t.Errorf("FinishReason = %v, want %v", resp.FinishReason, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAPIErrorMapping(t *testing.T) {
|
|
t.Run("429 rate limit is transient", func(t *testing.T) {
|
|
const body = `{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}`
|
|
srv, _ := newServer(t, http.StatusTooManyRequests, body)
|
|
m := testModel(t, srv, nil)
|
|
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
apiErr, ok := errors.AsType[*llm.APIError](err)
|
|
if !ok {
|
|
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
|
}
|
|
if apiErr.Status != http.StatusTooManyRequests {
|
|
t.Errorf("Status = %d, want 429", apiErr.Status)
|
|
}
|
|
if apiErr.Code != "rate_limit_exceeded" {
|
|
t.Errorf("Code = %q, want %q", apiErr.Code, "rate_limit_exceeded")
|
|
}
|
|
if apiErr.Message != "Rate limit reached" {
|
|
t.Errorf("Message = %q", apiErr.Message)
|
|
}
|
|
if apiErr.Provider != "openai" || apiErr.Model != "gpt-test" {
|
|
t.Errorf("Provider/Model = %q/%q", apiErr.Provider, apiErr.Model)
|
|
}
|
|
if got := llm.Classify(err); got != llm.ClassTransient {
|
|
t.Errorf("Classify = %v, want transient", got)
|
|
}
|
|
})
|
|
t.Run("401 code null falls back to type, permanent", func(t *testing.T) {
|
|
const body = `{"error":{"message":"Incorrect API key provided","type":"authentication_error","param":null,"code":null}}`
|
|
srv, _ := newServer(t, http.StatusUnauthorized, body)
|
|
m := testModel(t, srv, nil)
|
|
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
apiErr, ok := errors.AsType[*llm.APIError](err)
|
|
if !ok {
|
|
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
|
}
|
|
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" {
|
|
t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code)
|
|
}
|
|
if got := llm.Classify(err); got != llm.ClassPermanent {
|
|
t.Errorf("Classify = %v, want permanent", got)
|
|
}
|
|
})
|
|
t.Run("non-JSON body becomes message", func(t *testing.T) {
|
|
srv, _ := newServer(t, http.StatusServiceUnavailable, "upstream exploded\n")
|
|
m := testModel(t, srv, nil)
|
|
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
apiErr, ok := errors.AsType[*llm.APIError](err)
|
|
if !ok {
|
|
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
|
}
|
|
if apiErr.Status != http.StatusServiceUnavailable || apiErr.Message != "upstream exploded" {
|
|
t.Errorf("Status/Message = %d/%q", apiErr.Status, apiErr.Message)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestMissingAPIKey(t *testing.T) {
|
|
t.Setenv("OPENAI_API_KEY", "")
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
m, err := New(WithBaseURL(srv.URL)).Model("gpt-test")
|
|
if err != nil {
|
|
t.Fatalf("Model: %v", err)
|
|
}
|
|
_, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
apiErr, ok := errors.AsType[*llm.APIError](err)
|
|
if !ok {
|
|
t.Fatalf("err = %v (%T), want *llm.APIError", err, err)
|
|
}
|
|
if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "missing_api_key" {
|
|
t.Errorf("Status/Code = %d/%q, want 401/missing_api_key", apiErr.Status, apiErr.Code)
|
|
}
|
|
if rec.hits != 0 {
|
|
t.Errorf("server hit %d times, want 0", rec.hits)
|
|
}
|
|
}
|
|
|
|
func TestEnvAPIKeyReadAtConstruction(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
t.Setenv("OPENAI_API_KEY", "env-secret")
|
|
p := New(WithBaseURL(srv.URL))
|
|
t.Setenv("OPENAI_API_KEY", "changed-later") // must not affect p
|
|
m, err := p.Model("gpt-test")
|
|
if err != nil {
|
|
t.Fatalf("Model: %v", err)
|
|
}
|
|
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if got := rec.header.Get("Authorization"); got != "Bearer env-secret" {
|
|
t.Errorf("Authorization = %q, want %q", got, "Bearer env-secret")
|
|
}
|
|
}
|
|
|
|
func TestAuthAndContentTypeHeaders(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
m := testModel(t, srv, nil)
|
|
if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if got := rec.header.Get("Authorization"); got != "Bearer test-key" {
|
|
t.Errorf("Authorization = %q, want %q", got, "Bearer test-key")
|
|
}
|
|
if got := rec.header.Get("Content-Type"); got != "application/json" {
|
|
t.Errorf("Content-Type = %q, want application/json", got)
|
|
}
|
|
if rec.path != "/chat/completions" {
|
|
t.Errorf("path = %q, want /chat/completions", rec.path)
|
|
}
|
|
}
|
|
|
|
func TestCompatEndpointNameAndBaseURL(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
p := New(WithName("groq"), WithAPIKey("k"), WithBaseURL(srv.URL+"/openai/v1/"))
|
|
if p.Name() != "groq" {
|
|
t.Errorf("Name = %q, want groq", p.Name())
|
|
}
|
|
m, err := p.Model("llama-3.3-70b")
|
|
if err != nil {
|
|
t.Fatalf("Model: %v", err)
|
|
}
|
|
resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if rec.path != "/openai/v1/chat/completions" {
|
|
t.Errorf("path = %q, want /openai/v1/chat/completions (trailing slash trimmed)", rec.path)
|
|
}
|
|
if resp.Model != "groq/llama-3.3-70b" {
|
|
t.Errorf("Model = %q, want groq/llama-3.3-70b", resp.Model)
|
|
}
|
|
if rec.body["model"] != "llama-3.3-70b" {
|
|
t.Errorf("wire model = %v, want llama-3.3-70b (verbatim)", rec.body["model"])
|
|
}
|
|
}
|
|
|
|
func TestCapabilityEnforcement(t *testing.T) {
|
|
img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) }
|
|
tests := []struct {
|
|
name string
|
|
caps *llm.Capabilities // nil = provider defaults
|
|
msg llm.Message
|
|
}{
|
|
{
|
|
name: "images unsupported",
|
|
caps: &llm.Capabilities{SupportsTools: true, SupportsStreaming: true},
|
|
msg: llm.UserParts(img("image/png", 4)),
|
|
},
|
|
{
|
|
name: "too many images",
|
|
caps: &llm.Capabilities{MaxImagesPerReq: 1},
|
|
msg: llm.UserParts(img("image/png", 4), img("image/png", 4)),
|
|
},
|
|
{
|
|
name: "disallowed MIME under defaults",
|
|
msg: llm.UserParts(img("image/bmp", 4)),
|
|
},
|
|
{
|
|
name: "image too large",
|
|
caps: &llm.Capabilities{MaxImagesPerReq: 4, MaxImageBytes: 2},
|
|
msg: llm.UserParts(img("image/png", 3)),
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
var mopts []llm.ModelOption
|
|
if tt.caps != nil {
|
|
mopts = append(mopts, llm.WithCapabilities(*tt.caps))
|
|
}
|
|
m := testModel(t, srv, nil, mopts...)
|
|
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{tt.msg}})
|
|
if !errors.Is(err, llm.ErrUnsupported) {
|
|
t.Fatalf("err = %v, want ErrUnsupported", err)
|
|
}
|
|
if got := llm.Classify(err); got != llm.ClassPermanent {
|
|
t.Errorf("Classify = %v, want permanent", got)
|
|
}
|
|
if rec.hits != 0 {
|
|
t.Errorf("server hit %d times, want 0 (must refuse before sending)", rec.hits)
|
|
}
|
|
})
|
|
}
|
|
|
|
t.Run("streaming unsupported", func(t *testing.T) {
|
|
srv, rec := newServer(t, http.StatusOK, textResponse)
|
|
m := testModel(t, srv, nil, llm.WithCapabilities(llm.Capabilities{SupportsTools: true}))
|
|
_, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
if !errors.Is(err, llm.ErrUnsupported) {
|
|
t.Fatalf("err = %v, want ErrUnsupported", err)
|
|
}
|
|
if rec.hits != 0 {
|
|
t.Errorf("server hit %d times, want 0", rec.hits)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestModelCapabilitiesOverride(t *testing.T) {
|
|
p := New(WithAPIKey("k"))
|
|
def, err := p.Model("a")
|
|
if err != nil {
|
|
t.Fatalf("Model: %v", err)
|
|
}
|
|
if caps := def.Capabilities(); !caps.SupportsTools || caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 20<<20 {
|
|
t.Errorf("default caps = %+v", caps)
|
|
}
|
|
custom := llm.Capabilities{SupportsStreaming: true, ContextWindow: 8192}
|
|
ovr, err := p.Model("b", llm.WithCapabilities(custom))
|
|
if err != nil {
|
|
t.Fatalf("Model: %v", err)
|
|
}
|
|
if got := ovr.Capabilities(); !reflect.DeepEqual(got, custom) {
|
|
t.Errorf("override caps = %+v, want %+v", got, custom)
|
|
}
|
|
}
|
|
|
|
func TestTransportErrorIsNotAPIError(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
|
url := srv.URL
|
|
srv.Close() // guarantee connection refused
|
|
p := New(WithAPIKey("k"), WithBaseURL(url))
|
|
m, err := p.Model("gpt-test")
|
|
if err != nil {
|
|
t.Fatalf("Model: %v", err)
|
|
}
|
|
_, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
if err == nil {
|
|
t.Fatal("Generate succeeded against closed server")
|
|
}
|
|
if _, ok := errors.AsType[*llm.APIError](err); ok {
|
|
t.Errorf("transport error wrapped in APIError: %v", err)
|
|
}
|
|
if !strings.Contains(err.Error(), "openai: do request") {
|
|
t.Errorf("err = %v, want openai: do request context", err)
|
|
}
|
|
if got := llm.Classify(err); got != llm.ClassTransient {
|
|
t.Errorf("Classify = %v, want transient (net error must stay visible)", got)
|
|
}
|
|
}
|
|
|
|
func TestDecodeErrorWrapped(t *testing.T) {
|
|
srv, _ := newServer(t, http.StatusOK, "{not json")
|
|
m := testModel(t, srv, nil)
|
|
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}})
|
|
if err == nil || !strings.Contains(err.Error(), "openai: decode response") {
|
|
t.Errorf("err = %v, want decode response context", err)
|
|
}
|
|
if _, ok := errors.AsType[*llm.APIError](err); ok {
|
|
t.Errorf("decode error wrapped in APIError: %v", err)
|
|
}
|
|
}
|