go-llm/openai.go
Steve Dudenhoeffer 7c9eb08cb4 Add support for integers and tool configuration in schema handling
This update introduces support for `jsonschema.Integer` types and updates the logic to handle nested items in schemas. Added a new default error log for unknown types using `slog.Error`. Also, integrated tool configuration with a `FunctionCallingConfig` when `dontRequireTool` is false.
2025-04-06 01:23:10 -04:00

126 lines
2.7 KiB
Go

package go_llm
import (
"context"
"fmt"
"strings"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/shared"
)
type openaiImpl struct {
key string
model string
baseUrl string
}
var _ LLM = openaiImpl{}
func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams {
res := openai.ChatCompletionNewParams{
Model: o.model,
}
for _, i := range request.Conversation {
res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...)
}
for _, msg := range request.Messages {
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
}
if request.Toolbox != nil {
for _, tool := range request.Toolbox.funcs {
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
Type: "function",
Function: shared.FunctionDefinitionParam{
Name: tool.Name,
Description: openai.String(tool.Description),
Strict: openai.Bool(tool.Strict),
Parameters: tool.Parameters.FunctionParameters(),
},
})
}
if !request.Toolbox.dontRequireTool {
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: openai.String("required"),
}
}
}
if request.Temperature != nil {
res.Temperature = openai.Float(*request.Temperature)
}
return res
}
func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response {
var res Response
if response == nil {
return res
}
if len(response.Choices) == 0 {
return res
}
for _, choice := range response.Choices {
var toolCalls []ToolCall
for _, call := range choice.Message.ToolCalls {
toolCall := ToolCall{
ID: call.ID,
FunctionCall: FunctionCall{
Name: call.Function.Name,
Arguments: strings.TrimSpace(call.Function.Arguments),
},
}
toolCalls = append(toolCalls, toolCall)
}
res.Choices = append(res.Choices, ResponseChoice{
Content: choice.Message.Content,
Role: Role(choice.Message.Role),
Refusal: choice.Message.Refusal,
Calls: toolCalls,
})
}
return res
}
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
var opts = []option.RequestOption{
option.WithAPIKey(o.key),
}
if o.baseUrl != "" {
opts = append(opts, option.WithBaseURL(o.baseUrl))
}
cl := openai.NewClient(opts...)
req := o.newRequestToOpenAIRequest(request)
resp, err := cl.Chat.Completions.New(ctx, req)
//resp, err := cl.CreateChatCompletion(ctx, req)
if err != nil {
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
}
return o.responseToLLMResponse(resp), nil
}
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
return openaiImpl{
key: o.key,
model: modelVersion,
}, nil
}