Add support for integers and tool configuration in schema handling
This update introduces support for `jsonschema.Integer` types and updates the logic to handle nested items in schemas. Added a new default error log for unknown types using `slog.Error`. Also, integrated tool configuration with a `FunctionCallingConfig` when `dontRequireTool` is false.
This commit is contained in:
75
openai.go
75
openai.go
@@ -5,70 +5,69 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
oai "github.com/sashabaranov/go-openai"
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/shared"
|
||||
)
|
||||
|
||||
type openaiImpl struct {
|
||||
key string
|
||||
model string
|
||||
key string
|
||||
model string
|
||||
baseUrl string
|
||||
}
|
||||
|
||||
var _ LLM = openaiImpl{}
|
||||
|
||||
func (o openaiImpl) newRequestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
||||
res := oai.ChatCompletionRequest{
|
||||
func (o openaiImpl) newRequestToOpenAIRequest(request Request) openai.ChatCompletionNewParams {
|
||||
res := openai.ChatCompletionNewParams{
|
||||
Model: o.model,
|
||||
}
|
||||
|
||||
for _, i := range request.Conversation {
|
||||
res.Messages = append(res.Messages, i.toChatCompletionMessages()...)
|
||||
res.Messages = append(res.Messages, i.toChatCompletionMessages(o.model)...)
|
||||
}
|
||||
|
||||
for _, msg := range request.Messages {
|
||||
res.Messages = append(res.Messages, msg.toChatCompletionMessages()...)
|
||||
res.Messages = append(res.Messages, msg.toChatCompletionMessages(o.model)...)
|
||||
}
|
||||
|
||||
if request.Toolbox != nil {
|
||||
for _, tool := range request.Toolbox.funcs {
|
||||
res.Tools = append(res.Tools, oai.Tool{
|
||||
res.Tools = append(res.Tools, openai.ChatCompletionToolParam{
|
||||
Type: "function",
|
||||
Function: &oai.FunctionDefinition{
|
||||
Function: shared.FunctionDefinitionParam{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Strict: tool.Strict,
|
||||
Parameters: tool.Parameters.Definition(),
|
||||
Description: openai.String(tool.Description),
|
||||
Strict: openai.Bool(tool.Strict),
|
||||
Parameters: tool.Parameters.FunctionParameters(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if !request.Toolbox.dontRequireTool {
|
||||
res.ToolChoice = "required"
|
||||
res.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
|
||||
OfAuto: openai.String("required"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if request.Temperature != nil {
|
||||
res.Temperature = *request.Temperature
|
||||
}
|
||||
|
||||
// is this an o1-* model?
|
||||
isO1 := strings.Split(o.model, "-")[0] == "o1"
|
||||
|
||||
if isO1 {
|
||||
// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
|
||||
// user messages
|
||||
|
||||
for i, msg := range res.Messages {
|
||||
if msg.Role == "system" {
|
||||
res.Messages[i].Role = "user"
|
||||
}
|
||||
}
|
||||
res.Temperature = openai.Float(*request.Temperature)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
|
||||
res := Response{}
|
||||
func (o openaiImpl) responseToLLMResponse(response *openai.ChatCompletion) Response {
|
||||
var res Response
|
||||
|
||||
if response == nil {
|
||||
return res
|
||||
}
|
||||
|
||||
if len(response.Choices) == 0 {
|
||||
return res
|
||||
}
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
var toolCalls []ToolCall
|
||||
@@ -77,7 +76,7 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
||||
ID: call.ID,
|
||||
FunctionCall: FunctionCall{
|
||||
Name: call.Function.Name,
|
||||
Arguments: call.Function.Arguments,
|
||||
Arguments: strings.TrimSpace(call.Function.Arguments),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -87,7 +86,6 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
||||
res.Choices = append(res.Choices, ResponseChoice{
|
||||
Content: choice.Message.Content,
|
||||
Role: Role(choice.Message.Role),
|
||||
Name: choice.Message.Name,
|
||||
Refusal: choice.Message.Refusal,
|
||||
Calls: toolCalls,
|
||||
})
|
||||
@@ -97,11 +95,20 @@ func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) R
|
||||
}
|
||||
|
||||
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
||||
cl := oai.NewClient(o.key)
|
||||
var opts = []option.RequestOption{
|
||||
option.WithAPIKey(o.key),
|
||||
}
|
||||
|
||||
if o.baseUrl != "" {
|
||||
opts = append(opts, option.WithBaseURL(o.baseUrl))
|
||||
}
|
||||
|
||||
cl := openai.NewClient(opts...)
|
||||
|
||||
req := o.newRequestToOpenAIRequest(request)
|
||||
|
||||
resp, err := cl.CreateChatCompletion(ctx, req)
|
||||
resp, err := cl.Chat.Completions.New(ctx, req)
|
||||
//resp, err := cl.CreateChatCompletion(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
|
||||
|
Reference in New Issue
Block a user