Add support for integers and tool configuration in schema handling

This update introduces support for `jsonschema.Integer` types and updates the logic to handle nested items in schemas. Added a new default error log for unknown types using `slog.Error`. Also, integrated tool configuration with a `FunctionCallingConfig` when `dontRequireTool` is false.
This commit is contained in:
2025-04-06 01:23:10 -04:00
parent ff5e4ca7b0
commit 7c9eb08cb4
13 changed files with 267 additions and 96 deletions

View File

@@ -5,70 +5,69 @@ import (
"fmt"
"strings"
oai "github.com/sashabaranov/go-openai"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/shared"
)
type openaiImpl struct {
key string
model string
key string
model string
baseUrl string
}
var _ LLM = openaiImpl{}
func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
res := oai.ChatCompletionRequest{
func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams {
res := openai.ChatCompletionNewParams{
Model: o.model,
}
for _, i := range request.Conversation {
res.Messages = append(res.Messages, i.toChatCompletionMessages()...)
res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...)
}
for _, msg := range request.Messages {
res.Messages = append(res.Messages, msg.toChatCompletionMessages()...)
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
}
if request.Toolbox != nil {
for _, tool := range request.Toolbox.funcs {
res.Tools = append(res.Tools, oai.Tool{
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
Type: "function",
Function: &oai.FunctionDefinition{
Function: shared.FunctionDefinitionParam{
Name: tool.Name,
Description: tool.Description,
Strict: tool.Strict,
Parameters: tool.Parameters.Definition(),
Description: openai.String(tool.Description),
Strict: openai.Bool(tool.Strict),
Parameters: tool.Parameters.FunctionParameters(),
},
})
}
if !request.Toolbox.dontRequireTool {
res.ToolChoice = "required"
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: openai.String("required"),
}
}
}
if request.Temperature != nil {
res.Temperature = *request.Temperature
}
// is this an o1-* model?
isO1 := strings.Split(o.model, "-")[0] == "o1"
if isO1 {
// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
// user messages
for i, msg := range res.Messages {
if msg.Role == "system" {
res.Messages[i].Role = "user"
}
}
res.Temperature = openai.Float(*request.Temperature)
}
return res
}
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
res := Response{}
func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response {
var res Response
if response == nil {
return res
}
if len(response.Choices) == 0 {
return res
}
for _, choice := range response.Choices {
var toolCalls []ToolCall
@@ -77,7 +76,7 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
ID: call.ID,
FunctionCall: FunctionCall{
Name: call.Function.Name,
Arguments: call.Function.Arguments,
Arguments: strings.TrimSpace(call.Function.Arguments),
},
}
@@ -87,7 +86,6 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
res.Choices = append(res.Choices, ResponseChoice{
Content: choice.Message.Content,
Role: Role(choice.Message.Role),
Name: choice.Message.Name,
Refusal: choice.Message.Refusal,
Calls: toolCalls,
})
@@ -97,11 +95,20 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
}
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
cl := oai.NewClient(o.key)
var opts = []option.RequestOption{
option.WithAPIKey(o.key),
}
if o.baseUrl != "" {
opts = append(opts, option.WithBaseURL(o.baseUrl))
}
cl := openai.NewClient(opts...)
req := o.newRequestToOpenAIRequest(request)
resp, err := cl.CreateChatCompletion(ctx, req)
resp, err := cl.Chat.Completions.New(ctx, req)
//resp, err := cl.CreateChatCompletion(ctx, req)
if err != nil {
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)