package llm import ( "context" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "google.golang.org/genai" ) type googleImpl struct { key string model string } var _ LLM = googleImpl{} func (g googleImpl) ModelVersion(modelVersion string) (ChatCompletion, error) { g.model = modelVersion return g, nil } func (g googleImpl) requestToContents(in Request) ([]*genai.Content, *genai.GenerateContentConfig) { var contents []*genai.Content var cfg genai.GenerateContentConfig for _, tool := range in.Toolbox.Functions() { cfg.Tools = append(cfg.Tools, &genai.Tool{ FunctionDeclarations: []*genai.FunctionDeclaration{ { Name: tool.Name, Description: tool.Description, Parameters: tool.Parameters.GoogleParameters(), }, }, }) } if in.Toolbox.RequiresTool() { cfg.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{ Mode: genai.FunctionCallingConfigModeAny, }} } for _, c := range in.Messages { var role genai.Role switch c.Role { case RoleAssistant, RoleSystem: role = genai.RoleModel case RoleUser: role = genai.RoleUser } var parts []*genai.Part if c.Text != "" { parts = append(parts, genai.NewPartFromText(c.Text)) } for _, img := range c.Images { if img.Url != "" { // gemini does not support URLs, so we need to download the image and convert it to a blob resp, err := http.Get(img.Url) if err != nil { panic(fmt.Sprintf("error downloading image: %v", err)) } defer resp.Body.Close() if resp.ContentLength > 20*1024*1024 { panic(fmt.Sprintf("image size exceeds 20MB: %d bytes", resp.ContentLength)) } data, err := io.ReadAll(resp.Body) if err != nil { panic(fmt.Sprintf("error reading image data: %v", err)) } mimeType := http.DetectContentType(data) switch mimeType { case "image/jpeg", "image/png", "image/gif": // MIME type is valid default: panic(fmt.Sprintf("unsupported image MIME type: %s", mimeType)) } parts = append(parts, genai.NewPartFromBytes(data, mimeType)) } else { b, e := base64.StdEncoding.DecodeString(img.Base64) if e != nil { panic(fmt.Sprintf("error decoding base64: %v", e)) } parts = append(parts, genai.NewPartFromBytes(b, img.ContentType)) } } contents = append(contents, genai.NewContentFromParts(parts, role)) } return contents, &cfg } func (g googleImpl) responseToLLMResponse(in *genai.GenerateContentResponse) (Response, error) { res := Response{} for _, c := range in.Candidates { var choice ResponseChoice var set = false if c.Content != nil { for _, p := range c.Content.Parts { if p.Text != "" { set = true choice.Content = p.Text } else if p.FunctionCall != nil { v := p.FunctionCall b, e := json.Marshal(v.Args) if e != nil { return Response{}, fmt.Errorf("error marshalling args: %w", e) } call := ToolCall{ ID: v.Name, FunctionCall: FunctionCall{ Name: v.Name, Arguments: string(b), }, } choice.Calls = append(choice.Calls, call) set = true } } } if set { choice.Role = RoleAssistant res.Choices = append(res.Choices, choice) } } return res, nil } func (g googleImpl) ChatComplete(ctx context.Context, req Request) (Response, error) { cl, err := genai.NewClient(ctx, &genai.ClientConfig{ APIKey: g.key, Backend: genai.BackendGeminiAPI, }) if err != nil { return Response{}, fmt.Errorf("error creating genai client: %w", err) } contents, cfg := g.requestToContents(req) resp, err := cl.Models.GenerateContent(ctx, g.model, contents, cfg) if err != nil { return Response{}, fmt.Errorf("error generating content: %w", err) } return g.responseToLLMResponse(resp) }