package openaicompat_test import ( "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/openai/openai-go" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) // newTestServer returns a httptest server that captures the raw request body // on POST /chat/completions and returns a canned OpenAI response so Complete() // succeeds. Use `captured` to assert on what the provider would send. func newTestServer(t *testing.T) (*httptest.Server, *[]byte) { t.Helper() var body []byte srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/chat/completions" { http.NotFound(w, r) return } b, err := io.ReadAll(r.Body) if err != nil { t.Errorf("read body: %v", err) } body = b w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{ "id": "cmpl-1", "object": "chat.completion", "choices": [{ "index": 0, "message": {"role":"assistant","content":"ok"}, "finish_reason": "stop" }], "usage": {"prompt_tokens":1,"completion_tokens":1,"total_tokens":2} }`) })) return srv, &body } func textReq(model, content string) provider.Request { return provider.Request{ Model: model, Messages: []provider.Message{{Role: "user", Content: content}}, } } func TestComplete_ZeroRulesPassesThrough(t *testing.T) { srv, body := newTestServer(t) defer srv.Close() temp := 0.7 req := textReq("gpt-4o", "hi") req.Temperature = &temp p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{}) resp, err := p.Complete(context.Background(), req) if err != nil { t.Fatalf("Complete: %v", err) } if resp.Text != "ok" { t.Errorf("Text = %q, want %q", resp.Text, "ok") } // Temperature should be present since RestrictTemperature is nil. var parsed map[string]any if err := json.Unmarshal(*body, &parsed); err != nil { t.Fatalf("unmarshal request body: %v", err) } if _, ok := parsed["temperature"]; !ok { t.Errorf("expected temperature in request body, got: %s", *body) } } func TestComplete_RestrictTemperatureDropsField(t *testing.T) { srv, body := newTestServer(t) defer srv.Close() temp := 0.7 req := textReq("o1", "hi") req.Temperature = &temp p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ RestrictTemperature: func(m string) bool { return strings.HasPrefix(m, "o") }, }) if _, err := p.Complete(context.Background(), req); err != nil { t.Fatalf("Complete: %v", err) } var parsed map[string]any if err := json.Unmarshal(*body, &parsed); err != nil { t.Fatalf("unmarshal: %v", err) } if _, ok := parsed["temperature"]; ok { t.Errorf("temperature should be dropped for o1, got: %s", *body) } } func TestComplete_SupportsVisionRejectsWhenFalse(t *testing.T) { srv, _ := newTestServer(t) defer srv.Close() req := provider.Request{ Model: "deepseek-chat", Messages: []provider.Message{{ Role: "user", Content: "describe", Images: []provider.Image{{URL: "https://example.com/a.png"}}, }}, } p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ SupportsVision: func(string) bool { return false }, }) _, err := p.Complete(context.Background(), req) var fue *openaicompat.FeatureUnsupportedError if !errors.As(err, &fue) { t.Fatalf("want FeatureUnsupportedError, got %v", err) } if fue.Feature != "vision" || fue.Model != "deepseek-chat" { t.Errorf("unexpected err: %+v", fue) } } func TestComplete_SupportsToolsRejectsWhenFalse(t *testing.T) { srv, _ := newTestServer(t) defer srv.Close() req := provider.Request{ Model: "deepseek-reasoner", Messages: []provider.Message{{Role: "user", Content: "hi"}}, Tools: []provider.ToolDef{ {Name: "get_weather", Description: "weather", Schema: map[string]any{"type": "object"}}, }, } p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ SupportsTools: func(m string) bool { return !strings.Contains(m, "reasoner") }, }) _, err := p.Complete(context.Background(), req) var fue *openaicompat.FeatureUnsupportedError if !errors.As(err, &fue) { t.Fatalf("want FeatureUnsupportedError, got %v", err) } if fue.Feature != "tools" { t.Errorf("feature = %q, want tools", fue.Feature) } } func TestComplete_SupportsAudioRejectsWhenFalse(t *testing.T) { srv, _ := newTestServer(t) defer srv.Close() req := provider.Request{ Model: "groq-llama", Messages: []provider.Message{{ Role: "user", Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}}, }}, } p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ SupportsAudio: func(string) bool { return false }, }) _, err := p.Complete(context.Background(), req) var fue *openaicompat.FeatureUnsupportedError if !errors.As(err, &fue) { t.Fatalf("want FeatureUnsupportedError, got %v", err) } if fue.Feature != "audio" { t.Errorf("feature = %q, want audio", fue.Feature) } } func TestComplete_MaxImagesPerMessage(t *testing.T) { srv, _ := newTestServer(t) defer srv.Close() req := provider.Request{ Model: "anything", Messages: []provider.Message{{ Role: "user", Images: []provider.Image{ {URL: "a"}, {URL: "b"}, {URL: "c"}, }, }}, } p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{MaxImagesPerMessage: 2}) _, err := p.Complete(context.Background(), req) if err == nil || !strings.Contains(err.Error(), "max allowed is 2") { t.Fatalf("want max-images error, got %v", err) } // Exactly at limit succeeds. req.Messages[0].Images = req.Messages[0].Images[:2] if _, err := p.Complete(context.Background(), req); err != nil { t.Errorf("at-limit request should succeed, got %v", err) } } func TestComplete_CustomizeRequestInvoked(t *testing.T) { srv, body := newTestServer(t) defer srv.Close() called := false p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ CustomizeRequest: func(params *openai.ChatCompletionNewParams) { called = true // Confirm we receive a non-empty built request. if params.Model != "gpt-4o" { t.Errorf("CustomizeRequest saw model %q, want gpt-4o", params.Model) } // Mutation here should end up on the wire. params.User = openai.String("test-user") }, }) if _, err := p.Complete(context.Background(), textReq("gpt-4o", "hi")); err != nil { t.Fatalf("Complete: %v", err) } if !called { t.Fatal("CustomizeRequest hook was not invoked") } if !strings.Contains(string(*body), `"user":"test-user"`) { t.Errorf("mutation from CustomizeRequest not reflected on wire: %s", *body) } } func TestStream_EmitsDoneAndText(t *testing.T) { // SSE stream with one content chunk then [DONE]. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) for _, line := range []string{ `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hel"}}]}`, `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"lo"}}]}`, `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, `data: [DONE]`, } { _, _ = io.WriteString(w, line+"\n\n") if flusher != nil { flusher.Flush() } } })) defer srv.Close() p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{}) events := make(chan provider.StreamEvent, 16) go func() { _ = p.Stream(context.Background(), textReq("gpt-4o", "hi"), events) close(events) }() var text strings.Builder var sawDone bool var doneUsage *provider.Usage for ev := range events { switch ev.Type { case provider.StreamEventText: text.WriteString(ev.Text) case provider.StreamEventDone: sawDone = true if ev.Response != nil { doneUsage = ev.Response.Usage } } } if text.String() != "hello" { t.Errorf("got text %q, want %q", text.String(), "hello") } if !sawDone { t.Fatal("no Done event emitted") } if doneUsage == nil || doneUsage.TotalTokens != 3 { t.Errorf("usage on Done = %+v, want TotalTokens=3", doneUsage) } } func TestComplete_ReasoningEffortPassthrough(t *testing.T) { srv, body := newTestServer(t) defer srv.Close() req := textReq("o3-mini", "hi") req.Reasoning = "high" p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{}) if _, err := p.Complete(context.Background(), req); err != nil { t.Fatalf("Complete: %v", err) } var parsed map[string]any if err := json.Unmarshal(*body, &parsed); err != nil { t.Fatalf("unmarshal: %v", err) } if parsed["reasoning_effort"] != "high" { t.Errorf("reasoning_effort = %v, want \"high\"; body: %s", parsed["reasoning_effort"], *body) } } func TestComplete_SupportsReasoningGate(t *testing.T) { srv, body := newTestServer(t) defer srv.Close() req := textReq("gpt-4o", "hi") req.Reasoning = "high" // SupportsReasoning returns false → reasoning_effort must NOT be sent. p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ SupportsReasoning: func(string) bool { return false }, }) if _, err := p.Complete(context.Background(), req); err != nil { t.Fatalf("Complete: %v", err) } var parsed map[string]any if err := json.Unmarshal(*body, &parsed); err != nil { t.Fatalf("unmarshal: %v", err) } if _, ok := parsed["reasoning_effort"]; ok { t.Errorf("reasoning_effort should be absent when SupportsReasoning=false; body: %s", *body) } } func TestComplete_MapReasoningEffort(t *testing.T) { srv, body := newTestServer(t) defer srv.Close() req := textReq("grok-3-mini", "hi") req.Reasoning = "medium" // xAI-style mapping: medium → high. p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ MapReasoningEffort: func(level string) string { if level == "medium" { return "high" } return level }, }) if _, err := p.Complete(context.Background(), req); err != nil { t.Fatalf("Complete: %v", err) } var parsed map[string]any if err := json.Unmarshal(*body, &parsed); err != nil { t.Fatalf("unmarshal: %v", err) } if parsed["reasoning_effort"] != "high" { t.Errorf("reasoning_effort = %v, want \"high\" after medium→high remap; body: %s", parsed["reasoning_effort"], *body) } } func TestComplete_ReasoningContentExtracted(t *testing.T) { // Server returns a DeepSeek-style response with reasoning_content alongside content. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{ "id": "cmpl-1", "object": "chat.completion", "choices": [{ "index": 0, "message": { "role":"assistant", "content":"42", "reasoning_content":"the user asked for the answer..." }, "finish_reason": "stop" }], "usage": {"prompt_tokens":1,"completion_tokens":2,"total_tokens":3} }`) })) defer srv.Close() p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{}) resp, err := p.Complete(context.Background(), textReq("deepseek-reasoner", "?")) if err != nil { t.Fatalf("Complete: %v", err) } if resp.Text != "42" { t.Errorf("Text = %q, want %q", resp.Text, "42") } if !strings.Contains(resp.Thinking, "the user asked for") { t.Errorf("Thinking = %q, want it to contain the reasoning trace", resp.Thinking) } } func TestStream_ReasoningContentEmitsThinkingEvents(t *testing.T) { // Two SSE chunks, each with a reasoning_content delta, then a final done chunk. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) for _, line := range []string{ `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"reasoning_content":"think "}}]}`, `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"reasoning_content":"hard","content":"42"}}]}`, `data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, `data: [DONE]`, } { _, _ = io.WriteString(w, line+"\n\n") if flusher != nil { flusher.Flush() } } })) defer srv.Close() p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{}) events := make(chan provider.StreamEvent, 32) go func() { _ = p.Stream(context.Background(), textReq("deepseek-reasoner", "?"), events) close(events) }() var thinking strings.Builder var sawDone bool var doneThinking string for ev := range events { switch ev.Type { case provider.StreamEventThinking: thinking.WriteString(ev.Text) case provider.StreamEventDone: sawDone = true if ev.Response != nil { doneThinking = ev.Response.Thinking } } } if thinking.String() != "think hard" { t.Errorf("streamed thinking = %q, want %q", thinking.String(), "think hard") } if !sawDone { t.Fatal("no Done event") } if doneThinking != "think hard" { t.Errorf("Done.Response.Thinking = %q, want %q", doneThinking, "think hard") } } func TestStream_RulesCheckedBeforeNetwork(t *testing.T) { // Server should never be hit when rules reject up front. hit := false srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hit = true w.WriteHeader(http.StatusInternalServerError) })) defer srv.Close() p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{ SupportsVision: func(string) bool { return false }, }) req := provider.Request{ Model: "no-vision-model", Messages: []provider.Message{{ Role: "user", Images: []provider.Image{{URL: "a"}}, }}, } events := make(chan provider.StreamEvent, 4) err := p.Stream(context.Background(), req, events) var fue *openaicompat.FeatureUnsupportedError if !errors.As(err, &fue) { t.Fatalf("want FeatureUnsupportedError, got %v", err) } if hit { t.Error("server was contacted despite Rules violation") } }