2024-10-06 21:02:26 -04:00
|
|
|
package go_llm
|
2024-10-06 20:01:01 -04:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
2024-10-31 11:21:03 -04:00
|
|
|
"strings"
|
2024-12-28 20:39:57 -05:00
|
|
|
|
|
|
|
oai "github.com/sashabaranov/go-openai"
|
2024-10-06 20:01:01 -04:00
|
|
|
)
|
|
|
|
|
2024-11-08 20:53:12 -05:00
|
|
|
type openaiImpl struct {
|
2024-10-06 20:01:01 -04:00
|
|
|
key string
|
|
|
|
model string
|
|
|
|
}
|
|
|
|
|
2024-11-08 20:53:12 -05:00
|
|
|
var _ LLM = openaiImpl{}
|
2024-10-06 20:01:01 -04:00
|
|
|
|
2024-11-08 20:53:12 -05:00
|
|
|
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
|
2024-10-06 20:01:01 -04:00
|
|
|
res := oai.ChatCompletionRequest{
|
|
|
|
Model: o.model,
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, msg := range request.Messages {
|
2024-10-06 21:02:26 -04:00
|
|
|
m := oai.ChatCompletionMessage{
|
2024-10-06 20:01:01 -04:00
|
|
|
Content: msg.Text,
|
|
|
|
Role: string(msg.Role),
|
|
|
|
Name: msg.Name,
|
2024-10-06 21:02:26 -04:00
|
|
|
}
|
|
|
|
|
2024-10-07 16:33:57 -04:00
|
|
|
for _, img := range msg.Images {
|
|
|
|
if img.Base64 != "" {
|
|
|
|
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
|
|
|
|
Type: "image_url",
|
|
|
|
ImageURL: &oai.ChatMessageImageURL{
|
|
|
|
URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64),
|
|
|
|
},
|
|
|
|
})
|
|
|
|
} else if img.Url != "" {
|
|
|
|
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
|
|
|
|
Type: "image_url",
|
|
|
|
ImageURL: &oai.ChatMessageImageURL{
|
|
|
|
URL: img.Url,
|
|
|
|
},
|
|
|
|
})
|
2024-10-06 21:02:26 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-10-20 21:50:12 -04:00
|
|
|
// openai does not allow Content and MultiContent to be set at the same time, so we need to check
|
|
|
|
if len(m.MultiContent) > 0 && m.Content != "" {
|
|
|
|
m.MultiContent = append([]oai.ChatMessagePart{{
|
|
|
|
Type: "text",
|
|
|
|
Text: m.Content,
|
|
|
|
}}, m.MultiContent...)
|
|
|
|
|
|
|
|
m.Content = ""
|
|
|
|
}
|
|
|
|
|
2024-10-06 21:02:26 -04:00
|
|
|
res.Messages = append(res.Messages, m)
|
2024-10-06 20:01:01 -04:00
|
|
|
}
|
|
|
|
|
2024-12-28 20:39:57 -05:00
|
|
|
if request.Toolbox != nil {
|
|
|
|
for _, tool := range request.Toolbox.funcs {
|
|
|
|
res.Tools = append(res.Tools, oai.Tool{
|
|
|
|
Type: "function",
|
|
|
|
Function: &oai.FunctionDefinition{
|
|
|
|
Name: tool.Name,
|
|
|
|
Description: tool.Description,
|
|
|
|
Strict: tool.Strict,
|
|
|
|
Parameters: tool.Parameters.Definition(),
|
|
|
|
},
|
|
|
|
})
|
|
|
|
|
|
|
|
fmt.Println("tool:", tool.Name, tool.Description, tool.Strict, tool.Parameters.Definition())
|
|
|
|
}
|
2024-10-06 20:01:01 -04:00
|
|
|
}
|
|
|
|
|
2024-10-06 21:02:26 -04:00
|
|
|
if request.Temperature != nil {
|
|
|
|
res.Temperature = *request.Temperature
|
|
|
|
}
|
|
|
|
|
2024-10-31 11:21:03 -04:00
|
|
|
// is this an o1-* model?
|
|
|
|
isO1 := strings.Split(o.model, "-")[0] == "o1"
|
|
|
|
|
|
|
|
if isO1 {
|
|
|
|
// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
|
|
|
|
// user messages
|
|
|
|
|
|
|
|
for i, msg := range res.Messages {
|
|
|
|
if msg.Role == "system" {
|
|
|
|
res.Messages[i].Role = "user"
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-10-06 20:01:01 -04:00
|
|
|
return res
|
|
|
|
}
|
|
|
|
|
2024-11-08 20:53:12 -05:00
|
|
|
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
|
2024-10-06 20:01:01 -04:00
|
|
|
res := Response{}
|
|
|
|
|
|
|
|
for _, choice := range response.Choices {
|
2024-11-11 00:23:01 -05:00
|
|
|
var toolCalls []ToolCall
|
2024-10-06 20:01:01 -04:00
|
|
|
for _, call := range choice.Message.ToolCalls {
|
2024-11-11 00:23:01 -05:00
|
|
|
fmt.Println("responseToLLMResponse: call:", call.Function.Arguments)
|
2024-10-06 20:01:01 -04:00
|
|
|
toolCall := ToolCall{
|
|
|
|
ID: call.ID,
|
|
|
|
FunctionCall: FunctionCall{
|
|
|
|
Name: call.Function.Name,
|
|
|
|
Arguments: call.Function.Arguments,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
2024-11-11 00:23:01 -05:00
|
|
|
fmt.Println("toolCall.FunctionCall.Arguments:", toolCall.FunctionCall.Arguments)
|
|
|
|
|
|
|
|
toolCalls = append(toolCalls, toolCall)
|
2024-10-06 20:01:01 -04:00
|
|
|
|
|
|
|
}
|
|
|
|
res.Choices = append(res.Choices, ResponseChoice{
|
|
|
|
Content: choice.Message.Content,
|
|
|
|
Role: Role(choice.Message.Role),
|
|
|
|
Name: choice.Message.Name,
|
|
|
|
Refusal: choice.Message.Refusal,
|
2024-11-11 00:23:01 -05:00
|
|
|
Calls: toolCalls,
|
2024-10-06 20:01:01 -04:00
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
return res
|
|
|
|
}
|
|
|
|
|
2024-11-08 20:53:12 -05:00
|
|
|
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
|
2024-10-06 20:01:01 -04:00
|
|
|
cl := oai.NewClient(o.key)
|
|
|
|
|
|
|
|
req := o.requestToOpenAIRequest(request)
|
|
|
|
|
|
|
|
resp, err := cl.CreateChatCompletion(ctx, req)
|
|
|
|
|
|
|
|
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
|
|
|
|
|
|
|
|
if err != nil {
|
2024-11-08 20:53:12 -05:00
|
|
|
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
|
2024-10-06 20:01:01 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
return o.responseToLLMResponse(resp), nil
|
|
|
|
}
|
|
|
|
|
2024-11-08 20:53:12 -05:00
|
|
|
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
|
|
|
|
return openaiImpl{
|
2024-10-06 20:01:01 -04:00
|
|
|
key: o.key,
|
|
|
|
model: modelVersion,
|
|
|
|
}, nil
|
|
|
|
}
|