v2 is a new Go module (v2/) with a dramatically simpler API: - Unified Message type (no more Input marker interface) - Define[T] for ergonomic tool creation with standard context.Context - Chat session with automatic tool-call loop (agent loop) - Streaming via pull-based StreamReader - MCP one-call connect (MCPStdioServer, MCPHTTPServer, MCPSSEServer) - Middleware support (logging, retry, timeout, usage tracking) - Decoupled JSON Schema (map[string]any, no provider coupling) - Sample tools: WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP - Providers: OpenAI, Anthropic, Google (all with streaming) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
323 lines
7.5 KiB
Go
323 lines
7.5 KiB
Go
// Package google implements the go-llm v2 provider interface for Google (Gemini).
|
|
package google
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
|
|
"google.golang.org/genai"
|
|
)
|
|
|
|
// Provider implements the provider.Provider interface for Google Gemini.
|
|
type Provider struct {
|
|
apiKey string
|
|
}
|
|
|
|
// New creates a new Google provider.
|
|
func New(apiKey string) *Provider {
|
|
return &Provider{apiKey: apiKey}
|
|
}
|
|
|
|
// Complete performs a non-streaming completion.
|
|
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
|
|
APIKey: p.apiKey,
|
|
Backend: genai.BackendGeminiAPI,
|
|
})
|
|
if err != nil {
|
|
return provider.Response{}, fmt.Errorf("google client error: %w", err)
|
|
}
|
|
|
|
contents, cfg := p.buildRequest(req)
|
|
|
|
resp, err := cl.Models.GenerateContent(ctx, req.Model, contents, cfg)
|
|
if err != nil {
|
|
return provider.Response{}, fmt.Errorf("google completion error: %w", err)
|
|
}
|
|
|
|
return p.convertResponse(resp)
|
|
}
|
|
|
|
// Stream performs a streaming completion.
|
|
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
|
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
|
|
APIKey: p.apiKey,
|
|
Backend: genai.BackendGeminiAPI,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("google client error: %w", err)
|
|
}
|
|
|
|
contents, cfg := p.buildRequest(req)
|
|
|
|
var fullText strings.Builder
|
|
var toolCalls []provider.ToolCall
|
|
|
|
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
|
|
if err != nil {
|
|
return fmt.Errorf("google stream error: %w", err)
|
|
}
|
|
|
|
for _, c := range resp.Candidates {
|
|
if c.Content == nil {
|
|
continue
|
|
}
|
|
for _, part := range c.Content.Parts {
|
|
if part.Text != "" {
|
|
fullText.WriteString(part.Text)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventText,
|
|
Text: part.Text,
|
|
}
|
|
}
|
|
if part.FunctionCall != nil {
|
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
|
tc := provider.ToolCall{
|
|
ID: part.FunctionCall.Name,
|
|
Name: part.FunctionCall.Name,
|
|
Arguments: string(args),
|
|
}
|
|
toolCalls = append(toolCalls, tc)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventToolStart,
|
|
ToolCall: &tc,
|
|
ToolIndex: len(toolCalls) - 1,
|
|
}
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventToolEnd,
|
|
ToolCall: &tc,
|
|
ToolIndex: len(toolCalls) - 1,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventDone,
|
|
Response: &provider.Response{
|
|
Text: fullText.String(),
|
|
ToolCalls: toolCalls,
|
|
},
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Provider) buildRequest(req provider.Request) ([]*genai.Content, *genai.GenerateContentConfig) {
|
|
var contents []*genai.Content
|
|
cfg := &genai.GenerateContentConfig{}
|
|
|
|
for _, tool := range req.Tools {
|
|
cfg.Tools = append(cfg.Tools, &genai.Tool{
|
|
FunctionDeclarations: []*genai.FunctionDeclaration{
|
|
{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
Parameters: schemaToGenai(tool.Schema),
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
if req.Temperature != nil {
|
|
f := float32(*req.Temperature)
|
|
cfg.Temperature = &f
|
|
}
|
|
|
|
if req.MaxTokens != nil {
|
|
cfg.MaxOutputTokens = int32(*req.MaxTokens)
|
|
}
|
|
|
|
if req.TopP != nil {
|
|
f := float32(*req.TopP)
|
|
cfg.TopP = &f
|
|
}
|
|
|
|
if len(req.Stop) > 0 {
|
|
cfg.StopSequences = req.Stop
|
|
}
|
|
|
|
for _, msg := range req.Messages {
|
|
var role genai.Role
|
|
switch msg.Role {
|
|
case "system":
|
|
cfg.SystemInstruction = genai.NewContentFromText(msg.Content, genai.RoleUser)
|
|
continue
|
|
case "assistant":
|
|
role = genai.RoleModel
|
|
case "tool":
|
|
// Tool results go as function responses (Genai uses RoleUser for function responses)
|
|
contents = append(contents, &genai.Content{
|
|
Role: genai.RoleUser,
|
|
Parts: []*genai.Part{
|
|
{
|
|
FunctionResponse: &genai.FunctionResponse{
|
|
Name: msg.ToolCallID,
|
|
Response: map[string]any{
|
|
"result": msg.Content,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
continue
|
|
default:
|
|
role = genai.RoleUser
|
|
}
|
|
|
|
var parts []*genai.Part
|
|
|
|
if msg.Content != "" {
|
|
parts = append(parts, genai.NewPartFromText(msg.Content))
|
|
}
|
|
|
|
// Handle tool calls in assistant messages
|
|
for _, tc := range msg.ToolCalls {
|
|
var args map[string]any
|
|
if tc.Arguments != "" {
|
|
_ = json.Unmarshal([]byte(tc.Arguments), &args)
|
|
}
|
|
parts = append(parts, &genai.Part{
|
|
FunctionCall: &genai.FunctionCall{
|
|
Name: tc.Name,
|
|
Args: args,
|
|
},
|
|
})
|
|
}
|
|
|
|
for _, img := range msg.Images {
|
|
if img.URL != "" {
|
|
// Gemini doesn't support URLs directly; download
|
|
resp, err := http.Get(img.URL)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
data, err := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
mimeType := http.DetectContentType(data)
|
|
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
|
|
} else if img.Base64 != "" {
|
|
data, err := base64.StdEncoding.DecodeString(img.Base64)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
parts = append(parts, genai.NewPartFromBytes(data, img.ContentType))
|
|
}
|
|
}
|
|
|
|
contents = append(contents, genai.NewContentFromParts(parts, role))
|
|
}
|
|
|
|
return contents, cfg
|
|
}
|
|
|
|
func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provider.Response, error) {
|
|
var res provider.Response
|
|
|
|
for _, c := range resp.Candidates {
|
|
if c.Content == nil {
|
|
continue
|
|
}
|
|
for _, part := range c.Content.Parts {
|
|
if part.Text != "" {
|
|
res.Text += part.Text
|
|
}
|
|
if part.FunctionCall != nil {
|
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
|
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
|
ID: part.FunctionCall.Name,
|
|
Name: part.FunctionCall.Name,
|
|
Arguments: string(args),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
if resp.UsageMetadata != nil {
|
|
res.Usage = &provider.Usage{
|
|
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
|
|
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
|
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// schemaToGenai converts a JSON Schema map to a genai.Schema.
|
|
func schemaToGenai(s map[string]any) *genai.Schema {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
|
|
schema := &genai.Schema{}
|
|
|
|
if t, ok := s["type"].(string); ok {
|
|
switch t {
|
|
case "object":
|
|
schema.Type = genai.TypeObject
|
|
case "array":
|
|
schema.Type = genai.TypeArray
|
|
case "string":
|
|
schema.Type = genai.TypeString
|
|
case "integer":
|
|
schema.Type = genai.TypeInteger
|
|
case "number":
|
|
schema.Type = genai.TypeNumber
|
|
case "boolean":
|
|
schema.Type = genai.TypeBoolean
|
|
}
|
|
}
|
|
|
|
if desc, ok := s["description"].(string); ok {
|
|
schema.Description = desc
|
|
}
|
|
|
|
if props, ok := s["properties"].(map[string]any); ok {
|
|
schema.Properties = make(map[string]*genai.Schema)
|
|
for k, v := range props {
|
|
if vm, ok := v.(map[string]any); ok {
|
|
schema.Properties[k] = schemaToGenai(vm)
|
|
}
|
|
}
|
|
}
|
|
|
|
if req, ok := s["required"].([]string); ok {
|
|
schema.Required = req
|
|
} else if req, ok := s["required"].([]any); ok {
|
|
for _, r := range req {
|
|
if rs, ok := r.(string); ok {
|
|
schema.Required = append(schema.Required, rs)
|
|
}
|
|
}
|
|
}
|
|
|
|
if items, ok := s["items"].(map[string]any); ok {
|
|
schema.Items = schemaToGenai(items)
|
|
}
|
|
|
|
if enums, ok := s["enum"].([]string); ok {
|
|
schema.Enum = enums
|
|
} else if enums, ok := s["enum"].([]any); ok {
|
|
for _, e := range enums {
|
|
if es, ok := e.(string); ok {
|
|
schema.Enum = append(schema.Enum, es)
|
|
}
|
|
}
|
|
}
|
|
|
|
return schema
|
|
}
|