// Package google implements the go-llm v2 provider interface for Google (Gemini). package google import ( "context" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "strings" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" "google.golang.org/genai" ) // Provider implements the provider.Provider interface for Google Gemini. type Provider struct { apiKey string } // New creates a new Google provider. func New(apiKey string) *Provider { return &Provider{apiKey: apiKey} } // Complete performs a non-streaming completion. func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { cl, err := genai.NewClient(ctx, &genai.ClientConfig{ APIKey: p.apiKey, Backend: genai.BackendGeminiAPI, }) if err != nil { return provider.Response{}, fmt.Errorf("google client error: %w", err) } contents, cfg := p.buildRequest(req) resp, err := cl.Models.GenerateContent(ctx, req.Model, contents, cfg) if err != nil { return provider.Response{}, fmt.Errorf("google completion error: %w", err) } return p.convertResponse(resp) } // Stream performs a streaming completion. func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { cl, err := genai.NewClient(ctx, &genai.ClientConfig{ APIKey: p.apiKey, Backend: genai.BackendGeminiAPI, }) if err != nil { return fmt.Errorf("google client error: %w", err) } contents, cfg := p.buildRequest(req) var fullText strings.Builder var toolCalls []provider.ToolCall for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) { if err != nil { return fmt.Errorf("google stream error: %w", err) } for _, c := range resp.Candidates { if c.Content == nil { continue } for _, part := range c.Content.Parts { if part.Text != "" { fullText.WriteString(part.Text) events <- provider.StreamEvent{ Type: provider.StreamEventText, Text: part.Text, } } if part.FunctionCall != nil { args, _ := json.Marshal(part.FunctionCall.Args) tc := provider.ToolCall{ ID: part.FunctionCall.Name, Name: part.FunctionCall.Name, Arguments: string(args), } toolCalls = append(toolCalls, tc) events <- provider.StreamEvent{ Type: provider.StreamEventToolStart, ToolCall: &tc, ToolIndex: len(toolCalls) - 1, } events <- provider.StreamEvent{ Type: provider.StreamEventToolEnd, ToolCall: &tc, ToolIndex: len(toolCalls) - 1, } } } } } events <- provider.StreamEvent{ Type: provider.StreamEventDone, Response: &provider.Response{ Text: fullText.String(), ToolCalls: toolCalls, }, } return nil } func (p *Provider) buildRequest(req provider.Request) ([]*genai.Content, *genai.GenerateContentConfig) { var contents []*genai.Content cfg := &genai.GenerateContentConfig{} for _, tool := range req.Tools { cfg.Tools = append(cfg.Tools, &genai.Tool{ FunctionDeclarations: []*genai.FunctionDeclaration{ { Name: tool.Name, Description: tool.Description, Parameters: schemaToGenai(tool.Schema), }, }, }) } if req.Temperature != nil { f := float32(*req.Temperature) cfg.Temperature = &f } if req.MaxTokens != nil { cfg.MaxOutputTokens = int32(*req.MaxTokens) } if req.TopP != nil { f := float32(*req.TopP) cfg.TopP = &f } if len(req.Stop) > 0 { cfg.StopSequences = req.Stop } for _, msg := range req.Messages { var role genai.Role switch msg.Role { case "system": cfg.SystemInstruction = genai.NewContentFromText(msg.Content, genai.RoleUser) continue case "assistant": role = genai.RoleModel case "tool": // Tool results go as function responses (Genai uses RoleUser for function responses) contents = append(contents, &genai.Content{ Role: genai.RoleUser, Parts: []*genai.Part{ { FunctionResponse: &genai.FunctionResponse{ Name: msg.ToolCallID, Response: map[string]any{ "result": msg.Content, }, }, }, }, }) continue default: role = genai.RoleUser } var parts []*genai.Part if msg.Content != "" { parts = append(parts, genai.NewPartFromText(msg.Content)) } // Handle tool calls in assistant messages for _, tc := range msg.ToolCalls { var args map[string]any if tc.Arguments != "" { _ = json.Unmarshal([]byte(tc.Arguments), &args) } parts = append(parts, &genai.Part{ FunctionCall: &genai.FunctionCall{ Name: tc.Name, Args: args, }, }) } for _, img := range msg.Images { if img.URL != "" { // Gemini doesn't support URLs directly; download resp, err := http.Get(img.URL) if err != nil { continue } data, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { continue } mimeType := http.DetectContentType(data) parts = append(parts, genai.NewPartFromBytes(data, mimeType)) } else if img.Base64 != "" { data, err := base64.StdEncoding.DecodeString(img.Base64) if err != nil { continue } parts = append(parts, genai.NewPartFromBytes(data, img.ContentType)) } } contents = append(contents, genai.NewContentFromParts(parts, role)) } return contents, cfg } func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provider.Response, error) { var res provider.Response for _, c := range resp.Candidates { if c.Content == nil { continue } for _, part := range c.Content.Parts { if part.Text != "" { res.Text += part.Text } if part.FunctionCall != nil { args, _ := json.Marshal(part.FunctionCall.Args) res.ToolCalls = append(res.ToolCalls, provider.ToolCall{ ID: part.FunctionCall.Name, Name: part.FunctionCall.Name, Arguments: string(args), }) } } } if resp.UsageMetadata != nil { res.Usage = &provider.Usage{ InputTokens: int(resp.UsageMetadata.PromptTokenCount), OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount), TotalTokens: int(resp.UsageMetadata.TotalTokenCount), } } return res, nil } // schemaToGenai converts a JSON Schema map to a genai.Schema. func schemaToGenai(s map[string]any) *genai.Schema { if s == nil { return nil } schema := &genai.Schema{} if t, ok := s["type"].(string); ok { switch t { case "object": schema.Type = genai.TypeObject case "array": schema.Type = genai.TypeArray case "string": schema.Type = genai.TypeString case "integer": schema.Type = genai.TypeInteger case "number": schema.Type = genai.TypeNumber case "boolean": schema.Type = genai.TypeBoolean } } if desc, ok := s["description"].(string); ok { schema.Description = desc } if props, ok := s["properties"].(map[string]any); ok { schema.Properties = make(map[string]*genai.Schema) for k, v := range props { if vm, ok := v.(map[string]any); ok { schema.Properties[k] = schemaToGenai(vm) } } } if req, ok := s["required"].([]string); ok { schema.Required = req } else if req, ok := s["required"].([]any); ok { for _, r := range req { if rs, ok := r.(string); ok { schema.Required = append(schema.Required, rs) } } } if items, ok := s["items"].(map[string]any); ok { schema.Items = schemaToGenai(items) } if enums, ok := s["enum"].([]string); ok { schema.Enum = enums } else if enums, ok := s["enum"].([]any); ok { for _, e := range enums { if es, ok := e.(string); ok { schema.Enum = append(schema.Enum, es) } } } return schema }