feat: Google (Gemini) provider on the official Gen AI SDK

Phase 4: provider/google on google.golang.org/genai v1.59.0 — lazy cached
client, FunctionResponse tool loop, raw-JSON-schema tools and structured
output, ThinkingLevel reasoning mapping, iter.Pull2 streaming, hermetic
httptest suite via HTTPOptions.BaseURL. Registry wires google + gemini
schemes to the real client; stub machinery deleted (all built-ins real).
ADR-0011; README matrix + CLAUDE.md synced.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-10 13:04:28 +02:00
parent 043249e0e1
commit 1ca607906d
11 changed files with 1245 additions and 59 deletions
+416
View File
@@ -0,0 +1,416 @@
// Package google implements majordomo's provider contract for Google's
// Gemini models on the official Google Gen AI Go SDK
// (google.golang.org/genai, the approved third-party dependency per
// ADR-0007; the legacy github.com/google/generative-ai-go SDK is
// deprecated and not used).
//
// Targeted SDK surface (verified against genai v1.59.0 source, June 2026):
// Models.GenerateContent / GenerateContentStream (iter.Seq2), Content/Part
// with InlineData blobs for images, FunctionDeclaration.ParametersJsonSchema
// for raw JSON-schema tools, FunctionCall/FunctionResponse parts for the
// tool loop, GenerateContentConfig.ResponseJsonSchema + JSON MIME for
// structured output, ThinkingConfig.ThinkingLevel for reasoning effort, and
// HTTPOptions.BaseURL + HTTPClient for hermetic tests.
package google
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"sync"
"google.golang.org/genai"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
// defaultCapabilities reflects the published Gemini API limits (June 2026):
// png/jpeg/webp/heic/heif input; inline payloads bounded by a 20MB total
// request budget. MaxImagesPerReq is capped at a practical 100 (the
// published 3,600-file limit assumes the Files API, which majordomo does
// not use).
var defaultCapabilities = llm.Capabilities{
SupportsTools: true,
SupportsStructured: true,
SupportsStreaming: true,
MaxImagesPerReq: 100,
MaxImageBytes: 15 << 20,
AllowedImageMIME: []string{"image/jpeg", "image/png", "image/webp", "image/heic", "image/heif"},
}
// Provider is a Gemini provider over the official SDK.
type Provider struct {
name string
apiKey string
baseURL string
httpClient *http.Client
caps llm.Capabilities
mu sync.Mutex
client *genai.Client
}
// Option configures the provider.
type Option func(*Provider)
// WithName overrides the registry name (default "google").
func WithName(name string) Option { return func(p *Provider) { p.name = name } }
// WithAPIKey sets the API key (default: GOOGLE_API_KEY, then
// GEMINI_API_KEY, matching the SDK's own precedence).
func WithAPIKey(key string) Option { return func(p *Provider) { p.apiKey = key } }
// WithBaseURL overrides the API endpoint (tests, proxies).
func WithBaseURL(u string) Option {
return func(p *Provider) { p.baseURL = strings.TrimRight(u, "/") }
}
// WithHTTPClient overrides the HTTP client.
func WithHTTPClient(c *http.Client) Option { return func(p *Provider) { p.httpClient = c } }
// WithDefaultCapabilities overrides the provider-wide default capabilities.
func WithDefaultCapabilities(caps llm.Capabilities) Option {
return func(p *Provider) { p.caps = caps }
}
// New creates the provider. Construction never fails: a missing key
// surfaces as an auth error at request time (and chains can fail over).
func New(opts ...Option) *Provider {
p := &Provider{
name: "google",
caps: defaultCapabilities,
}
if key := os.Getenv("GOOGLE_API_KEY"); key != "" {
p.apiKey = key
} else if key := os.Getenv("GEMINI_API_KEY"); key != "" {
p.apiKey = key
}
for _, opt := range opts {
opt(p)
}
return p
}
// Name implements llm.Provider.
func (p *Provider) Name() string { return p.name }
// Model implements llm.Provider; the id passes through verbatim.
func (p *Provider) Model(id string, opts ...llm.ModelOption) (llm.Model, error) {
cfg := llm.ApplyModelOptions(opts)
caps := p.caps
if cfg.Capabilities != nil {
caps = *cfg.Capabilities
}
return &model{provider: p, id: id, caps: caps}, nil
}
// genaiClient builds (once) and returns the SDK client. The SDK's
// NewClient does no network I/O for the API-key backend; failures here are
// configuration errors, returned per call and retried on the next.
func (p *Provider) genaiClient(ctx context.Context) (*genai.Client, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.client != nil {
return p.client, nil
}
if p.apiKey == "" {
return nil, &llm.APIError{
Provider: p.name, Status: http.StatusUnauthorized,
Code: "missing_api_key",
Message: "no API key configured (set GOOGLE_API_KEY/GEMINI_API_KEY or use WithAPIKey)",
}
}
cc := &genai.ClientConfig{
APIKey: p.apiKey,
Backend: genai.BackendGeminiAPI,
}
if p.baseURL != "" {
cc.HTTPOptions = genai.HTTPOptions{BaseURL: p.baseURL}
}
if p.httpClient != nil {
cc.HTTPClient = p.httpClient
}
client, err := genai.NewClient(ctx, cc)
if err != nil {
return nil, fmt.Errorf("google: create client: %w", err)
}
p.client = client
return client, nil
}
type model struct {
provider *Provider
id string
caps llm.Capabilities
}
func (m *model) Capabilities() llm.Capabilities { return m.caps }
func (m *model) qualified() string { return m.provider.name + "/" + m.id }
// enforceCapabilities is the provider backstop (ADR-0009); the media layer
// normalizes before requests get here.
func (m *model) enforceCapabilities(req llm.Request) error {
count := 0
for _, msg := range req.Messages {
for _, part := range msg.Parts {
img, ok := part.(llm.ImagePart)
if !ok {
continue
}
count++
if !m.caps.SupportsImages() {
return fmt.Errorf("%w: %s does not accept image input", llm.ErrUnsupported, m.qualified())
}
if !m.caps.MIMEAllowed(img.MIME) {
return fmt.Errorf("%w: %s does not accept %s images", llm.ErrUnsupported, m.qualified(), img.MIME)
}
if m.caps.MaxImageBytes > 0 && len(img.Data) > m.caps.MaxImageBytes {
return fmt.Errorf("%w: image of %d bytes exceeds %s limit of %d",
llm.ErrUnsupported, len(img.Data), m.qualified(), m.caps.MaxImageBytes)
}
}
}
if count > m.caps.MaxImagesPerReq && m.caps.MaxImagesPerReq > 0 {
return fmt.Errorf("%w: %d images exceed %s limit of %d",
llm.ErrUnsupported, count, m.qualified(), m.caps.MaxImagesPerReq)
}
if len(req.Tools) > 0 && !m.caps.SupportsTools {
return fmt.Errorf("%w: %s does not support tools", llm.ErrUnsupported, m.qualified())
}
if len(req.Schema) > 0 && !m.caps.SupportsStructured {
return fmt.Errorf("%w: %s does not support structured output", llm.ErrUnsupported, m.qualified())
}
return nil
}
// buildContents maps canonical messages onto SDK contents, and collects
// the system prompt (Request.System + folded RoleSystem messages).
func (m *model) buildContents(req llm.Request) (string, []*genai.Content, error) {
var sys []string
if req.System != "" {
sys = append(sys, req.System)
}
var contents []*genai.Content
for _, msg := range req.Messages {
switch msg.Role {
case llm.RoleSystem:
if t := msg.Text(); t != "" {
sys = append(sys, t)
}
case llm.RoleTool:
parts := make([]*genai.Part, 0, len(msg.ToolResults))
for _, res := range msg.ToolResults {
payload := map[string]any{"output": res.Content}
if res.IsError {
payload = map[string]any{"error": res.Content}
}
parts = append(parts, &genai.Part{FunctionResponse: &genai.FunctionResponse{
ID: res.ID, Name: res.Name, Response: payload,
}})
}
contents = append(contents, &genai.Content{Role: genai.RoleUser, Parts: parts})
default:
role := genai.RoleUser
if msg.Role == llm.RoleAssistant {
role = genai.RoleModel
}
var parts []*genai.Part
for _, part := range msg.Parts {
switch v := part.(type) {
case llm.TextPart:
parts = append(parts, genai.NewPartFromText(v.Text))
case llm.ImagePart:
parts = append(parts, genai.NewPartFromBytes(v.Data, v.MIME))
}
}
for _, tc := range msg.ToolCalls {
args := map[string]any{}
if len(tc.Arguments) > 0 {
if err := json.Unmarshal(tc.Arguments, &args); err != nil {
return "", nil, fmt.Errorf("google: tool call %q arguments: %w", tc.Name, err)
}
}
parts = append(parts, &genai.Part{FunctionCall: &genai.FunctionCall{
ID: tc.ID, Name: tc.Name, Args: args,
}})
}
if len(parts) == 0 {
continue
}
contents = append(contents, &genai.Content{Role: role, Parts: parts})
}
}
return strings.Join(sys, "\n\n"), contents, nil
}
// buildConfig maps request knobs onto the SDK config.
func (m *model) buildConfig(req llm.Request, system string) (*genai.GenerateContentConfig, error) {
cfg := &genai.GenerateContentConfig{}
if system != "" {
cfg.SystemInstruction = genai.NewContentFromText(system, genai.RoleUser)
}
if req.Temperature != nil {
cfg.Temperature = new(float32)
*cfg.Temperature = float32(*req.Temperature)
}
if req.TopP != nil {
cfg.TopP = new(float32)
*cfg.TopP = float32(*req.TopP)
}
if req.MaxTokens > 0 {
cfg.MaxOutputTokens = int32(req.MaxTokens)
}
cfg.StopSequences = req.StopSequences
if len(req.Tools) > 0 && req.ToolChoice != "none" {
decls := make([]*genai.FunctionDeclaration, 0, len(req.Tools))
for _, t := range req.Tools {
decl := &genai.FunctionDeclaration{Name: t.Name, Description: t.Description}
if len(t.Parameters) > 0 {
var schema map[string]any
if err := json.Unmarshal(t.Parameters, &schema); err != nil {
return nil, fmt.Errorf("google: tool %q parameters: %w", t.Name, err)
}
decl.ParametersJsonSchema = schema
}
decls = append(decls, decl)
}
cfg.Tools = []*genai.Tool{{FunctionDeclarations: decls}}
switch req.ToolChoice {
case "", "auto":
// SDK default.
case "required":
cfg.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingConfigModeAny,
}}
default:
cfg.ToolConfig = &genai.ToolConfig{FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingConfigModeAny, AllowedFunctionNames: []string{req.ToolChoice},
}}
}
}
if len(req.Schema) > 0 {
var schema map[string]any
if err := json.Unmarshal(req.Schema, &schema); err != nil {
return nil, fmt.Errorf("google: output schema: %w", err)
}
cfg.ResponseJsonSchema = schema
cfg.ResponseMIMEType = "application/json"
}
switch req.ReasoningEffort {
case "":
case "low":
cfg.ThinkingConfig = &genai.ThinkingConfig{ThinkingLevel: genai.ThinkingLevelLow}
case "medium":
cfg.ThinkingConfig = &genai.ThinkingConfig{ThinkingLevel: genai.ThinkingLevelMedium}
case "high":
cfg.ThinkingConfig = &genai.ThinkingConfig{ThinkingLevel: genai.ThinkingLevelHigh}
default:
return nil, fmt.Errorf("google: invalid reasoning effort %q (want low/medium/high)", req.ReasoningEffort)
}
return cfg, nil
}
// mapError converts SDK errors into majordomo's classification shapes.
func (m *model) mapError(err error) error {
if apiErr, ok := errors.AsType[genai.APIError](err); ok {
return &llm.APIError{
Provider: m.provider.name, Model: m.id,
Status: apiErr.Code, Code: apiErr.Status, Message: apiErr.Message,
}
}
return fmt.Errorf("google %s: %w", m.qualified(), err)
}
// Generate implements llm.Model.
func (m *model) Generate(ctx context.Context, req llm.Request, opts ...llm.Option) (*llm.Response, error) {
req = req.Apply(opts...)
if err := m.enforceCapabilities(req); err != nil {
return nil, err
}
client, err := m.provider.genaiClient(ctx)
if err != nil {
return nil, err
}
system, contents, err := m.buildContents(req)
if err != nil {
return nil, err
}
cfg, err := m.buildConfig(req, system)
if err != nil {
return nil, err
}
resp, err := client.Models.GenerateContent(ctx, m.id, contents, cfg)
if err != nil {
return nil, m.mapError(err)
}
return m.toResponse(resp), nil
}
// toResponse converts an SDK response into the canonical shape.
func (m *model) toResponse(resp *genai.GenerateContentResponse) *llm.Response {
out := &llm.Response{Model: m.qualified(), Raw: resp}
if resp.UsageMetadata != nil {
out.Usage = llm.Usage{
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount + resp.UsageMetadata.ThoughtsTokenCount),
}
}
if len(resp.Candidates) == 0 {
out.FinishReason = llm.FinishOther
return out
}
cand := resp.Candidates[0]
if cand.Content != nil {
for _, part := range cand.Content.Parts {
if part == nil {
continue
}
if part.Text != "" && !part.Thought {
out.Parts = append(out.Parts, llm.Text(part.Text))
}
if fc := part.FunctionCall; fc != nil {
id := fc.ID
if id == "" {
id = "call_" + strconv.Itoa(len(out.ToolCalls))
}
args, err := json.Marshal(fc.Args)
if err != nil || len(fc.Args) == 0 {
args = json.RawMessage("{}")
}
out.ToolCalls = append(out.ToolCalls, llm.ToolCall{ID: id, Name: fc.Name, Arguments: args})
}
}
}
out.FinishReason = mapFinish(cand.FinishReason, len(out.ToolCalls) > 0)
return out
}
func mapFinish(fr genai.FinishReason, hasToolCalls bool) llm.FinishReason {
if hasToolCalls {
return llm.FinishToolCalls
}
switch fr {
case genai.FinishReasonStop, genai.FinishReasonUnspecified, "":
return llm.FinishStop
case genai.FinishReasonMaxTokens:
return llm.FinishLength
case genai.FinishReasonSafety, genai.FinishReasonRecitation, genai.FinishReasonBlocklist,
genai.FinishReasonProhibitedContent, genai.FinishReasonSPII, genai.FinishReasonImageSafety:
return llm.FinishContentFilter
default:
return llm.FinishOther
}
}
+457
View File
@@ -0,0 +1,457 @@
package google
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
type captured struct {
path string
query string
body map[string]any
}
// serve builds a provider pointed at an httptest server (the SDK's
// documented hermetic hook: HTTPOptions.BaseURL + HTTPClient).
func serve(t *testing.T, handler func(w http.ResponseWriter, r *http.Request)) (*Provider, *captured) {
t.Helper()
cap := &captured{}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cap.path = r.URL.Path
cap.query = r.URL.RawQuery
raw, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(raw, &cap.body)
handler(w, r)
}))
t.Cleanup(ts.Close)
return New(
WithAPIKey("test-key"),
WithBaseURL(ts.URL),
WithHTTPClient(ts.Client()),
), cap
}
func textResponse(text string) string {
return fmt.Sprintf(`{
"candidates":[{"content":{"role":"model","parts":[{"text":%q}]},"finishReason":"STOP"}],
"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":5,"thoughtsTokenCount":2}
}`, text)
}
func basicRequest() llm.Request {
return llm.Request{Messages: []llm.Message{llm.UserText("hi")}}
}
func TestGenerateRoundTrip(t *testing.T) {
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, textResponse("hello from gemini"))
})
m, _ := p.Model("gemini-2.5-flash")
temp := 0.3
resp, err := m.Generate(context.Background(), llm.Request{
System: "be terse",
Messages: []llm.Message{llm.SystemText("extra"), llm.UserText("hi")},
Temperature: &temp,
MaxTokens: 128,
})
if err != nil {
t.Fatalf("Generate: %v", err)
}
if !strings.Contains(cap.path, "models/gemini-2.5-flash:generateContent") {
t.Errorf("path = %q", cap.path)
}
sys := cap.body["systemInstruction"].(map[string]any)
sysText := sys["parts"].([]any)[0].(map[string]any)["text"]
if sysText != "be terse\n\nextra" {
t.Errorf("system = %v", sysText)
}
genCfg := cap.body["generationConfig"].(map[string]any)
if genCfg["temperature"] != 0.3 || genCfg["maxOutputTokens"] != float64(128) {
t.Errorf("generationConfig = %v", genCfg)
}
contents := cap.body["contents"].([]any)
if len(contents) != 1 {
t.Fatalf("contents = %v (system must not appear)", contents)
}
if resp.Text() != "hello from gemini" {
t.Errorf("text = %q", resp.Text())
}
if resp.Usage.InputTokens != 7 || resp.Usage.OutputTokens != 7 {
t.Errorf("usage = %+v (output must include thoughts)", resp.Usage)
}
if resp.FinishReason != llm.FinishStop {
t.Errorf("finish = %v", resp.FinishReason)
}
if resp.Model != "google/gemini-2.5-flash" {
t.Errorf("model = %q", resp.Model)
}
}
func TestImageInlineData(t *testing.T) {
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, textResponse("a png"))
})
m, _ := p.Model("gemini-2.5-flash")
_, err := m.Generate(context.Background(), llm.Request{
Messages: []llm.Message{llm.UserParts(llm.Text("see"), llm.Image("image/png", []byte{1, 2, 3}))},
})
if err != nil {
t.Fatalf("Generate: %v", err)
}
parts := cap.body["contents"].([]any)[0].(map[string]any)["parts"].([]any)
var foundBlob bool
for _, pt := range parts {
if blob, ok := pt.(map[string]any)["inlineData"].(map[string]any); ok {
foundBlob = true
if blob["mimeType"] != "image/png" || blob["data"] != "AQID" {
t.Errorf("blob = %v", blob)
}
}
}
if !foundBlob {
t.Error("no inlineData part sent")
}
}
func TestToolsAndFunctionCalls(t *testing.T) {
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, `{
"candidates":[{"content":{"role":"model","parts":[
{"functionCall":{"name":"get_weather","args":{"city":"Tokyo"}}}
]},"finishReason":"STOP"}]
}`)
})
m, _ := p.Model("gemini-2.5-flash")
resp, err := m.Generate(context.Background(), basicRequest(), llm.WithTools(llm.Tool{
Name: "get_weather", Description: "weather",
Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
}))
if err != nil {
t.Fatalf("Generate: %v", err)
}
tools := cap.body["tools"].([]any)
decls := tools[0].(map[string]any)["functionDeclarations"].([]any)
decl := decls[0].(map[string]any)
if decl["name"] != "get_weather" {
t.Errorf("decl = %v", decl)
}
if _, ok := decl["parametersJsonSchema"].(map[string]any); !ok {
t.Errorf("parametersJsonSchema missing: %v", decl)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("tool calls = %+v", resp.ToolCalls)
}
tc := resp.ToolCalls[0]
if tc.Name != "get_weather" || tc.ID == "" {
t.Errorf("call = %+v (id synthesized)", tc)
}
var args struct {
City string `json:"city"`
}
if err := json.Unmarshal(tc.Arguments, &args); err != nil || args.City != "Tokyo" {
t.Errorf("args = %s", tc.Arguments)
}
if resp.FinishReason != llm.FinishToolCalls {
t.Errorf("finish = %v", resp.FinishReason)
}
}
func TestToolResultsAndHistory(t *testing.T) {
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, textResponse("21C"))
})
m, _ := p.Model("gemini-2.5-flash")
_, err := m.Generate(context.Background(), llm.Request{
Messages: []llm.Message{
llm.UserText("weather?"),
{Role: llm.RoleAssistant, ToolCalls: []llm.ToolCall{
{ID: "c1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Tokyo"}`)},
}},
llm.ToolResultsMessage(
llm.ToolResult{ID: "c1", Name: "get_weather", Content: `{"temp":21}`},
llm.ToolResult{ID: "c2", Name: "broken", Content: "boom", IsError: true},
),
},
})
if err != nil {
t.Fatalf("Generate: %v", err)
}
contents := cap.body["contents"].([]any)
if len(contents) != 3 {
t.Fatalf("contents = %d, want 3", len(contents))
}
model := contents[1].(map[string]any)
if model["role"] != "model" {
t.Errorf("assistant role = %v", model["role"])
}
fc := model["parts"].([]any)[0].(map[string]any)["functionCall"].(map[string]any)
if fc["name"] != "get_weather" {
t.Errorf("functionCall = %v", fc)
}
results := contents[2].(map[string]any)
parts := results["parts"].([]any)
fr1 := parts[0].(map[string]any)["functionResponse"].(map[string]any)
if fr1["name"] != "get_weather" {
t.Errorf("functionResponse = %v", fr1)
}
if resp1 := fr1["response"].(map[string]any); resp1["output"] != `{"temp":21}` {
t.Errorf("response payload = %v", resp1)
}
fr2 := parts[1].(map[string]any)["functionResponse"].(map[string]any)
if resp2 := fr2["response"].(map[string]any); resp2["error"] != "boom" {
t.Errorf("error payload = %v", resp2)
}
}
func TestToolChoiceMapping(t *testing.T) {
for choice, want := range map[string]string{
"required": "ANY",
"get_weather": "ANY",
} {
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, textResponse("x"))
})
m, _ := p.Model("g")
_, err := m.Generate(context.Background(), basicRequest(),
llm.WithTools(llm.Tool{Name: "get_weather"}), llm.WithToolChoice(choice))
if err != nil {
t.Fatalf("Generate(%s): %v", choice, err)
}
tc := cap.body["toolConfig"].(map[string]any)["functionCallingConfig"].(map[string]any)
if tc["mode"] != want {
t.Errorf("choice %q → mode %v, want %v", choice, tc["mode"], want)
}
if choice == "get_weather" {
allowed := tc["allowedFunctionNames"].([]any)
if allowed[0] != "get_weather" {
t.Errorf("allowedFunctionNames = %v", allowed)
}
}
}
t.Run("none drops tools", func(t *testing.T) {
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, textResponse("x"))
})
m, _ := p.Model("g")
if _, err := m.Generate(context.Background(), basicRequest(),
llm.WithTools(llm.Tool{Name: "t"}), llm.WithToolChoice("none")); err != nil {
t.Fatalf("Generate: %v", err)
}
if _, present := cap.body["tools"]; present {
t.Error("tool_choice none must omit tools")
}
})
}
func TestStructuredOutput(t *testing.T) {
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, textResponse(`{"name":"Ada"}`))
})
m, _ := p.Model("g")
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
resp, err := m.Generate(context.Background(), basicRequest(), llm.WithSchema(schema, "person"))
if err != nil {
t.Fatalf("Generate: %v", err)
}
genCfg := cap.body["generationConfig"].(map[string]any)
if genCfg["responseMimeType"] != "application/json" {
t.Errorf("responseMimeType = %v", genCfg["responseMimeType"])
}
if _, ok := genCfg["responseJsonSchema"].(map[string]any); !ok {
t.Errorf("responseJsonSchema = %v", genCfg["responseJsonSchema"])
}
if resp.Text() != `{"name":"Ada"}` {
t.Errorf("text = %q", resp.Text())
}
}
func TestReasoningEffortMapsToThinkingLevel(t *testing.T) {
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, textResponse("x"))
})
m, _ := p.Model("g")
if _, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("high")); err != nil {
t.Fatalf("Generate: %v", err)
}
genCfg := cap.body["generationConfig"].(map[string]any)
thinking := genCfg["thinkingConfig"].(map[string]any)
if thinking["thinkingLevel"] != "HIGH" {
t.Errorf("thinkingConfig = %v", thinking)
}
if _, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("ultra")); err == nil {
t.Error("invalid effort should error")
}
}
func TestFinishReasonMapping(t *testing.T) {
for wire, want := range map[string]llm.FinishReason{
"STOP": llm.FinishStop,
"MAX_TOKENS": llm.FinishLength,
"SAFETY": llm.FinishContentFilter,
"PROHIBITED_CONTENT": llm.FinishContentFilter,
"MALFORMED_FUNCTION_CALL": llm.FinishOther,
} {
p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprintf(w, `{"candidates":[{"content":{"role":"model","parts":[{"text":"x"}]},"finishReason":%q}]}`, wire)
})
m, _ := p.Model("g")
resp, err := m.Generate(context.Background(), basicRequest())
if err != nil {
t.Fatalf("Generate(%s): %v", wire, err)
}
if resp.FinishReason != want {
t.Errorf("finish %q = %v, want %v", wire, resp.FinishReason, want)
}
}
}
func TestAPIErrorMapping(t *testing.T) {
p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
// no response written below; status set in the closure
})
_ = p
p2, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(429)
_, _ = io.WriteString(w, `{"error":{"code":429,"message":"quota exhausted","status":"RESOURCE_EXHAUSTED"}}`)
})
m, _ := p2.Model("g")
_, err := m.Generate(context.Background(), basicRequest())
var apiErr *llm.APIError
if !errors.As(err, &apiErr) {
t.Fatalf("error = %v (%T), want APIError", err, err)
}
if apiErr.Status != 429 || !strings.Contains(apiErr.Message, "quota") {
t.Errorf("apiErr = %+v", apiErr)
}
if llm.Classify(err) != llm.ClassTransient {
t.Error("429 must classify transient")
}
}
func TestMissingAPIKey(t *testing.T) {
t.Setenv("GOOGLE_API_KEY", "")
t.Setenv("GEMINI_API_KEY", "")
p := New(WithAPIKey(""))
m, _ := p.Model("g")
_, err := m.Generate(context.Background(), basicRequest())
var apiErr *llm.APIError
if !errors.As(err, &apiErr) || apiErr.Status != http.StatusUnauthorized {
t.Errorf("error = %v, want synthetic 401", err)
}
}
func TestEnvKeyPrecedence(t *testing.T) {
t.Setenv("GOOGLE_API_KEY", "g-key")
t.Setenv("GEMINI_API_KEY", "gem-key")
if p := New(); p.apiKey != "g-key" {
t.Errorf("apiKey = %q, want GOOGLE_API_KEY to win", p.apiKey)
}
t.Setenv("GOOGLE_API_KEY", "")
if p := New(); p.apiKey != "gem-key" {
t.Errorf("apiKey = %q, want GEMINI_API_KEY fallback", p.apiKey)
}
}
func TestCapabilityEnforcement(t *testing.T) {
p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, textResponse("x"))
})
m, _ := p.Model("g", llm.WithCapabilities(llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/png"}}))
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{
llm.UserParts(llm.Image("image/png", []byte{1}), llm.Image("image/png", []byte{2})),
}})
if !errors.Is(err, llm.ErrUnsupported) {
t.Errorf("error = %v, want ErrUnsupported", err)
}
}
func TestStreaming(t *testing.T) {
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = io.WriteString(w, `data: {"candidates":[{"content":{"role":"model","parts":[{"text":"Hel"}]}}]}
data: {"candidates":[{"content":{"role":"model","parts":[{"text":"lo"}]}}]}
data: {"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"ping","args":{}}}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":3,"candidatesTokenCount":6}}
`)
})
m, _ := p.Model("gemini-2.5-flash")
s, err := m.Stream(context.Background(), basicRequest())
if err != nil {
t.Fatalf("Stream: %v", err)
}
defer s.Close()
if !strings.Contains(cap.query+cap.path, "streamGenerateContent") {
t.Errorf("path = %q query = %q, want streaming endpoint", cap.path, cap.query)
}
var text strings.Builder
var calls []llm.ToolCall
var final *llm.Response
for {
ev, err := s.Next()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
t.Fatalf("Next: %v", err)
}
text.WriteString(ev.TextDelta)
if ev.ToolCall != nil {
calls = append(calls, *ev.ToolCall)
}
if ev.Response != nil {
final = ev.Response
}
}
if text.String() != "Hello" {
t.Errorf("text = %q", text.String())
}
if len(calls) != 1 || calls[0].Name != "ping" {
t.Errorf("calls = %+v", calls)
}
if final == nil {
t.Fatal("no final event")
}
if final.Usage.InputTokens != 3 || final.Usage.OutputTokens != 6 {
t.Errorf("usage = %+v", final.Usage)
}
if final.FinishReason != llm.FinishToolCalls {
t.Errorf("finish = %v", final.FinishReason)
}
}
func TestStreamCloseEarly(t *testing.T) {
p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = io.WriteString(w, "data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"x\"}]}}]}\n\n")
})
m, _ := p.Model("g")
s, err := m.Stream(context.Background(), basicRequest())
if err != nil {
t.Fatalf("Stream: %v", err)
}
if err := s.Close(); err != nil {
t.Errorf("Close: %v", err)
}
_ = s.Close() // idempotent
}
+140
View File
@@ -0,0 +1,140 @@
package google
import (
"context"
"encoding/json"
"io"
"iter"
"strconv"
"sync"
"google.golang.org/genai"
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
)
// Stream implements llm.Model over the SDK's range-over-func stream
// (iter.Seq2), adapted to majordomo's pull-based Stream via iter.Pull2.
func (m *model) Stream(ctx context.Context, req llm.Request, opts ...llm.Option) (llm.Stream, error) {
req = req.Apply(opts...)
if err := m.enforceCapabilities(req); err != nil {
return nil, err
}
client, err := m.provider.genaiClient(ctx)
if err != nil {
return nil, err
}
system, contents, err := m.buildContents(req)
if err != nil {
return nil, err
}
cfg, err := m.buildConfig(req, system)
if err != nil {
return nil, err
}
seq := client.Models.GenerateContentStream(ctx, m.id, contents, cfg)
next, stop := iter.Pull2(iter.Seq2[*genai.GenerateContentResponse, error](seq))
return &stream{model: m, next: next, stop: stop}, nil
}
type stream struct {
model *model
next func() (*genai.GenerateContentResponse, error, bool)
stop func()
mu sync.Mutex
closeOnce sync.Once
finished bool
pending []llm.StreamEvent
text []byte
toolCalls []llm.ToolCall
usage llm.Usage
finish genai.FinishReason
}
func (s *stream) Next() (llm.StreamEvent, error) {
s.mu.Lock()
defer s.mu.Unlock()
for {
if len(s.pending) > 0 {
ev := s.pending[0]
s.pending = s.pending[1:]
return ev, nil
}
if s.finished {
return llm.StreamEvent{}, io.EOF
}
chunk, err, ok := s.next()
if !ok {
s.queueFinal()
continue
}
if err != nil {
return llm.StreamEvent{}, s.model.mapError(err)
}
if chunk.UsageMetadata != nil {
s.usage = llm.Usage{
InputTokens: int(chunk.UsageMetadata.PromptTokenCount),
OutputTokens: int(chunk.UsageMetadata.CandidatesTokenCount + chunk.UsageMetadata.ThoughtsTokenCount),
}
}
if len(chunk.Candidates) == 0 {
continue
}
cand := chunk.Candidates[0]
if cand.FinishReason != "" {
s.finish = cand.FinishReason
}
if cand.Content == nil {
continue
}
for _, part := range cand.Content.Parts {
if part == nil {
continue
}
if part.Text != "" && !part.Thought {
s.text = append(s.text, part.Text...)
s.pending = append(s.pending, llm.StreamEvent{TextDelta: part.Text})
}
// Function calls arrive whole per chunk in the Gemini stream.
if fc := part.FunctionCall; fc != nil {
id := fc.ID
if id == "" {
id = "call_" + strconv.Itoa(len(s.toolCalls))
}
args, err := json.Marshal(fc.Args)
if err != nil || len(fc.Args) == 0 {
args = json.RawMessage("{}")
}
call := llm.ToolCall{ID: id, Name: fc.Name, Arguments: args}
s.toolCalls = append(s.toolCalls, call)
s.pending = append(s.pending, llm.StreamEvent{ToolCall: &call})
}
}
}
}
func (s *stream) queueFinal() {
resp := &llm.Response{
Model: s.model.qualified(),
Usage: s.usage,
FinishReason: mapFinish(s.finish, len(s.toolCalls) > 0),
}
if len(s.text) > 0 {
resp.Parts = append(resp.Parts, llm.Text(string(s.text)))
}
if len(s.toolCalls) > 0 {
resp.ToolCalls = append([]llm.ToolCall(nil), s.toolCalls...)
}
s.pending = append(s.pending, llm.StreamEvent{Response: resp})
s.finished = true
}
func (s *stream) Close() error {
s.closeOnce.Do(s.stop)
return nil
}