Steve Dudenhoeffer
0993a8e865
Modify `FunctionCall` struct to handle arguments as strings. Add debugging logs to facilitate error tracing and improve JSON unmarshalling in various functions.
149 lines
3.4 KiB
Go
149 lines
3.4 KiB
Go
package go_llm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
oai "github.com/sashabaranov/go-openai"
|
|
"strings"
|
|
)
|
|
|
|
type openaiImpl struct {
|
|
key string
|
|
model string
|
|
}
|
|
|
|
var _ LLM = openaiImpl{}
|
|
|
|
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
|
res := oai.ChatCompletionRequest{
|
|
Model: o.model,
|
|
}
|
|
|
|
for _, msg := range request.Messages {
|
|
m := oai.ChatCompletionMessage{
|
|
Content: msg.Text,
|
|
Role: string(msg.Role),
|
|
Name: msg.Name,
|
|
}
|
|
|
|
for _, img := range msg.Images {
|
|
if img.Base64 != "" {
|
|
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
|
|
Type: "image_url",
|
|
ImageURL: &oai.ChatMessageImageURL{
|
|
URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64),
|
|
},
|
|
})
|
|
} else if img.Url != "" {
|
|
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
|
|
Type: "image_url",
|
|
ImageURL: &oai.ChatMessageImageURL{
|
|
URL: img.Url,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
// openai does not allow Content and MultiContent to be set at the same time, so we need to check
|
|
if len(m.MultiContent) > 0 && m.Content != "" {
|
|
m.MultiContent = append([]oai.ChatMessagePart{{
|
|
Type: "text",
|
|
Text: m.Content,
|
|
}}, m.MultiContent...)
|
|
|
|
m.Content = ""
|
|
}
|
|
|
|
res.Messages = append(res.Messages, m)
|
|
}
|
|
|
|
for _, tool := range request.Toolbox.funcs {
|
|
res.Tools = append(res.Tools, oai.Tool{
|
|
Type: "function",
|
|
Function: &oai.FunctionDefinition{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
Strict: tool.Strict,
|
|
Parameters: tool.Parameters.Definition(),
|
|
},
|
|
})
|
|
|
|
fmt.Println("tool:", tool.Name, tool.Description, tool.Strict, tool.Parameters.Definition())
|
|
}
|
|
|
|
if request.Temperature != nil {
|
|
res.Temperature = *request.Temperature
|
|
}
|
|
|
|
// is this an o1-* model?
|
|
isO1 := strings.Split(o.model, "-")[0] == "o1"
|
|
|
|
if isO1 {
|
|
// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
|
|
// user messages
|
|
|
|
for i, msg := range res.Messages {
|
|
if msg.Role == "system" {
|
|
res.Messages[i].Role = "user"
|
|
}
|
|
}
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
|
|
res := Response{}
|
|
|
|
for _, choice := range response.Choices {
|
|
var toolCalls []ToolCall
|
|
for _, call := range choice.Message.ToolCalls {
|
|
fmt.Println("responseToLLMResponse: call:", call.Function.Arguments)
|
|
toolCall := ToolCall{
|
|
ID: call.ID,
|
|
FunctionCall: FunctionCall{
|
|
Name: call.Function.Name,
|
|
Arguments: call.Function.Arguments,
|
|
},
|
|
}
|
|
|
|
fmt.Println("toolCall.FunctionCall.Arguments:", toolCall.FunctionCall.Arguments)
|
|
|
|
toolCalls = append(toolCalls, toolCall)
|
|
|
|
}
|
|
res.Choices = append(res.Choices, ResponseChoice{
|
|
Content: choice.Message.Content,
|
|
Role: Role(choice.Message.Role),
|
|
Name: choice.Message.Name,
|
|
Refusal: choice.Message.Refusal,
|
|
Calls: toolCalls,
|
|
})
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
|
cl := oai.NewClient(o.key)
|
|
|
|
req := o.requestToOpenAIRequest(request)
|
|
|
|
resp, err := cl.CreateChatCompletion(ctx, req)
|
|
|
|
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
|
|
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
|
|
}
|
|
|
|
return o.responseToLLMResponse(resp), nil
|
|
}
|
|
|
|
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
return openaiImpl{
|
|
key: o.key,
|
|
model: modelVersion,
|
|
}, nil
|
|
}
|