// Package openaicompat implements a shared chat-completion Provider for any // service that speaks the OpenAI Chat Completions API (OpenAI itself, DeepSeek, // Moonshot, xAI, Groq, Ollama, and friends). // // Most providers differ from vanilla OpenAI only in endpoint URL and a handful // of per-model quirks (e.g., "this model is text-only", "this model doesn't // accept tools", "drop temperature on reasoning models"). Those quirks are // captured declaratively via Rules, so a concrete provider package is usually // a one-function wrapper that calls New with its own base URL and Rules. package openaicompat import ( "context" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "path" "strings" "github.com/openai/openai-go" "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/param" "github.com/openai/openai-go/shared" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) // Rules encodes provider-specific constraints on top of the OpenAI wire // protocol. The zero value means "no restrictions" and behaves like vanilla // OpenAI. Individual fields are documented inline. type Rules struct { // MaxImagesPerMessage rejects requests whose any single message carries // more images than this cap. 0 means "no cap". MaxImagesPerMessage int // MaxAudioPerMessage rejects requests whose any single message carries // more audio attachments than this cap. 0 means "no cap". MaxAudioPerMessage int // SupportsVision, when non-nil, is consulted for every request that // includes any image attachments. If it returns false for the request's // model, the call fails with a FeatureUnsupportedError before hitting // the network. SupportsVision func(model string) bool // SupportsTools, when non-nil, is consulted for every request that // includes any tool definitions. If it returns false for the model, // the call fails with a FeatureUnsupportedError before hitting the // network. SupportsTools func(model string) bool // SupportsAudio, when non-nil, is consulted for every request that // includes any audio attachments. If it returns false for the model, // the call fails with a FeatureUnsupportedError. SupportsAudio func(model string) bool // RestrictTemperature, when non-nil and returning true for the request's // model, causes the Temperature field to be silently dropped from the // outgoing request. Used by OpenAI o-series and gpt-5* which reject a // user-provided temperature. RestrictTemperature func(model string) bool // CustomizeRequest is a last-mile hook invoked after buildRequest but // before the call is sent. It receives the fully built OpenAI SDK // parameters and may mutate them freely (add headers, flip flags, tweak // response_format, etc.). CustomizeRequest func(params *openai.ChatCompletionNewParams) // SupportsReasoning, when non-nil and returning false for the request's // model, causes the request's Reasoning field to be silently dropped // from the outgoing request. Used by providers (e.g., OpenAI) where // reasoning_effort is rejected on non-reasoning models. nil = always // pass reasoning_effort through when set. SupportsReasoning func(model string) bool // MapReasoningEffort, when non-nil, maps the standardized go-llm // ReasoningLevel ("low"|"medium"|"high") to the provider's wire-level // effort string. Used by xAI which only accepts "low"|"high" (callers // remap "medium" to "high"). nil = pass-through unchanged. MapReasoningEffort func(level string) string } // FeatureUnsupportedError is returned when a Rules predicate rejects a request // because the target model does not support a feature the caller included. type FeatureUnsupportedError struct { Feature string Model string } func (e *FeatureUnsupportedError) Error() string { return fmt.Sprintf("openaicompat: model %q does not support %s", e.Model, e.Feature) } // Provider implements provider.Provider for any OpenAI-compatible endpoint. type Provider struct { apiKey string baseURL string rules Rules } // New creates a Provider. baseURL may be empty to let the OpenAI SDK use its // default; in practice concrete provider packages always pass a default. func New(apiKey, baseURL string, rules Rules) *Provider { return &Provider{apiKey: apiKey, baseURL: baseURL, rules: rules} } // Complete performs a non-streaming completion. func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { if err := p.checkRules(req); err != nil { return provider.Response{}, err } cl := openai.NewClient(p.requestOptions()...) oaiReq := p.buildRequest(req) if p.rules.CustomizeRequest != nil { p.rules.CustomizeRequest(&oaiReq) } resp, err := cl.Chat.Completions.New(ctx, oaiReq) if err != nil { return provider.Response{}, fmt.Errorf("openai completion error: %w", err) } return p.convertResponse(resp), nil } // Stream performs a streaming completion. func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { if err := p.checkRules(req); err != nil { return err } cl := openai.NewClient(p.requestOptions()...) oaiReq := p.buildRequest(req) oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{ IncludeUsage: openai.Bool(true), } if p.rules.CustomizeRequest != nil { p.rules.CustomizeRequest(&oaiReq) } stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq) var fullText strings.Builder var fullThinking strings.Builder var toolCalls []provider.ToolCall toolCallArgs := map[int]*strings.Builder{} var usage *provider.Usage for stream.Next() { chunk := stream.Current() // Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true) if chunk.Usage.TotalTokens > 0 { usage = &provider.Usage{ InputTokens: int(chunk.Usage.PromptTokens), OutputTokens: int(chunk.Usage.CompletionTokens), TotalTokens: int(chunk.Usage.TotalTokens), Details: extractUsageDetails(chunk.Usage), } } for _, choice := range chunk.Choices { // Text delta if choice.Delta.Content != "" { fullText.WriteString(choice.Delta.Content) events <- provider.StreamEvent{ Type: provider.StreamEventText, Text: choice.Delta.Content, } } // Reasoning/thinking delta — DeepSeek and Groq use a non-standard // "reasoning_content" field on the delta. Extract it from the // raw JSON since the OpenAI SDK doesn't surface it as a typed // field. if rc := extractReasoningContent(choice.Delta.RawJSON()); rc != "" { fullThinking.WriteString(rc) events <- provider.StreamEvent{ Type: provider.StreamEventThinking, Text: rc, } } // Tool call deltas for _, tc := range choice.Delta.ToolCalls { idx := int(tc.Index) if tc.ID != "" { // New tool call starting for len(toolCalls) <= idx { toolCalls = append(toolCalls, provider.ToolCall{}) } toolCalls[idx].ID = tc.ID toolCalls[idx].Name = tc.Function.Name toolCallArgs[idx] = &strings.Builder{} events <- provider.StreamEvent{ Type: provider.StreamEventToolStart, ToolCall: &provider.ToolCall{ ID: tc.ID, Name: tc.Function.Name, }, ToolIndex: idx, } } if tc.Function.Arguments != "" { if b, ok := toolCallArgs[idx]; ok { b.WriteString(tc.Function.Arguments) } events <- provider.StreamEvent{ Type: provider.StreamEventToolDelta, ToolIndex: idx, ToolCall: &provider.ToolCall{ Arguments: tc.Function.Arguments, }, } } } } } if err := stream.Err(); err != nil { return fmt.Errorf("openai stream error: %w", err) } // Finalize tool calls for idx := range toolCalls { if b, ok := toolCallArgs[idx]; ok { toolCalls[idx].Arguments = b.String() } events <- provider.StreamEvent{ Type: provider.StreamEventToolEnd, ToolIndex: idx, ToolCall: &toolCalls[idx], } } events <- provider.StreamEvent{ Type: provider.StreamEventDone, Response: &provider.Response{ Text: fullText.String(), Thinking: fullThinking.String(), ToolCalls: toolCalls, Usage: usage, }, } return nil } func (p *Provider) requestOptions() []option.RequestOption { opts := []option.RequestOption{option.WithAPIKey(p.apiKey)} if p.baseURL != "" { opts = append(opts, option.WithBaseURL(p.baseURL)) } return opts } // checkRules applies all Rules predicates against a request and returns an // error if any constraint is violated. Runs before any network call. func (p *Provider) checkRules(req provider.Request) error { var hasImages, hasAudio bool for _, msg := range req.Messages { if len(msg.Images) > 0 { hasImages = true } if len(msg.Audio) > 0 { hasAudio = true } if p.rules.MaxImagesPerMessage > 0 && len(msg.Images) > p.rules.MaxImagesPerMessage { return fmt.Errorf("openaicompat: message has %d images, max allowed is %d for model %q", len(msg.Images), p.rules.MaxImagesPerMessage, req.Model) } if p.rules.MaxAudioPerMessage > 0 && len(msg.Audio) > p.rules.MaxAudioPerMessage { return fmt.Errorf("openaicompat: message has %d audio attachments, max allowed is %d for model %q", len(msg.Audio), p.rules.MaxAudioPerMessage, req.Model) } } if hasImages && p.rules.SupportsVision != nil && !p.rules.SupportsVision(req.Model) { return &FeatureUnsupportedError{Feature: "vision", Model: req.Model} } if hasAudio && p.rules.SupportsAudio != nil && !p.rules.SupportsAudio(req.Model) { return &FeatureUnsupportedError{Feature: "audio", Model: req.Model} } if len(req.Tools) > 0 && p.rules.SupportsTools != nil && !p.rules.SupportsTools(req.Model) { return &FeatureUnsupportedError{Feature: "tools", Model: req.Model} } return nil } func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams { oaiReq := openai.ChatCompletionNewParams{ Model: req.Model, } for _, msg := range req.Messages { oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model)) } for _, tool := range req.Tools { oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{ Type: "function", Function: shared.FunctionDefinitionParam{ Name: tool.Name, Description: openai.String(tool.Description), Parameters: openai.FunctionParameters(tool.Schema), }, }) } if req.Temperature != nil { if p.rules.RestrictTemperature == nil || !p.rules.RestrictTemperature(req.Model) { oaiReq.Temperature = openai.Float(*req.Temperature) } } if req.MaxTokens != nil { oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens)) } if req.TopP != nil { oaiReq.TopP = openai.Float(*req.TopP) } if len(req.Stop) > 0 { oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])} } if req.Reasoning != "" { if p.rules.SupportsReasoning == nil || p.rules.SupportsReasoning(req.Model) { effort := req.Reasoning if p.rules.MapReasoningEffort != nil { effort = p.rules.MapReasoningEffort(effort) } oaiReq.ReasoningEffort = shared.ReasoningEffort(effort) } } return oaiReq } func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion { var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam var textContent param.Opt[string] for _, img := range msg.Images { var url string if img.Base64 != "" { url = "data:" + img.ContentType + ";base64," + img.Base64 } else if img.URL != "" { url = img.URL } if url != "" { arrayOfContentParts = append(arrayOfContentParts, openai.ChatCompletionContentPartUnionParam{ OfImageURL: &openai.ChatCompletionContentPartImageParam{ ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ URL: url, }, }, }, ) } } for _, aud := range msg.Audio { var b64Data string var format string if aud.Base64 != "" { b64Data = aud.Base64 format = audioFormat(aud.ContentType) } else if aud.URL != "" { resp, err := http.Get(aud.URL) if err != nil { continue } data, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { continue } b64Data = base64.StdEncoding.EncodeToString(data) ct := resp.Header.Get("Content-Type") if ct == "" { ct = aud.ContentType } if ct == "" { ct = audioFormatFromURL(aud.URL) } format = audioFormat(ct) } if b64Data != "" && format != "" { arrayOfContentParts = append(arrayOfContentParts, openai.ChatCompletionContentPartUnionParam{ OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{ InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ Data: b64Data, Format: format, }, }, }, ) } } if msg.Content != "" { if len(arrayOfContentParts) > 0 { arrayOfContentParts = append(arrayOfContentParts, openai.ChatCompletionContentPartUnionParam{ OfText: &openai.ChatCompletionContentPartTextParam{ Text: msg.Content, }, }, ) } else { textContent = openai.String(msg.Content) } } // Determine if this model uses developer messages instead of system useDeveloper := false parts := strings.Split(model, "-") if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' { useDeveloper = true } switch msg.Role { case "system": if useDeveloper { return openai.ChatCompletionMessageParamUnion{ OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ OfString: textContent, }, }, } } return openai.ChatCompletionMessageParamUnion{ OfSystem: &openai.ChatCompletionSystemMessageParam{ Content: openai.ChatCompletionSystemMessageParamContentUnion{ OfString: textContent, }, }, } case "user": return openai.ChatCompletionMessageParamUnion{ OfUser: &openai.ChatCompletionUserMessageParam{ Content: openai.ChatCompletionUserMessageParamContentUnion{ OfString: textContent, OfArrayOfContentParts: arrayOfContentParts, }, }, } case "assistant": as := &openai.ChatCompletionAssistantMessageParam{} if msg.Content != "" { as.Content.OfString = openai.String(msg.Content) } for _, tc := range msg.ToolCalls { as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ ID: tc.ID, Function: openai.ChatCompletionMessageToolCallFunctionParam{ Name: tc.Name, Arguments: tc.Arguments, }, }) } return openai.ChatCompletionMessageParamUnion{OfAssistant: as} case "tool": return openai.ChatCompletionMessageParamUnion{ OfTool: &openai.ChatCompletionToolMessageParam{ ToolCallID: msg.ToolCallID, Content: openai.ChatCompletionToolMessageParamContentUnion{ OfString: openai.String(msg.Content), }, }, } } // Fallback to user message return openai.ChatCompletionMessageParamUnion{ OfUser: &openai.ChatCompletionUserMessageParam{ Content: openai.ChatCompletionUserMessageParamContentUnion{ OfString: textContent, }, }, } } func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response { var res provider.Response if resp == nil || len(resp.Choices) == 0 { return res } choice := resp.Choices[0] res.Text = choice.Message.Content res.Thinking = extractReasoningContent(choice.Message.RawJSON()) for _, tc := range choice.Message.ToolCalls { res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ ID: tc.ID, Name: tc.Function.Name, Arguments: strings.TrimSpace(tc.Function.Arguments), }) } if resp.Usage.TotalTokens > 0 { res.Usage = &provider.Usage{ InputTokens: int(resp.Usage.PromptTokens), OutputTokens: int(resp.Usage.CompletionTokens), TotalTokens: int(resp.Usage.TotalTokens), } res.Usage.Details = extractUsageDetails(resp.Usage) } return res } // audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3"). func audioFormat(contentType string) string { ct := strings.ToLower(contentType) switch { case strings.Contains(ct, "wav"): return "wav" case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"): return "mp3" default: return "wav" } } // extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage. func extractUsageDetails(usage openai.CompletionUsage) map[string]int { details := map[string]int{} if usage.CompletionTokensDetails.ReasoningTokens > 0 { details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens) } if usage.CompletionTokensDetails.AudioTokens > 0 { details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens) } if usage.PromptTokensDetails.CachedTokens > 0 { details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens) } if usage.PromptTokensDetails.AudioTokens > 0 { details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens) } if len(details) == 0 { return nil } return details } // extractReasoningContent pulls the non-standard "reasoning_content" string // from the raw JSON of a message or delta. DeepSeek's reasoner and several // Groq-hosted reasoning models put their thinking trace in this field rather // than in OpenAI's standard "reasoning_summary" blocks; the OpenAI Go SDK // doesn't surface it as a typed field, so we re-parse the raw JSON. Returns // empty string when the field is absent or unparseable. func extractReasoningContent(rawJSON string) string { if rawJSON == "" || !strings.Contains(rawJSON, "reasoning_content") { return "" } var d struct { ReasoningContent string `json:"reasoning_content"` } if err := json.Unmarshal([]byte(rawJSON), &d); err != nil { return "" } return d.ReasoningContent } // audioFormatFromURL guesses the audio format from a URL's file extension. func audioFormatFromURL(u string) string { ext := strings.ToLower(path.Ext(u)) switch ext { case ".mp3": return "audio/mp3" case ".wav": return "audio/wav" default: return "audio/wav" } }