go-llm/google.go

191 lines
4.3 KiB
Go

package go_llm
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)
type google struct {
key string
model string
}
func (g google) ModelVersion(modelVersion string) (ChatCompletion, error) {
g.model = modelVersion
return g, nil
}
func (g google) requestToChatHistory(in Request, model *genai.GenerativeModel) (*genai.GenerativeModel, *genai.ChatSession, []genai.Part) {
res := *model
for _, tool := range in.Toolbox.functions {
res.Tools = append(res.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: tool.Parameters.GoogleParameters(),
},
},
})
}
if !in.Toolbox.RequiresTool() {
res.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingAny,
}}
}
cs := res.StartChat()
for i, c := range in.Messages {
content := genai.NewUserContent(genai.Text(c.Text))
switch c.Role {
case RoleAssistant, RoleSystem:
content.Role = "model"
case RoleUser:
content.Role = "user"
}
for _, img := range c.Images {
if img.Url != "" {
// gemini does not support URLs, so we need to download the image and convert it to a blob
// Download the image from the URL
resp, err := http.Get(img.Url)
if err != nil {
panic(fmt.Sprintf("error downloading image: %v", err))
}
defer resp.Body.Close()
// Check the Content-Length to ensure it's not over 20MB
if resp.ContentLength > 20*1024*1024 {
panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength))
}
// Read the content into a byte slice
data, err := io.ReadAll(resp.Body)
if err != nil {
panic(fmt.Sprintf("error reading image data: %v", err))
}
// Ensure the MIME type is appropriate
mimeType := http.DetectContentType(data)
switch mimeType {
case "image/jpeg", "image/png", "image/gif":
// MIME type is valid
default:
panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType))
}
// Create a genai.Blob using the validated image data
content.Parts = append(content.Parts, genai.Blob{
MIMEType: mimeType,
Data: data,
})
} else {
// convert base64 to blob
b, e := base64.StdEncoding.DecodeString(img.Base64)
if e != nil {
panic(fmt.Sprintf("error decoding base64: %v", e))
}
content.Parts = append(content.Parts, genai.Blob{
MIMEType: img.ContentType,
Data: b,
})
}
}
// if this is the last message, we want to add to history, we want it to be the parts
if i == len(in.Messages)-1 {
return &res, cs, content.Parts
}
cs.History = append(cs.History, content)
}
return &res, cs, nil
}
func (g google) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) {
res := Response{}
for _, c := range in.Candidates {
var choice ResponseChoice
var set = false
if c.Content != nil {
for _, p := range c.Content.Parts {
switch p.(type) {
case genai.Text:
choice.Content = string(p.(genai.Text))
set = true
case genai.FunctionCall:
v := p.(genai.FunctionCall)
b, e := json.Marshal(v.Args)
if e != nil {
return Response{}, fmt.Errorf("error marshalling args: %w", e)
}
call := ToolCall{
ID: v.Name,
FunctionCall: FunctionCall{
Name: v.Name,
Arguments: string(b),
},
}
choice.Calls = append(choice.Calls, call)
set = true
default:
return Response{}, fmt.Errorf("unknown part type: %T", p)
}
}
}
if set {
choice.Role = RoleAssistant
res.Choices = append(res.Choices, choice)
}
}
return res, nil
}
func (g google) ChatComplete(ctx context.Context, req Request) (Response, error) {
cl, err := genai.NewClient(ctx, option.WithAPIKey(g.key))
if err != nil {
return Response{}, fmt.Errorf("error creating genai client: %w", err)
}
model := cl.GenerativeModel(g.model)
_, cs, parts := g.requestToChatHistory(req, model)
resp, err := cs.SendMessage(ctx, parts...)
//parts := g.requestToGoogleRequest(req, model)
//resp, err := model.GenerateContent(ctx, parts...)
if err != nil {
return Response{}, fmt.Errorf("error generating content: %w", err)
}
return g.responseToLLMResponse(resp)
}