Files
go-llm/v2/openaicompat/openaicompat.go
T
steve cbaf41f50c
CI / Root Module (push) Failing after 1m30s
CI / Lint (push) Failing after 1m1s
CI / V2 Module (push) Successful in 3m41s
feat(v2): add ReasoningLevel option; thinking/reasoning across providers
Introduces an opt-in level-based reasoning toggle (low/medium/high) that
each provider translates to its native parameter:

- Anthropic: thinking.budget_tokens (1024/8000/24000), with temperature
  forced to default and MaxTokens auto-grown above the budget.
- OpenAI/xAI/Groq via openaicompat: reasoning_effort string, gated by a
  new Rules.SupportsReasoning predicate so non-reasoning models don't
  receive the parameter. xAI uses Rules.MapReasoningEffort to remap
  "medium" to "high" since its API only accepts low|high.
- Google: thinking_config.thinking_budget + include_thoughts:true.
- DeepSeek: SupportsReasoning=false (reasoner is always-on; the
  reasoning_content trace was already extracted via openaicompat).

Reasoning content is surfaced as Response.Thinking on Complete and as
StreamEventThinking deltas during streaming. Provider-side: extracted
from Anthropic thinking content blocks, Google's part.Thought=true
parts, and the non-standard reasoning_content field that DeepSeek and
Groq emit (parsed out of raw JSON since openai-go doesn't type it).

Public API:
  - llm.ReasoningLevel + ReasoningLow/Medium/High constants
  - llm.WithReasoning(level) request option
  - Model.WithReasoning(level) for baked-in defaults
  - provider.Request.Reasoning, provider.Response.Thinking
  - provider.StreamEventThinking

Tests cover Rules-based gating, MapReasoningEffort, reasoning_content
extraction (Complete + Stream), Anthropic budget mapping, and
temperature suppression when thinking is enabled. Existing behavior is
unchanged when Reasoning is the empty string.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-25 03:58:42 +00:00

596 lines
18 KiB
Go

// Package openaicompat implements a shared chat-completion Provider for any
// service that speaks the OpenAI Chat Completions API (OpenAI itself, DeepSeek,
// Moonshot, xAI, Groq, Ollama, and friends).
//
// Most providers differ from vanilla OpenAI only in endpoint URL and a handful
// of per-model quirks (e.g., "this model is text-only", "this model doesn't
// accept tools", "drop temperature on reasoning models"). Those quirks are
// captured declaratively via Rules, so a concrete provider package is usually
// a one-function wrapper that calls New with its own base URL and Rules.
package openaicompat
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"path"
"strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
"github.com/openai/openai-go/shared"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// Rules encodes provider-specific constraints on top of the OpenAI wire
// protocol. The zero value means "no restrictions" and behaves like vanilla
// OpenAI. Individual fields are documented inline.
type Rules struct {
// MaxImagesPerMessage rejects requests whose any single message carries
// more images than this cap. 0 means "no cap".
MaxImagesPerMessage int
// MaxAudioPerMessage rejects requests whose any single message carries
// more audio attachments than this cap. 0 means "no cap".
MaxAudioPerMessage int
// SupportsVision, when non-nil, is consulted for every request that
// includes any image attachments. If it returns false for the request's
// model, the call fails with a FeatureUnsupportedError before hitting
// the network.
SupportsVision func(model string) bool
// SupportsTools, when non-nil, is consulted for every request that
// includes any tool definitions. If it returns false for the model,
// the call fails with a FeatureUnsupportedError before hitting the
// network.
SupportsTools func(model string) bool
// SupportsAudio, when non-nil, is consulted for every request that
// includes any audio attachments. If it returns false for the model,
// the call fails with a FeatureUnsupportedError.
SupportsAudio func(model string) bool
// RestrictTemperature, when non-nil and returning true for the request's
// model, causes the Temperature field to be silently dropped from the
// outgoing request. Used by OpenAI o-series and gpt-5* which reject a
// user-provided temperature.
RestrictTemperature func(model string) bool
// CustomizeRequest is a last-mile hook invoked after buildRequest but
// before the call is sent. It receives the fully built OpenAI SDK
// parameters and may mutate them freely (add headers, flip flags, tweak
// response_format, etc.).
CustomizeRequest func(params *openai.ChatCompletionNewParams)
// SupportsReasoning, when non-nil and returning false for the request's
// model, causes the request's Reasoning field to be silently dropped
// from the outgoing request. Used by providers (e.g., OpenAI) where
// reasoning_effort is rejected on non-reasoning models. nil = always
// pass reasoning_effort through when set.
SupportsReasoning func(model string) bool
// MapReasoningEffort, when non-nil, maps the standardized go-llm
// ReasoningLevel ("low"|"medium"|"high") to the provider's wire-level
// effort string. Used by xAI which only accepts "low"|"high" (callers
// remap "medium" to "high"). nil = pass-through unchanged.
MapReasoningEffort func(level string) string
}
// FeatureUnsupportedError is returned when a Rules predicate rejects a request
// because the target model does not support a feature the caller included.
type FeatureUnsupportedError struct {
Feature string
Model string
}
func (e *FeatureUnsupportedError) Error() string {
return fmt.Sprintf("openaicompat: model %q does not support %s", e.Model, e.Feature)
}
// Provider implements provider.Provider for any OpenAI-compatible endpoint.
type Provider struct {
apiKey string
baseURL string
rules Rules
}
// New creates a Provider. baseURL may be empty to let the OpenAI SDK use its
// default; in practice concrete provider packages always pass a default.
func New(apiKey, baseURL string, rules Rules) *Provider {
return &Provider{apiKey: apiKey, baseURL: baseURL, rules: rules}
}
// Complete performs a non-streaming completion.
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
if err := p.checkRules(req); err != nil {
return provider.Response{}, err
}
cl := openai.NewClient(p.requestOptions()...)
oaiReq := p.buildRequest(req)
if p.rules.CustomizeRequest != nil {
p.rules.CustomizeRequest(&oaiReq)
}
resp, err := cl.Chat.Completions.New(ctx, oaiReq)
if err != nil {
return provider.Response{}, fmt.Errorf("openai completion error: %w", err)
}
return p.convertResponse(resp), nil
}
// Stream performs a streaming completion.
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
if err := p.checkRules(req); err != nil {
return err
}
cl := openai.NewClient(p.requestOptions()...)
oaiReq := p.buildRequest(req)
oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{
IncludeUsage: openai.Bool(true),
}
if p.rules.CustomizeRequest != nil {
p.rules.CustomizeRequest(&oaiReq)
}
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
var fullText strings.Builder
var fullThinking strings.Builder
var toolCalls []provider.ToolCall
toolCallArgs := map[int]*strings.Builder{}
var usage *provider.Usage
for stream.Next() {
chunk := stream.Current()
// Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true)
if chunk.Usage.TotalTokens > 0 {
usage = &provider.Usage{
InputTokens: int(chunk.Usage.PromptTokens),
OutputTokens: int(chunk.Usage.CompletionTokens),
TotalTokens: int(chunk.Usage.TotalTokens),
Details: extractUsageDetails(chunk.Usage),
}
}
for _, choice := range chunk.Choices {
// Text delta
if choice.Delta.Content != "" {
fullText.WriteString(choice.Delta.Content)
events <- provider.StreamEvent{
Type: provider.StreamEventText,
Text: choice.Delta.Content,
}
}
// Reasoning/thinking delta — DeepSeek and Groq use a non-standard
// "reasoning_content" field on the delta. Extract it from the
// raw JSON since the OpenAI SDK doesn't surface it as a typed
// field.
if rc := extractReasoningContent(choice.Delta.RawJSON()); rc != "" {
fullThinking.WriteString(rc)
events <- provider.StreamEvent{
Type: provider.StreamEventThinking,
Text: rc,
}
}
// Tool call deltas
for _, tc := range choice.Delta.ToolCalls {
idx := int(tc.Index)
if tc.ID != "" {
// New tool call starting
for len(toolCalls) <= idx {
toolCalls = append(toolCalls, provider.ToolCall{})
}
toolCalls[idx].ID = tc.ID
toolCalls[idx].Name = tc.Function.Name
toolCallArgs[idx] = &strings.Builder{}
events <- provider.StreamEvent{
Type: provider.StreamEventToolStart,
ToolCall: &provider.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
},
ToolIndex: idx,
}
}
if tc.Function.Arguments != "" {
if b, ok := toolCallArgs[idx]; ok {
b.WriteString(tc.Function.Arguments)
}
events <- provider.StreamEvent{
Type: provider.StreamEventToolDelta,
ToolIndex: idx,
ToolCall: &provider.ToolCall{
Arguments: tc.Function.Arguments,
},
}
}
}
}
}
if err := stream.Err(); err != nil {
return fmt.Errorf("openai stream error: %w", err)
}
// Finalize tool calls
for idx := range toolCalls {
if b, ok := toolCallArgs[idx]; ok {
toolCalls[idx].Arguments = b.String()
}
events <- provider.StreamEvent{
Type: provider.StreamEventToolEnd,
ToolIndex: idx,
ToolCall: &toolCalls[idx],
}
}
events <- provider.StreamEvent{
Type: provider.StreamEventDone,
Response: &provider.Response{
Text: fullText.String(),
Thinking: fullThinking.String(),
ToolCalls: toolCalls,
Usage: usage,
},
}
return nil
}
func (p *Provider) requestOptions() []option.RequestOption {
opts := []option.RequestOption{option.WithAPIKey(p.apiKey)}
if p.baseURL != "" {
opts = append(opts, option.WithBaseURL(p.baseURL))
}
return opts
}
// checkRules applies all Rules predicates against a request and returns an
// error if any constraint is violated. Runs before any network call.
func (p *Provider) checkRules(req provider.Request) error {
var hasImages, hasAudio bool
for _, msg := range req.Messages {
if len(msg.Images) > 0 {
hasImages = true
}
if len(msg.Audio) > 0 {
hasAudio = true
}
if p.rules.MaxImagesPerMessage > 0 && len(msg.Images) > p.rules.MaxImagesPerMessage {
return fmt.Errorf("openaicompat: message has %d images, max allowed is %d for model %q",
len(msg.Images), p.rules.MaxImagesPerMessage, req.Model)
}
if p.rules.MaxAudioPerMessage > 0 && len(msg.Audio) > p.rules.MaxAudioPerMessage {
return fmt.Errorf("openaicompat: message has %d audio attachments, max allowed is %d for model %q",
len(msg.Audio), p.rules.MaxAudioPerMessage, req.Model)
}
}
if hasImages && p.rules.SupportsVision != nil && !p.rules.SupportsVision(req.Model) {
return &FeatureUnsupportedError{Feature: "vision", Model: req.Model}
}
if hasAudio && p.rules.SupportsAudio != nil && !p.rules.SupportsAudio(req.Model) {
return &FeatureUnsupportedError{Feature: "audio", Model: req.Model}
}
if len(req.Tools) > 0 && p.rules.SupportsTools != nil && !p.rules.SupportsTools(req.Model) {
return &FeatureUnsupportedError{Feature: "tools", Model: req.Model}
}
return nil
}
func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams {
oaiReq := openai.ChatCompletionNewParams{
Model: req.Model,
}
for _, msg := range req.Messages {
oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model))
}
for _, tool := range req.Tools {
oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{
Type: "function",
Function: shared.FunctionDefinitionParam{
Name: tool.Name,
Description: openai.String(tool.Description),
Parameters: openai.FunctionParameters(tool.Schema),
},
})
}
if req.Temperature != nil {
if p.rules.RestrictTemperature == nil || !p.rules.RestrictTemperature(req.Model) {
oaiReq.Temperature = openai.Float(*req.Temperature)
}
}
if req.MaxTokens != nil {
oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens))
}
if req.TopP != nil {
oaiReq.TopP = openai.Float(*req.TopP)
}
if len(req.Stop) > 0 {
oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])}
}
if req.Reasoning != "" {
if p.rules.SupportsReasoning == nil || p.rules.SupportsReasoning(req.Model) {
effort := req.Reasoning
if p.rules.MapReasoningEffort != nil {
effort = p.rules.MapReasoningEffort(effort)
}
oaiReq.ReasoningEffort = shared.ReasoningEffort(effort)
}
}
return oaiReq
}
func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion {
var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam
var textContent param.Opt[string]
for _, img := range msg.Images {
var url string
if img.Base64 != "" {
url = "data:" + img.ContentType + ";base64," + img.Base64
} else if img.URL != "" {
url = img.URL
}
if url != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfImageURL: &openai.ChatCompletionContentPartImageParam{
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
URL: url,
},
},
},
)
}
}
for _, aud := range msg.Audio {
var b64Data string
var format string
if aud.Base64 != "" {
b64Data = aud.Base64
format = audioFormat(aud.ContentType)
} else if aud.URL != "" {
resp, err := http.Get(aud.URL)
if err != nil {
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
b64Data = base64.StdEncoding.EncodeToString(data)
ct := resp.Header.Get("Content-Type")
if ct == "" {
ct = aud.ContentType
}
if ct == "" {
ct = audioFormatFromURL(aud.URL)
}
format = audioFormat(ct)
}
if b64Data != "" && format != "" {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{
InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
Data: b64Data,
Format: format,
},
},
},
)
}
}
if msg.Content != "" {
if len(arrayOfContentParts) > 0 {
arrayOfContentParts = append(arrayOfContentParts,
openai.ChatCompletionContentPartUnionParam{
OfText: &openai.ChatCompletionContentPartTextParam{
Text: msg.Content,
},
},
)
} else {
textContent = openai.String(msg.Content)
}
}
// Determine if this model uses developer messages instead of system
useDeveloper := false
parts := strings.Split(model, "-")
if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' {
useDeveloper = true
}
switch msg.Role {
case "system":
if useDeveloper {
return openai.ChatCompletionMessageParamUnion{
OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{
Content: openai.ChatCompletionDeveloperMessageParamContentUnion{
OfString: textContent,
},
},
}
}
return openai.ChatCompletionMessageParamUnion{
OfSystem: &openai.ChatCompletionSystemMessageParam{
Content: openai.ChatCompletionSystemMessageParamContentUnion{
OfString: textContent,
},
},
}
case "user":
return openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
OfArrayOfContentParts: arrayOfContentParts,
},
},
}
case "assistant":
as := &openai.ChatCompletionAssistantMessageParam{}
if msg.Content != "" {
as.Content.OfString = openai.String(msg.Content)
}
for _, tc := range msg.ToolCalls {
as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{
ID: tc.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: tc.Name,
Arguments: tc.Arguments,
},
})
}
return openai.ChatCompletionMessageParamUnion{OfAssistant: as}
case "tool":
return openai.ChatCompletionMessageParamUnion{
OfTool: &openai.ChatCompletionToolMessageParam{
ToolCallID: msg.ToolCallID,
Content: openai.ChatCompletionToolMessageParamContentUnion{
OfString: openai.String(msg.Content),
},
},
}
}
// Fallback to user message
return openai.ChatCompletionMessageParamUnion{
OfUser: &openai.ChatCompletionUserMessageParam{
Content: openai.ChatCompletionUserMessageParamContentUnion{
OfString: textContent,
},
},
}
}
func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response {
var res provider.Response
if resp == nil || len(resp.Choices) == 0 {
return res
}
choice := resp.Choices[0]
res.Text = choice.Message.Content
res.Thinking = extractReasoningContent(choice.Message.RawJSON())
for _, tc := range choice.Message.ToolCalls {
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: strings.TrimSpace(tc.Function.Arguments),
})
}
if resp.Usage.TotalTokens > 0 {
res.Usage = &provider.Usage{
InputTokens: int(resp.Usage.PromptTokens),
OutputTokens: int(resp.Usage.CompletionTokens),
TotalTokens: int(resp.Usage.TotalTokens),
}
res.Usage.Details = extractUsageDetails(resp.Usage)
}
return res
}
// audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3").
func audioFormat(contentType string) string {
ct := strings.ToLower(contentType)
switch {
case strings.Contains(ct, "wav"):
return "wav"
case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"):
return "mp3"
default:
return "wav"
}
}
// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage.
func extractUsageDetails(usage openai.CompletionUsage) map[string]int {
details := map[string]int{}
if usage.CompletionTokensDetails.ReasoningTokens > 0 {
details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens)
}
if usage.CompletionTokensDetails.AudioTokens > 0 {
details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens)
}
if usage.PromptTokensDetails.CachedTokens > 0 {
details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens)
}
if usage.PromptTokensDetails.AudioTokens > 0 {
details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens)
}
if len(details) == 0 {
return nil
}
return details
}
// extractReasoningContent pulls the non-standard "reasoning_content" string
// from the raw JSON of a message or delta. DeepSeek's reasoner and several
// Groq-hosted reasoning models put their thinking trace in this field rather
// than in OpenAI's standard "reasoning_summary" blocks; the OpenAI Go SDK
// doesn't surface it as a typed field, so we re-parse the raw JSON. Returns
// empty string when the field is absent or unparseable.
func extractReasoningContent(rawJSON string) string {
if rawJSON == "" || !strings.Contains(rawJSON, "reasoning_content") {
return ""
}
var d struct {
ReasoningContent string `json:"reasoning_content"`
}
if err := json.Unmarshal([]byte(rawJSON), &d); err != nil {
return ""
}
return d.ReasoningContent
}
// audioFormatFromURL guesses the audio format from a URL's file extension.
func audioFormatFromURL(u string) string {
ext := strings.ToLower(path.Ext(u))
switch ext {
case ".mp3":
return "audio/mp3"
case ".wav":
return "audio/wav"
default:
return "audio/wav"
}
}