// Package openai implements the go-llm v2 provider interface for OpenAI. package openai import ( "context" "encoding/base64" "fmt" "io" "net/http" "path" "strings" "github.com/openai/openai-go" "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/param" "github.com/openai/openai-go/shared" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) // Provider implements the provider.Provider interface for OpenAI. type Provider struct { apiKey string baseURL string } // New creates a new OpenAI provider. func New(apiKey string, baseURL string) *Provider { return &Provider{apiKey: apiKey, baseURL: baseURL} } // Complete performs a non-streaming completion. func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { var opts []option.RequestOption opts = append(opts, option.WithAPIKey(p.apiKey)) if p.baseURL != "" { opts = append(opts, option.WithBaseURL(p.baseURL)) } cl := openai.NewClient(opts...) oaiReq := p.buildRequest(req) resp, err := cl.Chat.Completions.New(ctx, oaiReq) if err != nil { return provider.Response{}, fmt.Errorf("openai completion error: %w", err) } return p.convertResponse(resp), nil } // Stream performs a streaming completion. func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { var opts []option.RequestOption opts = append(opts, option.WithAPIKey(p.apiKey)) if p.baseURL != "" { opts = append(opts, option.WithBaseURL(p.baseURL)) } cl := openai.NewClient(opts...) oaiReq := p.buildRequest(req) stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq) var fullText strings.Builder var toolCalls []provider.ToolCall toolCallArgs := map[int]*strings.Builder{} for stream.Next() { chunk := stream.Current() for _, choice := range chunk.Choices { // Text delta if choice.Delta.Content != "" { fullText.WriteString(choice.Delta.Content) events <- provider.StreamEvent{ Type: provider.StreamEventText, Text: choice.Delta.Content, } } // Tool call deltas for _, tc := range choice.Delta.ToolCalls { idx := int(tc.Index) if tc.ID != "" { // New tool call starting for len(toolCalls) <= idx { toolCalls = append(toolCalls, provider.ToolCall{}) } toolCalls[idx].ID = tc.ID toolCalls[idx].Name = tc.Function.Name toolCallArgs[idx] = &strings.Builder{} events <- provider.StreamEvent{ Type: provider.StreamEventToolStart, ToolCall: &provider.ToolCall{ ID: tc.ID, Name: tc.Function.Name, }, ToolIndex: idx, } } if tc.Function.Arguments != "" { if b, ok := toolCallArgs[idx]; ok { b.WriteString(tc.Function.Arguments) } events <- provider.StreamEvent{ Type: provider.StreamEventToolDelta, ToolIndex: idx, ToolCall: &provider.ToolCall{ Arguments: tc.Function.Arguments, }, } } } } } if err := stream.Err(); err != nil { return fmt.Errorf("openai stream error: %w", err) } // Finalize tool calls for idx := range toolCalls { if b, ok := toolCallArgs[idx]; ok { toolCalls[idx].Arguments = b.String() } events <- provider.StreamEvent{ Type: provider.StreamEventToolEnd, ToolIndex: idx, ToolCall: &toolCalls[idx], } } // Send done event events <- provider.StreamEvent{ Type: provider.StreamEventDone, Response: &provider.Response{ Text: fullText.String(), ToolCalls: toolCalls, }, } return nil } func (p *Provider) buildRequest(req provider.Request) openai.ChatCompletionNewParams { oaiReq := openai.ChatCompletionNewParams{ Model: req.Model, } for _, msg := range req.Messages { oaiReq.Messages = append(oaiReq.Messages, convertMessage(msg, req.Model)) } for _, tool := range req.Tools { oaiReq.Tools = append(oaiReq.Tools, openai.ChatCompletionToolParam{ Type: "function", Function: shared.FunctionDefinitionParam{ Name: tool.Name, Description: openai.String(tool.Description), Parameters: openai.FunctionParameters(tool.Schema), }, }) } if req.Temperature != nil { // o* and gpt-5* models don't support custom temperatures if !strings.HasPrefix(req.Model, "o") && !strings.HasPrefix(req.Model, "gpt-5") { oaiReq.Temperature = openai.Float(*req.Temperature) } } if req.MaxTokens != nil { oaiReq.MaxCompletionTokens = openai.Int(int64(*req.MaxTokens)) } if req.TopP != nil { oaiReq.TopP = openai.Float(*req.TopP) } if len(req.Stop) > 0 { oaiReq.Stop = openai.ChatCompletionNewParamsStopUnion{OfString: openai.String(req.Stop[0])} } return oaiReq } func convertMessage(msg provider.Message, model string) openai.ChatCompletionMessageParamUnion { var arrayOfContentParts []openai.ChatCompletionContentPartUnionParam var textContent param.Opt[string] for _, img := range msg.Images { var url string if img.Base64 != "" { url = "data:" + img.ContentType + ";base64," + img.Base64 } else if img.URL != "" { url = img.URL } if url != "" { arrayOfContentParts = append(arrayOfContentParts, openai.ChatCompletionContentPartUnionParam{ OfImageURL: &openai.ChatCompletionContentPartImageParam{ ImageURL: openai.ChatCompletionContentPartImageImageURLParam{ URL: url, }, }, }, ) } } for _, aud := range msg.Audio { var b64Data string var format string if aud.Base64 != "" { b64Data = aud.Base64 format = audioFormat(aud.ContentType) } else if aud.URL != "" { resp, err := http.Get(aud.URL) if err != nil { continue } data, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { continue } b64Data = base64.StdEncoding.EncodeToString(data) ct := resp.Header.Get("Content-Type") if ct == "" { ct = aud.ContentType } if ct == "" { ct = audioFormatFromURL(aud.URL) } format = audioFormat(ct) } if b64Data != "" && format != "" { arrayOfContentParts = append(arrayOfContentParts, openai.ChatCompletionContentPartUnionParam{ OfInputAudio: &openai.ChatCompletionContentPartInputAudioParam{ InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{ Data: b64Data, Format: format, }, }, }, ) } } if msg.Content != "" { if len(arrayOfContentParts) > 0 { arrayOfContentParts = append(arrayOfContentParts, openai.ChatCompletionContentPartUnionParam{ OfText: &openai.ChatCompletionContentPartTextParam{ Text: msg.Content, }, }, ) } else { textContent = openai.String(msg.Content) } } // Determine if this model uses developer messages instead of system useDeveloper := false parts := strings.Split(model, "-") if len(parts) > 1 && len(parts[0]) > 0 && parts[0][0] == 'o' { useDeveloper = true } switch msg.Role { case "system": if useDeveloper { return openai.ChatCompletionMessageParamUnion{ OfDeveloper: &openai.ChatCompletionDeveloperMessageParam{ Content: openai.ChatCompletionDeveloperMessageParamContentUnion{ OfString: textContent, }, }, } } return openai.ChatCompletionMessageParamUnion{ OfSystem: &openai.ChatCompletionSystemMessageParam{ Content: openai.ChatCompletionSystemMessageParamContentUnion{ OfString: textContent, }, }, } case "user": return openai.ChatCompletionMessageParamUnion{ OfUser: &openai.ChatCompletionUserMessageParam{ Content: openai.ChatCompletionUserMessageParamContentUnion{ OfString: textContent, OfArrayOfContentParts: arrayOfContentParts, }, }, } case "assistant": as := &openai.ChatCompletionAssistantMessageParam{} if msg.Content != "" { as.Content.OfString = openai.String(msg.Content) } for _, tc := range msg.ToolCalls { as.ToolCalls = append(as.ToolCalls, openai.ChatCompletionMessageToolCallParam{ ID: tc.ID, Function: openai.ChatCompletionMessageToolCallFunctionParam{ Name: tc.Name, Arguments: tc.Arguments, }, }) } return openai.ChatCompletionMessageParamUnion{OfAssistant: as} case "tool": return openai.ChatCompletionMessageParamUnion{ OfTool: &openai.ChatCompletionToolMessageParam{ ToolCallID: msg.ToolCallID, Content: openai.ChatCompletionToolMessageParamContentUnion{ OfString: openai.String(msg.Content), }, }, } } // Fallback to user message return openai.ChatCompletionMessageParamUnion{ OfUser: &openai.ChatCompletionUserMessageParam{ Content: openai.ChatCompletionUserMessageParamContentUnion{ OfString: textContent, }, }, } } func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Response { var res provider.Response if resp == nil || len(resp.Choices) == 0 { return res } choice := resp.Choices[0] res.Text = choice.Message.Content for _, tc := range choice.Message.ToolCalls { res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ ID: tc.ID, Name: tc.Function.Name, Arguments: strings.TrimSpace(tc.Function.Arguments), }) } if resp.Usage.TotalTokens > 0 { res.Usage = &provider.Usage{ InputTokens: int(resp.Usage.PromptTokens), OutputTokens: int(resp.Usage.CompletionTokens), TotalTokens: int(resp.Usage.TotalTokens), } } return res } // audioFormat converts a MIME type to an OpenAI audio format string ("wav" or "mp3"). func audioFormat(contentType string) string { ct := strings.ToLower(contentType) switch { case strings.Contains(ct, "wav"): return "wav" case strings.Contains(ct, "mp3"), strings.Contains(ct, "mpeg"): return "mp3" default: return "wav" } } // audioFormatFromURL guesses the audio format from a URL's file extension. func audioFormatFromURL(u string) string { ext := strings.ToLower(path.Ext(u)) switch ext { case ".mp3": return "audio/mp3" case ".wav": return "audio/wav" default: return "audio/wav" } }