cbaf41f50c
Introduces an opt-in level-based reasoning toggle (low/medium/high) that each provider translates to its native parameter: - Anthropic: thinking.budget_tokens (1024/8000/24000), with temperature forced to default and MaxTokens auto-grown above the budget. - OpenAI/xAI/Groq via openaicompat: reasoning_effort string, gated by a new Rules.SupportsReasoning predicate so non-reasoning models don't receive the parameter. xAI uses Rules.MapReasoningEffort to remap "medium" to "high" since its API only accepts low|high. - Google: thinking_config.thinking_budget + include_thoughts:true. - DeepSeek: SupportsReasoning=false (reasoner is always-on; the reasoning_content trace was already extracted via openaicompat). Reasoning content is surfaced as Response.Thinking on Complete and as StreamEventThinking deltas during streaming. Provider-side: extracted from Anthropic thinking content blocks, Google's part.Thought=true parts, and the non-standard reasoning_content field that DeepSeek and Groq emit (parsed out of raw JSON since openai-go doesn't type it). Public API: - llm.ReasoningLevel + ReasoningLow/Medium/High constants - llm.WithReasoning(level) request option - Model.WithReasoning(level) for baked-in defaults - provider.Request.Reasoning, provider.Response.Thinking - provider.StreamEventThinking Tests cover Rules-based gating, MapReasoningEffort, reasoning_content extraction (Complete + Stream), Anthropic budget mapping, and temperature suppression when thinking is enabled. Existing behavior is unchanged when Reasoning is the empty string. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
470 lines
14 KiB
Go
470 lines
14 KiB
Go
package openaicompat_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/openai/openai-go"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/openaicompat"
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
)
|
|
|
|
// newTestServer returns a httptest server that captures the raw request body
|
|
// on POST /chat/completions and returns a canned OpenAI response so Complete()
|
|
// succeeds. Use `captured` to assert on what the provider would send.
|
|
func newTestServer(t *testing.T) (*httptest.Server, *[]byte) {
|
|
t.Helper()
|
|
var body []byte
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/chat/completions" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
b, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
t.Errorf("read body: %v", err)
|
|
}
|
|
body = b
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = io.WriteString(w, `{
|
|
"id": "cmpl-1",
|
|
"object": "chat.completion",
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {"role":"assistant","content":"ok"},
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}
|
|
}`)
|
|
}))
|
|
return srv, &body
|
|
}
|
|
|
|
func textReq(model, content string) provider.Request {
|
|
return provider.Request{
|
|
Model: model,
|
|
Messages: []provider.Message{{Role: "user", Content: content}},
|
|
}
|
|
}
|
|
|
|
func TestComplete_ZeroRulesPassesThrough(t *testing.T) {
|
|
srv, body := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
temp := 0.7
|
|
req := textReq("gpt-4o", "hi")
|
|
req.Temperature = &temp
|
|
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
|
resp, err := p.Complete(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
if resp.Text != "ok" {
|
|
t.Errorf("Text = %q, want %q", resp.Text, "ok")
|
|
}
|
|
|
|
// Temperature should be present since RestrictTemperature is nil.
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal(*body, &parsed); err != nil {
|
|
t.Fatalf("unmarshal request body: %v", err)
|
|
}
|
|
if _, ok := parsed["temperature"]; !ok {
|
|
t.Errorf("expected temperature in request body, got: %s", *body)
|
|
}
|
|
}
|
|
|
|
func TestComplete_RestrictTemperatureDropsField(t *testing.T) {
|
|
srv, body := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
temp := 0.7
|
|
req := textReq("o1", "hi")
|
|
req.Temperature = &temp
|
|
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
|
RestrictTemperature: func(m string) bool { return strings.HasPrefix(m, "o") },
|
|
})
|
|
if _, err := p.Complete(context.Background(), req); err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal(*body, &parsed); err != nil {
|
|
t.Fatalf("unmarshal: %v", err)
|
|
}
|
|
if _, ok := parsed["temperature"]; ok {
|
|
t.Errorf("temperature should be dropped for o1, got: %s", *body)
|
|
}
|
|
}
|
|
|
|
func TestComplete_SupportsVisionRejectsWhenFalse(t *testing.T) {
|
|
srv, _ := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
req := provider.Request{
|
|
Model: "deepseek-chat",
|
|
Messages: []provider.Message{{
|
|
Role: "user",
|
|
Content: "describe",
|
|
Images: []provider.Image{{URL: "https://example.com/a.png"}},
|
|
}},
|
|
}
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
|
SupportsVision: func(string) bool { return false },
|
|
})
|
|
_, err := p.Complete(context.Background(), req)
|
|
var fue *openaicompat.FeatureUnsupportedError
|
|
if !errors.As(err, &fue) {
|
|
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
|
}
|
|
if fue.Feature != "vision" || fue.Model != "deepseek-chat" {
|
|
t.Errorf("unexpected err: %+v", fue)
|
|
}
|
|
}
|
|
|
|
func TestComplete_SupportsToolsRejectsWhenFalse(t *testing.T) {
|
|
srv, _ := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
req := provider.Request{
|
|
Model: "deepseek-reasoner",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
Tools: []provider.ToolDef{
|
|
{Name: "get_weather", Description: "weather", Schema: map[string]any{"type": "object"}},
|
|
},
|
|
}
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
|
SupportsTools: func(m string) bool { return !strings.Contains(m, "reasoner") },
|
|
})
|
|
_, err := p.Complete(context.Background(), req)
|
|
var fue *openaicompat.FeatureUnsupportedError
|
|
if !errors.As(err, &fue) {
|
|
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
|
}
|
|
if fue.Feature != "tools" {
|
|
t.Errorf("feature = %q, want tools", fue.Feature)
|
|
}
|
|
}
|
|
|
|
func TestComplete_SupportsAudioRejectsWhenFalse(t *testing.T) {
|
|
srv, _ := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
req := provider.Request{
|
|
Model: "groq-llama",
|
|
Messages: []provider.Message{{
|
|
Role: "user",
|
|
Audio: []provider.Audio{{Base64: "AAA=", ContentType: "audio/wav"}},
|
|
}},
|
|
}
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
|
SupportsAudio: func(string) bool { return false },
|
|
})
|
|
_, err := p.Complete(context.Background(), req)
|
|
var fue *openaicompat.FeatureUnsupportedError
|
|
if !errors.As(err, &fue) {
|
|
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
|
}
|
|
if fue.Feature != "audio" {
|
|
t.Errorf("feature = %q, want audio", fue.Feature)
|
|
}
|
|
}
|
|
|
|
func TestComplete_MaxImagesPerMessage(t *testing.T) {
|
|
srv, _ := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
req := provider.Request{
|
|
Model: "anything",
|
|
Messages: []provider.Message{{
|
|
Role: "user",
|
|
Images: []provider.Image{
|
|
{URL: "a"}, {URL: "b"}, {URL: "c"},
|
|
},
|
|
}},
|
|
}
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{MaxImagesPerMessage: 2})
|
|
_, err := p.Complete(context.Background(), req)
|
|
if err == nil || !strings.Contains(err.Error(), "max allowed is 2") {
|
|
t.Fatalf("want max-images error, got %v", err)
|
|
}
|
|
|
|
// Exactly at limit succeeds.
|
|
req.Messages[0].Images = req.Messages[0].Images[:2]
|
|
if _, err := p.Complete(context.Background(), req); err != nil {
|
|
t.Errorf("at-limit request should succeed, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestComplete_CustomizeRequestInvoked(t *testing.T) {
|
|
srv, body := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
called := false
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
|
CustomizeRequest: func(params *openai.ChatCompletionNewParams) {
|
|
called = true
|
|
// Confirm we receive a non-empty built request.
|
|
if params.Model != "gpt-4o" {
|
|
t.Errorf("CustomizeRequest saw model %q, want gpt-4o", params.Model)
|
|
}
|
|
// Mutation here should end up on the wire.
|
|
params.User = openai.String("test-user")
|
|
},
|
|
})
|
|
if _, err := p.Complete(context.Background(), textReq("gpt-4o", "hi")); err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
if !called {
|
|
t.Fatal("CustomizeRequest hook was not invoked")
|
|
}
|
|
if !strings.Contains(string(*body), `"user":"test-user"`) {
|
|
t.Errorf("mutation from CustomizeRequest not reflected on wire: %s", *body)
|
|
}
|
|
}
|
|
|
|
func TestStream_EmitsDoneAndText(t *testing.T) {
|
|
// SSE stream with one content chunk then [DONE].
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
for _, line := range []string{
|
|
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hel"}}]}`,
|
|
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"lo"}}]}`,
|
|
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`,
|
|
`data: [DONE]`,
|
|
} {
|
|
_, _ = io.WriteString(w, line+"\n\n")
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
|
events := make(chan provider.StreamEvent, 16)
|
|
go func() {
|
|
_ = p.Stream(context.Background(), textReq("gpt-4o", "hi"), events)
|
|
close(events)
|
|
}()
|
|
|
|
var text strings.Builder
|
|
var sawDone bool
|
|
var doneUsage *provider.Usage
|
|
for ev := range events {
|
|
switch ev.Type {
|
|
case provider.StreamEventText:
|
|
text.WriteString(ev.Text)
|
|
case provider.StreamEventDone:
|
|
sawDone = true
|
|
if ev.Response != nil {
|
|
doneUsage = ev.Response.Usage
|
|
}
|
|
}
|
|
}
|
|
if text.String() != "hello" {
|
|
t.Errorf("got text %q, want %q", text.String(), "hello")
|
|
}
|
|
if !sawDone {
|
|
t.Fatal("no Done event emitted")
|
|
}
|
|
if doneUsage == nil || doneUsage.TotalTokens != 3 {
|
|
t.Errorf("usage on Done = %+v, want TotalTokens=3", doneUsage)
|
|
}
|
|
}
|
|
|
|
func TestComplete_ReasoningEffortPassthrough(t *testing.T) {
|
|
srv, body := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
req := textReq("o3-mini", "hi")
|
|
req.Reasoning = "high"
|
|
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
|
if _, err := p.Complete(context.Background(), req); err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal(*body, &parsed); err != nil {
|
|
t.Fatalf("unmarshal: %v", err)
|
|
}
|
|
if parsed["reasoning_effort"] != "high" {
|
|
t.Errorf("reasoning_effort = %v, want \"high\"; body: %s", parsed["reasoning_effort"], *body)
|
|
}
|
|
}
|
|
|
|
func TestComplete_SupportsReasoningGate(t *testing.T) {
|
|
srv, body := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
req := textReq("gpt-4o", "hi")
|
|
req.Reasoning = "high"
|
|
|
|
// SupportsReasoning returns false → reasoning_effort must NOT be sent.
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
|
SupportsReasoning: func(string) bool { return false },
|
|
})
|
|
if _, err := p.Complete(context.Background(), req); err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal(*body, &parsed); err != nil {
|
|
t.Fatalf("unmarshal: %v", err)
|
|
}
|
|
if _, ok := parsed["reasoning_effort"]; ok {
|
|
t.Errorf("reasoning_effort should be absent when SupportsReasoning=false; body: %s", *body)
|
|
}
|
|
}
|
|
|
|
func TestComplete_MapReasoningEffort(t *testing.T) {
|
|
srv, body := newTestServer(t)
|
|
defer srv.Close()
|
|
|
|
req := textReq("grok-3-mini", "hi")
|
|
req.Reasoning = "medium"
|
|
|
|
// xAI-style mapping: medium → high.
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
|
MapReasoningEffort: func(level string) string {
|
|
if level == "medium" {
|
|
return "high"
|
|
}
|
|
return level
|
|
},
|
|
})
|
|
if _, err := p.Complete(context.Background(), req); err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
var parsed map[string]any
|
|
if err := json.Unmarshal(*body, &parsed); err != nil {
|
|
t.Fatalf("unmarshal: %v", err)
|
|
}
|
|
if parsed["reasoning_effort"] != "high" {
|
|
t.Errorf("reasoning_effort = %v, want \"high\" after medium→high remap; body: %s", parsed["reasoning_effort"], *body)
|
|
}
|
|
}
|
|
|
|
func TestComplete_ReasoningContentExtracted(t *testing.T) {
|
|
// Server returns a DeepSeek-style response with reasoning_content alongside content.
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = io.WriteString(w, `{
|
|
"id": "cmpl-1",
|
|
"object": "chat.completion",
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {
|
|
"role":"assistant",
|
|
"content":"42",
|
|
"reasoning_content":"the user asked for the answer..."
|
|
},
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}
|
|
}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
|
resp, err := p.Complete(context.Background(), textReq("deepseek-reasoner", "?"))
|
|
if err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
if resp.Text != "42" {
|
|
t.Errorf("Text = %q, want %q", resp.Text, "42")
|
|
}
|
|
if !strings.Contains(resp.Thinking, "the user asked for") {
|
|
t.Errorf("Thinking = %q, want it to contain the reasoning trace", resp.Thinking)
|
|
}
|
|
}
|
|
|
|
func TestStream_ReasoningContentEmitsThinkingEvents(t *testing.T) {
|
|
// Two SSE chunks, each with a reasoning_content delta, then a final done chunk.
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
for _, line := range []string{
|
|
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"reasoning_content":"think "}}]}`,
|
|
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"reasoning_content":"hard","content":"42"}}]}`,
|
|
`data: {"id":"1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`,
|
|
`data: [DONE]`,
|
|
} {
|
|
_, _ = io.WriteString(w, line+"\n\n")
|
|
if flusher != nil {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{})
|
|
events := make(chan provider.StreamEvent, 32)
|
|
go func() {
|
|
_ = p.Stream(context.Background(), textReq("deepseek-reasoner", "?"), events)
|
|
close(events)
|
|
}()
|
|
|
|
var thinking strings.Builder
|
|
var sawDone bool
|
|
var doneThinking string
|
|
for ev := range events {
|
|
switch ev.Type {
|
|
case provider.StreamEventThinking:
|
|
thinking.WriteString(ev.Text)
|
|
case provider.StreamEventDone:
|
|
sawDone = true
|
|
if ev.Response != nil {
|
|
doneThinking = ev.Response.Thinking
|
|
}
|
|
}
|
|
}
|
|
if thinking.String() != "think hard" {
|
|
t.Errorf("streamed thinking = %q, want %q", thinking.String(), "think hard")
|
|
}
|
|
if !sawDone {
|
|
t.Fatal("no Done event")
|
|
}
|
|
if doneThinking != "think hard" {
|
|
t.Errorf("Done.Response.Thinking = %q, want %q", doneThinking, "think hard")
|
|
}
|
|
}
|
|
|
|
func TestStream_RulesCheckedBeforeNetwork(t *testing.T) {
|
|
// Server should never be hit when rules reject up front.
|
|
hit := false
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
hit = true
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := openaicompat.New("test-key", srv.URL, openaicompat.Rules{
|
|
SupportsVision: func(string) bool { return false },
|
|
})
|
|
req := provider.Request{
|
|
Model: "no-vision-model",
|
|
Messages: []provider.Message{{
|
|
Role: "user",
|
|
Images: []provider.Image{{URL: "a"}},
|
|
}},
|
|
}
|
|
events := make(chan provider.StreamEvent, 4)
|
|
err := p.Stream(context.Background(), req, events)
|
|
var fue *openaicompat.FeatureUnsupportedError
|
|
if !errors.As(err, &fue) {
|
|
t.Fatalf("want FeatureUnsupportedError, got %v", err)
|
|
}
|
|
if hit {
|
|
t.Error("server was contacted despite Rules violation")
|
|
}
|
|
}
|