Files
go-llm/v2/google/google.go
Steve Dudenhoeffer 7e1705c385
All checks were successful
CI / Lint (push) Successful in 9m37s
CI / Root Module (push) Successful in 10m53s
CI / V2 Module (push) Successful in 11m9s
feat: add audio input support to v2 providers
Add Audio struct alongside Image for sending audio attachments to
multimodal LLMs. OpenAI uses input_audio content parts (wav/mp3),
Google Gemini uses genai.NewPartFromBytes, and Anthropic skips
audio gracefully since it's not supported.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 21:00:56 -05:00

356 lines
8.2 KiB
Go

// Package google implements the go-llm v2 provider interface for Google (Gemini).
package google
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
"google.golang.org/genai"
)
// Provider implements the provider.Provider interface for Google Gemini.
type Provider struct {
apiKey string
}
// New creates a new Google provider.
func New(apiKey string) *Provider {
return &Provider{apiKey: apiKey}
}
// Complete performs a non-streaming completion.
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: p.apiKey,
Backend: genai.BackendGeminiAPI,
})
if err != nil {
return provider.Response{}, fmt.Errorf("google client error: %w", err)
}
contents, cfg := p.buildRequest(req)
resp, err := cl.Models.GenerateContent(ctx, req.Model, contents, cfg)
if err != nil {
return provider.Response{}, fmt.Errorf("google completion error: %w", err)
}
return p.convertResponse(resp)
}
// Stream performs a streaming completion.
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: p.apiKey,
Backend: genai.BackendGeminiAPI,
})
if err != nil {
return fmt.Errorf("google client error: %w", err)
}
contents, cfg := p.buildRequest(req)
var fullText strings.Builder
var toolCalls []provider.ToolCall
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
if err != nil {
return fmt.Errorf("google stream error: %w", err)
}
for _, c := range resp.Candidates {
if c.Content == nil {
continue
}
for _, part := range c.Content.Parts {
if part.Text != "" {
fullText.WriteString(part.Text)
events <- provider.StreamEvent{
Type: provider.StreamEventText,
Text: part.Text,
}
}
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
tc := provider.ToolCall{
ID: part.FunctionCall.Name,
Name: part.FunctionCall.Name,
Arguments: string(args),
}
toolCalls = append(toolCalls, tc)
events <- provider.StreamEvent{
Type: provider.StreamEventToolStart,
ToolCall: &tc,
ToolIndex: len(toolCalls) - 1,
}
events <- provider.StreamEvent{
Type: provider.StreamEventToolEnd,
ToolCall: &tc,
ToolIndex: len(toolCalls) - 1,
}
}
}
}
}
events <- provider.StreamEvent{
Type: provider.StreamEventDone,
Response: &provider.Response{
Text: fullText.String(),
ToolCalls: toolCalls,
},
}
return nil
}
func (p *Provider) buildRequest(req provider.Request) ([]*genai.Content, *genai.GenerateContentConfig) {
var contents []*genai.Content
cfg := &genai.GenerateContentConfig{}
for _, tool := range req.Tools {
cfg.Tools = append(cfg.Tools, &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: schemaToGenai(tool.Schema),
},
},
})
}
if req.Temperature != nil {
f := float32(*req.Temperature)
cfg.Temperature = &f
}
if req.MaxTokens != nil {
cfg.MaxOutputTokens = int32(*req.MaxTokens)
}
if req.TopP != nil {
f := float32(*req.TopP)
cfg.TopP = &f
}
if len(req.Stop) > 0 {
cfg.StopSequences = req.Stop
}
for _, msg := range req.Messages {
var role genai.Role
switch msg.Role {
case "system":
cfg.SystemInstruction = genai.NewContentFromText(msg.Content, genai.RoleUser)
continue
case "assistant":
role = genai.RoleModel
case "tool":
// Tool results go as function responses (Genai uses RoleUser for function responses)
contents = append(contents, &genai.Content{
Role: genai.RoleUser,
Parts: []*genai.Part{
{
FunctionResponse: &genai.FunctionResponse{
Name: msg.ToolCallID,
Response: map[string]any{
"result": msg.Content,
},
},
},
},
})
continue
default:
role = genai.RoleUser
}
var parts []*genai.Part
if msg.Content != "" {
parts = append(parts, genai.NewPartFromText(msg.Content))
}
// Handle tool calls in assistant messages
for _, tc := range msg.ToolCalls {
var args map[string]any
if tc.Arguments != "" {
_ = json.Unmarshal([]byte(tc.Arguments), &args)
}
parts = append(parts, &genai.Part{
FunctionCall: &genai.FunctionCall{
Name: tc.Name,
Args: args,
},
})
}
for _, img := range msg.Images {
if img.URL != "" {
// Gemini doesn't support URLs directly; download
resp, err := http.Get(img.URL)
if err != nil {
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
mimeType := http.DetectContentType(data)
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
} else if img.Base64 != "" {
data, err := base64.StdEncoding.DecodeString(img.Base64)
if err != nil {
continue
}
parts = append(parts, genai.NewPartFromBytes(data, img.ContentType))
}
}
for _, aud := range msg.Audio {
if aud.URL != "" {
resp, err := http.Get(aud.URL)
if err != nil {
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
continue
}
mimeType := resp.Header.Get("Content-Type")
if mimeType == "" {
mimeType = aud.ContentType
}
if mimeType == "" {
mimeType = "audio/wav"
}
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
} else if aud.Base64 != "" {
data, err := base64.StdEncoding.DecodeString(aud.Base64)
if err != nil {
continue
}
ct := aud.ContentType
if ct == "" {
ct = "audio/wav"
}
parts = append(parts, genai.NewPartFromBytes(data, ct))
}
}
contents = append(contents, genai.NewContentFromParts(parts, role))
}
return contents, cfg
}
func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provider.Response, error) {
var res provider.Response
for _, c := range resp.Candidates {
if c.Content == nil {
continue
}
for _, part := range c.Content.Parts {
if part.Text != "" {
res.Text += part.Text
}
if part.FunctionCall != nil {
args, _ := json.Marshal(part.FunctionCall.Args)
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
ID: part.FunctionCall.Name,
Name: part.FunctionCall.Name,
Arguments: string(args),
})
}
}
}
if resp.UsageMetadata != nil {
res.Usage = &provider.Usage{
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
}
}
return res, nil
}
// schemaToGenai converts a JSON Schema map to a genai.Schema.
func schemaToGenai(s map[string]any) *genai.Schema {
if s == nil {
return nil
}
schema := &genai.Schema{}
if t, ok := s["type"].(string); ok {
switch t {
case "object":
schema.Type = genai.TypeObject
case "array":
schema.Type = genai.TypeArray
case "string":
schema.Type = genai.TypeString
case "integer":
schema.Type = genai.TypeInteger
case "number":
schema.Type = genai.TypeNumber
case "boolean":
schema.Type = genai.TypeBoolean
}
}
if desc, ok := s["description"].(string); ok {
schema.Description = desc
}
if props, ok := s["properties"].(map[string]any); ok {
schema.Properties = make(map[string]*genai.Schema)
for k, v := range props {
if vm, ok := v.(map[string]any); ok {
schema.Properties[k] = schemaToGenai(vm)
}
}
}
if req, ok := s["required"].([]string); ok {
schema.Required = req
} else if req, ok := s["required"].([]any); ok {
for _, r := range req {
if rs, ok := r.(string); ok {
schema.Required = append(schema.Required, rs)
}
}
}
if items, ok := s["items"].(map[string]any); ok {
schema.Items = schemaToGenai(items)
}
if enums, ok := s["enum"].([]string); ok {
schema.Enum = enums
} else if enums, ok := s["enum"].([]any); ok {
for _, e := range enums {
if es, ok := e.(string); ok {
schema.Enum = append(schema.Enum, es)
}
}
}
return schema
}