package openai import ( "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "reflect" "strings" "testing" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" ) var ( _ llm.Provider = (*Provider)(nil) _ llm.Model = (*model)(nil) _ llm.Stream = (*stream)(nil) ) const textResponse = `{ "id": "chatcmpl-1", "object": "chat.completion", "created": 1741570283, "model": "gpt-test", "choices": [{ "index": 0, "message": {"role": "assistant", "content": "hello", "refusal": null, "annotations": []}, "logprobs": null, "finish_reason": "stop" }], "usage": { "prompt_tokens": 19, "completion_tokens": 10, "total_tokens": 29, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0} }, "service_tier": "default", "system_fingerprint": "fp_x" }` // recorded captures the last request a test server received. type recorded struct { body map[string]any header http.Header path string hits int } // newServer starts a test server that records the request and replies with // a fixed status and body. func newServer(t *testing.T, status int, respBody string) (*httptest.Server, *recorded) { t.Helper() rec := &recorded{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rec.hits++ rec.header = r.Header.Clone() rec.path = r.URL.Path raw, err := io.ReadAll(r.Body) if err != nil { t.Errorf("read request body: %v", err) } if len(raw) > 0 { if err := json.Unmarshal(raw, &rec.body); err != nil { t.Errorf("request body is not JSON: %v\n%s", err, raw) } } w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) io.WriteString(w, respBody) })) t.Cleanup(srv.Close) return srv, rec } func testModel(t *testing.T, srv *httptest.Server, popts []Option, mopts ...llm.ModelOption) llm.Model { t.Helper() opts := append([]Option{WithAPIKey("test-key"), WithBaseURL(srv.URL)}, popts...) m, err := New(opts...).Model("gpt-test", mopts...) if err != nil { t.Fatalf("Model: %v", err) } return m } func fptr(f float64) *float64 { return &f } func TestGenerateRequestShape(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) m := testModel(t, srv, nil) req := llm.Request{ System: "base system", Messages: []llm.Message{ llm.SystemText("folded system"), llm.UserParts(llm.Text("look:"), llm.Image("image/png", []byte{1, 2, 3})), { Role: llm.RoleAssistant, Parts: []llm.Part{llm.Text("checking")}, ToolCalls: []llm.ToolCall{ {ID: "call_1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Boston"}`)}, }, }, llm.ToolResultsMessage( llm.ToolResult{ID: "call_1", Name: "get_weather", Content: "72F"}, llm.ToolResult{ID: "call_2", Name: "get_weather", Content: "boom", IsError: true}, ), llm.UserText("thanks"), }, Tools: []llm.Tool{{ Name: "get_weather", Description: "Get current weather", Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`), }}, ToolChoice: "auto", Temperature: fptr(0.5), TopP: fptr(0.9), MaxTokens: 256, StopSequences: []string{"END"}, ReasoningEffort: "high", Schema: json.RawMessage(`{"type":"object","properties":{"ok":{"type":"boolean"}}}`), SchemaName: "verdict", } if _, err := m.Generate(context.Background(), req); err != nil { t.Fatalf("Generate: %v", err) } want := map[string]any{ "model": "gpt-test", "messages": []any{ map[string]any{"role": "system", "content": "base system\n\nfolded system"}, map[string]any{"role": "user", "content": []any{ map[string]any{"type": "text", "text": "look:"}, map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,AQID"}}, }}, map[string]any{"role": "assistant", "content": "checking", "tool_calls": []any{ map[string]any{"id": "call_1", "type": "function", "function": map[string]any{ "name": "get_weather", "arguments": `{"city":"Boston"}`, }}, }}, map[string]any{"role": "tool", "content": "72F", "tool_call_id": "call_1"}, map[string]any{"role": "tool", "content": "ERROR: boom", "tool_call_id": "call_2"}, map[string]any{"role": "user", "content": "thanks"}, }, "tools": []any{ map[string]any{"type": "function", "function": map[string]any{ "name": "get_weather", "description": "Get current weather", "parameters": map[string]any{"type": "object", "properties": map[string]any{"city": map[string]any{"type": "string"}}}, }}, }, "tool_choice": "auto", "temperature": 0.5, "top_p": 0.9, "max_completion_tokens": float64(256), "stop": []any{"END"}, "reasoning_effort": "high", "response_format": map[string]any{"type": "json_schema", "json_schema": map[string]any{ "name": "verdict", "schema": map[string]any{"type": "object", "properties": map[string]any{"ok": map[string]any{"type": "boolean"}}}, }}, } if !reflect.DeepEqual(rec.body, want) { got, _ := json.MarshalIndent(rec.body, "", " ") exp, _ := json.MarshalIndent(want, "", " ") t.Errorf("request body mismatch\ngot:\n%s\nwant:\n%s", got, exp) } } func TestToolChoiceForms(t *testing.T) { tests := []struct { choice string want any // nil = key absent }{ {"", nil}, {"auto", "auto"}, {"none", "none"}, {"required", "required"}, {"get_weather", map[string]any{"type": "function", "function": map[string]any{"name": "get_weather"}}}, } for _, tt := range tests { t.Run("choice="+tt.choice, func(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) m := testModel(t, srv, nil) req := llm.Request{ Messages: []llm.Message{llm.UserText("hi")}, Tools: []llm.Tool{{Name: "get_weather"}}, ToolChoice: tt.choice, } if _, err := m.Generate(context.Background(), req); err != nil { t.Fatalf("Generate: %v", err) } got, present := rec.body["tool_choice"] if tt.want == nil { if present { t.Errorf("tool_choice present, want omitted: %v", got) } return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("tool_choice = %v, want %v", got, tt.want) } }) } } func TestMaxTokensField(t *testing.T) { t.Run("default uses max_completion_tokens", func(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) m := testModel(t, srv, nil) req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64} if _, err := m.Generate(context.Background(), req); err != nil { t.Fatalf("Generate: %v", err) } if got := rec.body["max_completion_tokens"]; got != float64(64) { t.Errorf("max_completion_tokens = %v, want 64", got) } if _, present := rec.body["max_tokens"]; present { t.Error("max_tokens present, want omitted") } }) t.Run("WithLegacyMaxTokens uses max_tokens", func(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) m := testModel(t, srv, []Option{WithLegacyMaxTokens()}) req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}, MaxTokens: 64} if _, err := m.Generate(context.Background(), req); err != nil { t.Fatalf("Generate: %v", err) } if got := rec.body["max_tokens"]; got != float64(64) { t.Errorf("max_tokens = %v, want 64", got) } if _, present := rec.body["max_completion_tokens"]; present { t.Error("max_completion_tokens present, want omitted") } }) t.Run("zero omits both", func(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) m := testModel(t, srv, nil) req := llm.Request{Messages: []llm.Message{llm.UserText("hi")}} if _, err := m.Generate(context.Background(), req); err != nil { t.Fatalf("Generate: %v", err) } if _, present := rec.body["max_tokens"]; present { t.Error("max_tokens present, want omitted") } if _, present := rec.body["max_completion_tokens"]; present { t.Error("max_completion_tokens present, want omitted") } }) } func TestSchemaNameDefault(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) m := testModel(t, srv, nil) req := llm.Request{ Messages: []llm.Message{llm.UserText("hi")}, Schema: json.RawMessage(`{"type":"object"}`), } if _, err := m.Generate(context.Background(), req); err != nil { t.Fatalf("Generate: %v", err) } rf, ok := rec.body["response_format"].(map[string]any) if !ok { t.Fatalf("response_format missing: %v", rec.body) } js, ok := rf["json_schema"].(map[string]any) if !ok { t.Fatalf("json_schema missing: %v", rf) } if js["name"] != "response" { t.Errorf("schema name = %v, want %q", js["name"], "response") } } func TestGenerateTextResponse(t *testing.T) { srv, _ := newServer(t, http.StatusOK, textResponse) m := testModel(t, srv, nil) resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err != nil { t.Fatalf("Generate: %v", err) } if got := resp.Text(); got != "hello" { t.Errorf("Text = %q, want %q", got, "hello") } if resp.FinishReason != llm.FinishStop { t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishStop) } if resp.Usage != (llm.Usage{InputTokens: 19, OutputTokens: 10}) { t.Errorf("Usage = %+v, want {19 10}", resp.Usage) } if resp.Model != "openai/gpt-test" { t.Errorf("Model = %q, want %q", resp.Model, "openai/gpt-test") } if len(resp.ToolCalls) != 0 { t.Errorf("ToolCalls = %v, want none", resp.ToolCalls) } if resp.Raw == nil { t.Error("Raw is nil, want wire response") } } func TestGenerateToolCallResponse(t *testing.T) { const body = `{ "id": "chatcmpl-2", "object": "chat.completion", "created": 1, "model": "gpt-test", "choices": [{ "index": 0, "message": {"role": "assistant", "content": null, "tool_calls": [ {"id": "call_9", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Boston\"}"}}, {"id": "", "type": "function", "function": {"name": "get_time", "arguments": "{}"}} ]}, "finish_reason": "stop" }], "usage": {"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7} }` srv, _ := newServer(t, http.StatusOK, body) m := testModel(t, srv, nil) resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err != nil { t.Fatalf("Generate: %v", err) } if len(resp.ToolCalls) != 2 { t.Fatalf("ToolCalls = %d, want 2", len(resp.ToolCalls)) } tc := resp.ToolCalls[0] if tc.ID != "call_9" || tc.Name != "get_weather" || string(tc.Arguments) != `{"city":"Boston"}` { t.Errorf("ToolCalls[0] = %+v", tc) } if resp.ToolCalls[1].ID != "call_1" { t.Errorf("synthesized ID = %q, want %q", resp.ToolCalls[1].ID, "call_1") } // finish_reason "stop" with tool_calls present: presence wins. if resp.FinishReason != llm.FinishToolCalls { t.Errorf("FinishReason = %v, want %v", resp.FinishReason, llm.FinishToolCalls) } if len(resp.Parts) != 0 { t.Errorf("Parts = %v, want none", resp.Parts) } } func TestFinishReasonMapping(t *testing.T) { tests := []struct { wire string want llm.FinishReason }{ {"stop", llm.FinishStop}, {"length", llm.FinishLength}, {"tool_calls", llm.FinishToolCalls}, {"content_filter", llm.FinishContentFilter}, {"function_call", llm.FinishOther}, {"weird_new_reason", llm.FinishOther}, } for _, tt := range tests { t.Run(tt.wire, func(t *testing.T) { body := `{"choices":[{"index":0,"message":{"role":"assistant","content":"x"},"finish_reason":"` + tt.wire + `"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` srv, _ := newServer(t, http.StatusOK, body) m := testModel(t, srv, nil) resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err != nil { t.Fatalf("Generate: %v", err) } if resp.FinishReason != tt.want { t.Errorf("FinishReason = %v, want %v", resp.FinishReason, tt.want) } }) } } func TestAPIErrorMapping(t *testing.T) { t.Run("429 rate limit is transient", func(t *testing.T) { const body = `{"error":{"message":"Rate limit reached","type":"rate_limit_error","param":null,"code":"rate_limit_exceeded"}}` srv, _ := newServer(t, http.StatusTooManyRequests, body) m := testModel(t, srv, nil) _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) apiErr, ok := errors.AsType[*llm.APIError](err) if !ok { t.Fatalf("err = %v (%T), want *llm.APIError", err, err) } if apiErr.Status != http.StatusTooManyRequests { t.Errorf("Status = %d, want 429", apiErr.Status) } if apiErr.Code != "rate_limit_exceeded" { t.Errorf("Code = %q, want %q", apiErr.Code, "rate_limit_exceeded") } if apiErr.Message != "Rate limit reached" { t.Errorf("Message = %q", apiErr.Message) } if apiErr.Provider != "openai" || apiErr.Model != "gpt-test" { t.Errorf("Provider/Model = %q/%q", apiErr.Provider, apiErr.Model) } if got := llm.Classify(err); got != llm.ClassTransient { t.Errorf("Classify = %v, want transient", got) } }) t.Run("401 code null falls back to type, permanent", func(t *testing.T) { const body = `{"error":{"message":"Incorrect API key provided","type":"authentication_error","param":null,"code":null}}` srv, _ := newServer(t, http.StatusUnauthorized, body) m := testModel(t, srv, nil) _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) apiErr, ok := errors.AsType[*llm.APIError](err) if !ok { t.Fatalf("err = %v (%T), want *llm.APIError", err, err) } if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "authentication_error" { t.Errorf("Status/Code = %d/%q", apiErr.Status, apiErr.Code) } if got := llm.Classify(err); got != llm.ClassPermanent { t.Errorf("Classify = %v, want permanent", got) } }) t.Run("non-JSON body becomes message", func(t *testing.T) { srv, _ := newServer(t, http.StatusServiceUnavailable, "upstream exploded\n") m := testModel(t, srv, nil) _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) apiErr, ok := errors.AsType[*llm.APIError](err) if !ok { t.Fatalf("err = %v (%T), want *llm.APIError", err, err) } if apiErr.Status != http.StatusServiceUnavailable || apiErr.Message != "upstream exploded" { t.Errorf("Status/Message = %d/%q", apiErr.Status, apiErr.Message) } }) } func TestMissingAPIKey(t *testing.T) { t.Setenv("OPENAI_API_KEY", "") srv, rec := newServer(t, http.StatusOK, textResponse) m, err := New(WithBaseURL(srv.URL)).Model("gpt-test") if err != nil { t.Fatalf("Model: %v", err) } _, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) apiErr, ok := errors.AsType[*llm.APIError](err) if !ok { t.Fatalf("err = %v (%T), want *llm.APIError", err, err) } if apiErr.Status != http.StatusUnauthorized || apiErr.Code != "missing_api_key" { t.Errorf("Status/Code = %d/%q, want 401/missing_api_key", apiErr.Status, apiErr.Code) } if rec.hits != 0 { t.Errorf("server hit %d times, want 0", rec.hits) } } func TestEnvAPIKeyReadAtConstruction(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) t.Setenv("OPENAI_API_KEY", "env-secret") p := New(WithBaseURL(srv.URL)) t.Setenv("OPENAI_API_KEY", "changed-later") // must not affect p m, err := p.Model("gpt-test") if err != nil { t.Fatalf("Model: %v", err) } if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil { t.Fatalf("Generate: %v", err) } if got := rec.header.Get("Authorization"); got != "Bearer env-secret" { t.Errorf("Authorization = %q, want %q", got, "Bearer env-secret") } } func TestAuthAndContentTypeHeaders(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) m := testModel(t, srv, nil) if _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}); err != nil { t.Fatalf("Generate: %v", err) } if got := rec.header.Get("Authorization"); got != "Bearer test-key" { t.Errorf("Authorization = %q, want %q", got, "Bearer test-key") } if got := rec.header.Get("Content-Type"); got != "application/json" { t.Errorf("Content-Type = %q, want application/json", got) } if rec.path != "/chat/completions" { t.Errorf("path = %q, want /chat/completions", rec.path) } } func TestCompatEndpointNameAndBaseURL(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) p := New(WithName("groq"), WithAPIKey("k"), WithBaseURL(srv.URL+"/openai/v1/")) if p.Name() != "groq" { t.Errorf("Name = %q, want groq", p.Name()) } m, err := p.Model("llama-3.3-70b") if err != nil { t.Fatalf("Model: %v", err) } resp, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err != nil { t.Fatalf("Generate: %v", err) } if rec.path != "/openai/v1/chat/completions" { t.Errorf("path = %q, want /openai/v1/chat/completions (trailing slash trimmed)", rec.path) } if resp.Model != "groq/llama-3.3-70b" { t.Errorf("Model = %q, want groq/llama-3.3-70b", resp.Model) } if rec.body["model"] != "llama-3.3-70b" { t.Errorf("wire model = %v, want llama-3.3-70b (verbatim)", rec.body["model"]) } } func TestCapabilityEnforcement(t *testing.T) { img := func(mime string, n int) llm.Part { return llm.Image(mime, make([]byte, n)) } tests := []struct { name string caps *llm.Capabilities // nil = provider defaults msg llm.Message }{ { name: "images unsupported", caps: &llm.Capabilities{SupportsTools: true, SupportsStreaming: true}, msg: llm.UserParts(img("image/png", 4)), }, { name: "too many images", caps: &llm.Capabilities{MaxImagesPerReq: 1}, msg: llm.UserParts(img("image/png", 4), img("image/png", 4)), }, { name: "disallowed MIME under defaults", msg: llm.UserParts(img("image/bmp", 4)), }, { name: "image too large", caps: &llm.Capabilities{MaxImagesPerReq: 4, MaxImageBytes: 2}, msg: llm.UserParts(img("image/png", 3)), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) var mopts []llm.ModelOption if tt.caps != nil { mopts = append(mopts, llm.WithCapabilities(*tt.caps)) } m := testModel(t, srv, nil, mopts...) _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{tt.msg}}) if !errors.Is(err, llm.ErrUnsupported) { t.Fatalf("err = %v, want ErrUnsupported", err) } if got := llm.Classify(err); got != llm.ClassPermanent { t.Errorf("Classify = %v, want permanent", got) } if rec.hits != 0 { t.Errorf("server hit %d times, want 0 (must refuse before sending)", rec.hits) } }) } t.Run("streaming unsupported", func(t *testing.T) { srv, rec := newServer(t, http.StatusOK, textResponse) m := testModel(t, srv, nil, llm.WithCapabilities(llm.Capabilities{SupportsTools: true})) _, err := m.Stream(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if !errors.Is(err, llm.ErrUnsupported) { t.Fatalf("err = %v, want ErrUnsupported", err) } if rec.hits != 0 { t.Errorf("server hit %d times, want 0", rec.hits) } }) } func TestModelCapabilitiesOverride(t *testing.T) { p := New(WithAPIKey("k")) def, err := p.Model("a") if err != nil { t.Fatalf("Model: %v", err) } if caps := def.Capabilities(); !caps.SupportsTools || caps.MaxImagesPerReq != 100 || caps.MaxImageBytes != 20<<20 { t.Errorf("default caps = %+v", caps) } custom := llm.Capabilities{SupportsStreaming: true, ContextWindow: 8192} ovr, err := p.Model("b", llm.WithCapabilities(custom)) if err != nil { t.Fatalf("Model: %v", err) } if got := ovr.Capabilities(); !reflect.DeepEqual(got, custom) { t.Errorf("override caps = %+v, want %+v", got, custom) } } func TestTransportErrorIsNotAPIError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) url := srv.URL srv.Close() // guarantee connection refused p := New(WithAPIKey("k"), WithBaseURL(url)) m, err := p.Model("gpt-test") if err != nil { t.Fatalf("Model: %v", err) } _, err = m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err == nil { t.Fatal("Generate succeeded against closed server") } if _, ok := errors.AsType[*llm.APIError](err); ok { t.Errorf("transport error wrapped in APIError: %v", err) } if !strings.Contains(err.Error(), "openai: do request") { t.Errorf("err = %v, want openai: do request context", err) } if got := llm.Classify(err); got != llm.ClassTransient { t.Errorf("Classify = %v, want transient (net error must stay visible)", got) } } func TestDecodeErrorWrapped(t *testing.T) { srv, _ := newServer(t, http.StatusOK, "{not json") m := testModel(t, srv, nil) _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{llm.UserText("hi")}}) if err == nil || !strings.Contains(err.Error(), "openai: decode response") { t.Errorf("err = %v, want decode response context", err) } if _, ok := errors.AsType[*llm.APIError](err); ok { t.Errorf("decode error wrapped in APIError: %v", err) } }