go-llm/openai.go

143 lines
3.1 KiB
Go
Raw Normal View History

2024-10-06 21:02:26 -04:00
package go_llm
2024-10-06 20:01:01 -04:00
import (
"context"
"fmt"
oai "github.com/sashabaranov/go-openai"
"strings"
2024-10-06 20:01:01 -04:00
)
type openaiImpl struct {
2024-10-06 20:01:01 -04:00
key string
model string
}
var _ LLM = openaiImpl{}
2024-10-06 20:01:01 -04:00
func (o openaiImpl) requestToOpenAIRequest(request Request) oai.ChatCompletionRequest {
2024-10-06 20:01:01 -04:00
res := oai.ChatCompletionRequest{
Model: o.model,
}
for _, msg := range request.Messages {
2024-10-06 21:02:26 -04:00
m := oai.ChatCompletionMessage{
2024-10-06 20:01:01 -04:00
Content: msg.Text,
Role: string(msg.Role),
Name: msg.Name,
2024-10-06 21:02:26 -04:00
}
2024-10-07 16:33:57 -04:00
for _, img := range msg.Images {
if img.Base64 != "" {
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
Type: "image_url",
ImageURL: &oai.ChatMessageImageURL{
URL: fmt.Sprintf("data:%s;base64,%s", img.ContentType, img.Base64),
},
})
} else if img.Url != "" {
m.MultiContent = append(m.MultiContent, oai.ChatMessagePart{
Type: "image_url",
ImageURL: &oai.ChatMessageImageURL{
URL: img.Url,
},
})
2024-10-06 21:02:26 -04:00
}
}
// openai does not allow Content and MultiContent to be set at the same time, so we need to check
if len(m.MultiContent) > 0 && m.Content != "" {
m.MultiContent = append([]oai.ChatMessagePart{{
Type: "text",
Text: m.Content,
}}, m.MultiContent...)
m.Content = ""
}
2024-10-06 21:02:26 -04:00
res.Messages = append(res.Messages, m)
2024-10-06 20:01:01 -04:00
}
for _, tool := range request.Toolbox {
res.Tools = append(res.Tools, oai.Tool{
Type: "function",
Function: &oai.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Strict: tool.Strict,
Parameters: tool.Parameters,
},
})
}
2024-10-06 21:02:26 -04:00
if request.Temperature != nil {
res.Temperature = *request.Temperature
}
// is this an o1-* model?
isO1 := strings.Split(o.model, "-")[0] == "o1"
if isO1 {
// o1 models do not support system messages, so if any messages are system messages, we need to convert them to
// user messages
for i, msg := range res.Messages {
if msg.Role == "system" {
res.Messages[i].Role = "user"
}
}
}
2024-10-06 20:01:01 -04:00
return res
}
func (o openaiImpl) responseToLLMResponse(response oai.ChatCompletionResponse) Response {
2024-10-06 20:01:01 -04:00
res := Response{}
for _, choice := range response.Choices {
var tools []ToolCall
for _, call := range choice.Message.ToolCalls {
toolCall := ToolCall{
ID: call.ID,
FunctionCall: FunctionCall{
Name: call.Function.Name,
Arguments: call.Function.Arguments,
},
}
tools = append(tools, toolCall)
}
res.Choices = append(res.Choices, ResponseChoice{
Content: choice.Message.Content,
Role: Role(choice.Message.Role),
Name: choice.Message.Name,
Refusal: choice.Message.Refusal,
})
}
return res
}
func (o openaiImpl) ChatComplete(ctx context.Context, request Request) (Response, error) {
2024-10-06 20:01:01 -04:00
cl := oai.NewClient(o.key)
req := o.requestToOpenAIRequest(request)
resp, err := cl.CreateChatCompletion(ctx, req)
fmt.Println("resp:", fmt.Sprintf("%#v", resp))
if err != nil {
return Response{}, fmt.Errorf("unhandled openaiImpl error: %w", err)
2024-10-06 20:01:01 -04:00
}
return o.responseToLLMResponse(resp), nil
}
func (o openaiImpl) ModelVersion(modelVersion string) (ChatCompletion, error) {
return openaiImpl{
2024-10-06 20:01:01 -04:00
key: o.key,
model: modelVersion,
}, nil
}