cbaf41f50c
Introduces an opt-in level-based reasoning toggle (low/medium/high) that each provider translates to its native parameter: - Anthropic: thinking.budget_tokens (1024/8000/24000), with temperature forced to default and MaxTokens auto-grown above the budget. - OpenAI/xAI/Groq via openaicompat: reasoning_effort string, gated by a new Rules.SupportsReasoning predicate so non-reasoning models don't receive the parameter. xAI uses Rules.MapReasoningEffort to remap "medium" to "high" since its API only accepts low|high. - Google: thinking_config.thinking_budget + include_thoughts:true. - DeepSeek: SupportsReasoning=false (reasoner is always-on; the reasoning_content trace was already extracted via openaicompat). Reasoning content is surfaced as Response.Thinking on Complete and as StreamEventThinking deltas during streaming. Provider-side: extracted from Anthropic thinking content blocks, Google's part.Thought=true parts, and the non-standard reasoning_content field that DeepSeek and Groq emit (parsed out of raw JSON since openai-go doesn't type it). Public API: - llm.ReasoningLevel + ReasoningLow/Medium/High constants - llm.WithReasoning(level) request option - Model.WithReasoning(level) for baked-in defaults - provider.Request.Reasoning, provider.Response.Thinking - provider.StreamEventThinking Tests cover Rules-based gating, MapReasoningEffort, reasoning_content extraction (Complete + Stream), Anthropic budget mapping, and temperature suppression when thinking is enabled. Existing behavior is unchanged when Reasoning is the empty string. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
436 lines
11 KiB
Go
436 lines
11 KiB
Go
// Package google implements the go-llm v2 provider interface for Google (Gemini).
|
|
package google
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
|
|
"google.golang.org/genai"
|
|
)
|
|
|
|
// Provider implements the provider.Provider interface for Google Gemini.
|
|
type Provider struct {
|
|
apiKey string
|
|
}
|
|
|
|
// New creates a new Google provider.
|
|
func New(apiKey string) *Provider {
|
|
return &Provider{apiKey: apiKey}
|
|
}
|
|
|
|
// Complete performs a non-streaming completion.
|
|
func (p *Provider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
|
|
APIKey: p.apiKey,
|
|
Backend: genai.BackendGeminiAPI,
|
|
})
|
|
if err != nil {
|
|
return provider.Response{}, fmt.Errorf("google client error: %w", err)
|
|
}
|
|
|
|
contents, cfg := p.buildRequest(req)
|
|
|
|
resp, err := cl.Models.GenerateContent(ctx, req.Model, contents, cfg)
|
|
if err != nil {
|
|
return provider.Response{}, fmt.Errorf("google completion error: %w", err)
|
|
}
|
|
|
|
return p.convertResponse(resp)
|
|
}
|
|
|
|
// Stream performs a streaming completion.
|
|
func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
|
cl, err := genai.NewClient(ctx, &genai.ClientConfig{
|
|
APIKey: p.apiKey,
|
|
Backend: genai.BackendGeminiAPI,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("google client error: %w", err)
|
|
}
|
|
|
|
contents, cfg := p.buildRequest(req)
|
|
|
|
var fullText strings.Builder
|
|
var fullThinking strings.Builder
|
|
var toolCalls []provider.ToolCall
|
|
var usage *provider.Usage
|
|
|
|
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
|
|
if err != nil {
|
|
return fmt.Errorf("google stream error: %w", err)
|
|
}
|
|
|
|
// Track usage from the last chunk (final chunk has cumulative counts)
|
|
if resp.UsageMetadata != nil {
|
|
usage = &provider.Usage{
|
|
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
|
|
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
|
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
|
}
|
|
details := map[string]int{}
|
|
if resp.UsageMetadata.CachedContentTokenCount > 0 {
|
|
details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount)
|
|
}
|
|
if resp.UsageMetadata.ThoughtsTokenCount > 0 {
|
|
details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount)
|
|
}
|
|
if len(details) > 0 {
|
|
usage.Details = details
|
|
}
|
|
}
|
|
|
|
for _, c := range resp.Candidates {
|
|
if c.Content == nil {
|
|
continue
|
|
}
|
|
for _, part := range c.Content.Parts {
|
|
if part.Text != "" {
|
|
if part.Thought {
|
|
fullThinking.WriteString(part.Text)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventThinking,
|
|
Text: part.Text,
|
|
}
|
|
} else {
|
|
fullText.WriteString(part.Text)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventText,
|
|
Text: part.Text,
|
|
}
|
|
}
|
|
}
|
|
if part.FunctionCall != nil {
|
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
|
tc := provider.ToolCall{
|
|
ID: part.FunctionCall.Name,
|
|
Name: part.FunctionCall.Name,
|
|
Arguments: string(args),
|
|
}
|
|
toolCalls = append(toolCalls, tc)
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventToolStart,
|
|
ToolCall: &tc,
|
|
ToolIndex: len(toolCalls) - 1,
|
|
}
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventToolEnd,
|
|
ToolCall: &tc,
|
|
ToolIndex: len(toolCalls) - 1,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
events <- provider.StreamEvent{
|
|
Type: provider.StreamEventDone,
|
|
Response: &provider.Response{
|
|
Text: fullText.String(),
|
|
Thinking: fullThinking.String(),
|
|
ToolCalls: toolCalls,
|
|
Usage: usage,
|
|
},
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Provider) buildRequest(req provider.Request) ([]*genai.Content, *genai.GenerateContentConfig) {
|
|
var contents []*genai.Content
|
|
cfg := &genai.GenerateContentConfig{}
|
|
|
|
for _, tool := range req.Tools {
|
|
cfg.Tools = append(cfg.Tools, &genai.Tool{
|
|
FunctionDeclarations: []*genai.FunctionDeclaration{
|
|
{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
Parameters: schemaToGenai(tool.Schema),
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
if req.Temperature != nil {
|
|
f := float32(*req.Temperature)
|
|
cfg.Temperature = &f
|
|
}
|
|
|
|
if req.MaxTokens != nil {
|
|
cfg.MaxOutputTokens = int32(*req.MaxTokens)
|
|
}
|
|
|
|
if req.TopP != nil {
|
|
f := float32(*req.TopP)
|
|
cfg.TopP = &f
|
|
}
|
|
|
|
if len(req.Stop) > 0 {
|
|
cfg.StopSequences = req.Stop
|
|
}
|
|
|
|
// Extended thinking via thinking_config. Models that don't support
|
|
// thinking ignore this field; budgets here mirror the Anthropic
|
|
// mapping so a single ReasoningLevel produces comparable behavior
|
|
// across providers.
|
|
if budget := thinkingBudget(req.Reasoning); budget > 0 {
|
|
b := int32(budget)
|
|
cfg.ThinkingConfig = &genai.ThinkingConfig{
|
|
ThinkingBudget: &b,
|
|
IncludeThoughts: true,
|
|
}
|
|
}
|
|
|
|
for _, msg := range req.Messages {
|
|
var role genai.Role
|
|
switch msg.Role {
|
|
case "system":
|
|
cfg.SystemInstruction = genai.NewContentFromText(msg.Content, genai.RoleUser)
|
|
continue
|
|
case "assistant":
|
|
role = genai.RoleModel
|
|
case "tool":
|
|
// Tool results go as function responses (Genai uses RoleUser for function responses)
|
|
contents = append(contents, &genai.Content{
|
|
Role: genai.RoleUser,
|
|
Parts: []*genai.Part{
|
|
{
|
|
FunctionResponse: &genai.FunctionResponse{
|
|
Name: msg.ToolCallID,
|
|
Response: map[string]any{
|
|
"result": msg.Content,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
continue
|
|
default:
|
|
role = genai.RoleUser
|
|
}
|
|
|
|
var parts []*genai.Part
|
|
|
|
if msg.Content != "" {
|
|
parts = append(parts, genai.NewPartFromText(msg.Content))
|
|
}
|
|
|
|
// Handle tool calls in assistant messages
|
|
for _, tc := range msg.ToolCalls {
|
|
var args map[string]any
|
|
if tc.Arguments != "" {
|
|
_ = json.Unmarshal([]byte(tc.Arguments), &args)
|
|
}
|
|
parts = append(parts, &genai.Part{
|
|
FunctionCall: &genai.FunctionCall{
|
|
Name: tc.Name,
|
|
Args: args,
|
|
},
|
|
})
|
|
}
|
|
|
|
for _, img := range msg.Images {
|
|
if img.URL != "" {
|
|
// Gemini doesn't support URLs directly; download
|
|
resp, err := http.Get(img.URL)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
data, err := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
mimeType := http.DetectContentType(data)
|
|
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
|
|
} else if img.Base64 != "" {
|
|
data, err := base64.StdEncoding.DecodeString(img.Base64)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
parts = append(parts, genai.NewPartFromBytes(data, img.ContentType))
|
|
}
|
|
}
|
|
|
|
for _, aud := range msg.Audio {
|
|
if aud.URL != "" {
|
|
resp, err := http.Get(aud.URL)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
data, err := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
mimeType := resp.Header.Get("Content-Type")
|
|
if mimeType == "" {
|
|
mimeType = aud.ContentType
|
|
}
|
|
if mimeType == "" {
|
|
mimeType = "audio/wav"
|
|
}
|
|
parts = append(parts, genai.NewPartFromBytes(data, mimeType))
|
|
} else if aud.Base64 != "" {
|
|
data, err := base64.StdEncoding.DecodeString(aud.Base64)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
ct := aud.ContentType
|
|
if ct == "" {
|
|
ct = "audio/wav"
|
|
}
|
|
parts = append(parts, genai.NewPartFromBytes(data, ct))
|
|
}
|
|
}
|
|
|
|
contents = append(contents, genai.NewContentFromParts(parts, role))
|
|
}
|
|
|
|
return contents, cfg
|
|
}
|
|
|
|
func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provider.Response, error) {
|
|
var res provider.Response
|
|
|
|
for _, c := range resp.Candidates {
|
|
if c.Content == nil {
|
|
continue
|
|
}
|
|
for _, part := range c.Content.Parts {
|
|
if part.Text != "" {
|
|
if part.Thought {
|
|
res.Thinking += part.Text
|
|
} else {
|
|
res.Text += part.Text
|
|
}
|
|
}
|
|
if part.FunctionCall != nil {
|
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
|
res.ToolCalls = append(res.ToolCalls, provider.ToolCall{
|
|
ID: part.FunctionCall.Name,
|
|
Name: part.FunctionCall.Name,
|
|
Arguments: string(args),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
if resp.UsageMetadata != nil {
|
|
res.Usage = &provider.Usage{
|
|
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
|
|
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
|
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
|
}
|
|
details := map[string]int{}
|
|
if resp.UsageMetadata.CachedContentTokenCount > 0 {
|
|
details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount)
|
|
}
|
|
if resp.UsageMetadata.ThoughtsTokenCount > 0 {
|
|
details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount)
|
|
}
|
|
if len(details) > 0 {
|
|
res.Usage.Details = details
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// Thinking budgets used by Google for low/medium/high reasoning levels.
|
|
// Mirrors the Anthropic mapping so a single go-llm ReasoningLevel produces
|
|
// comparable behavior across providers.
|
|
const (
|
|
thinkingBudgetLow = 1024
|
|
thinkingBudgetMedium = 8000
|
|
thinkingBudgetHigh = 24000
|
|
)
|
|
|
|
// thinkingBudget returns the genai thinking_budget for a go-llm
|
|
// ReasoningLevel, or 0 to disable thinking.
|
|
func thinkingBudget(level string) int {
|
|
switch level {
|
|
case "low":
|
|
return thinkingBudgetLow
|
|
case "medium":
|
|
return thinkingBudgetMedium
|
|
case "high":
|
|
return thinkingBudgetHigh
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// schemaToGenai converts a JSON Schema map to a genai.Schema.
|
|
func schemaToGenai(s map[string]any) *genai.Schema {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
|
|
schema := &genai.Schema{}
|
|
|
|
if t, ok := s["type"].(string); ok {
|
|
switch t {
|
|
case "object":
|
|
schema.Type = genai.TypeObject
|
|
case "array":
|
|
schema.Type = genai.TypeArray
|
|
case "string":
|
|
schema.Type = genai.TypeString
|
|
case "integer":
|
|
schema.Type = genai.TypeInteger
|
|
case "number":
|
|
schema.Type = genai.TypeNumber
|
|
case "boolean":
|
|
schema.Type = genai.TypeBoolean
|
|
}
|
|
}
|
|
|
|
if desc, ok := s["description"].(string); ok {
|
|
schema.Description = desc
|
|
}
|
|
|
|
if props, ok := s["properties"].(map[string]any); ok {
|
|
schema.Properties = make(map[string]*genai.Schema)
|
|
for k, v := range props {
|
|
if vm, ok := v.(map[string]any); ok {
|
|
schema.Properties[k] = schemaToGenai(vm)
|
|
}
|
|
}
|
|
}
|
|
|
|
if req, ok := s["required"].([]string); ok {
|
|
schema.Required = req
|
|
} else if req, ok := s["required"].([]any); ok {
|
|
for _, r := range req {
|
|
if rs, ok := r.(string); ok {
|
|
schema.Required = append(schema.Required, rs)
|
|
}
|
|
}
|
|
}
|
|
|
|
if items, ok := s["items"].(map[string]any); ok {
|
|
schema.Items = schemaToGenai(items)
|
|
}
|
|
|
|
if enums, ok := s["enum"].([]string); ok {
|
|
schema.Enum = enums
|
|
} else if enums, ok := s["enum"].([]any); ok {
|
|
for _, e := range enums {
|
|
if es, ok := e.(string); ok {
|
|
schema.Enum = append(schema.Enum, es)
|
|
}
|
|
}
|
|
}
|
|
|
|
return schema
|
|
}
|