1ca607906d
Phase 4: provider/google on google.golang.org/genai v1.59.0 — lazy cached client, FunctionResponse tool loop, raw-JSON-schema tools and structured output, ThinkingLevel reasoning mapping, iter.Pull2 streaming, hermetic httptest suite via HTTPOptions.BaseURL. Registry wires google + gemini schemes to the real client; stub machinery deleted (all built-ins real). ADR-0011; README matrix + CLAUDE.md synced. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
458 lines
14 KiB
Go
458 lines
14 KiB
Go
package google
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
|
)
|
|
|
|
type captured struct {
|
|
path string
|
|
query string
|
|
body map[string]any
|
|
}
|
|
|
|
// serve builds a provider pointed at an httptest server (the SDK's
|
|
// documented hermetic hook: HTTPOptions.BaseURL + HTTPClient).
|
|
func serve(t *testing.T, handler func(w http.ResponseWriter, r *http.Request)) (*Provider, *captured) {
|
|
t.Helper()
|
|
cap := &captured{}
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
cap.path = r.URL.Path
|
|
cap.query = r.URL.RawQuery
|
|
raw, _ := io.ReadAll(r.Body)
|
|
_ = json.Unmarshal(raw, &cap.body)
|
|
handler(w, r)
|
|
}))
|
|
t.Cleanup(ts.Close)
|
|
return New(
|
|
WithAPIKey("test-key"),
|
|
WithBaseURL(ts.URL),
|
|
WithHTTPClient(ts.Client()),
|
|
), cap
|
|
}
|
|
|
|
func textResponse(text string) string {
|
|
return fmt.Sprintf(`{
|
|
"candidates":[{"content":{"role":"model","parts":[{"text":%q}]},"finishReason":"STOP"}],
|
|
"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":5,"thoughtsTokenCount":2}
|
|
}`, text)
|
|
}
|
|
|
|
func basicRequest() llm.Request {
|
|
return llm.Request{Messages: []llm.Message{llm.UserText("hi")}}
|
|
}
|
|
|
|
func TestGenerateRoundTrip(t *testing.T) {
|
|
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = io.WriteString(w, textResponse("hello from gemini"))
|
|
})
|
|
|
|
m, _ := p.Model("gemini-2.5-flash")
|
|
temp := 0.3
|
|
resp, err := m.Generate(context.Background(), llm.Request{
|
|
System: "be terse",
|
|
Messages: []llm.Message{llm.SystemText("extra"), llm.UserText("hi")},
|
|
Temperature: &temp,
|
|
MaxTokens: 128,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(cap.path, "models/gemini-2.5-flash:generateContent") {
|
|
t.Errorf("path = %q", cap.path)
|
|
}
|
|
sys := cap.body["systemInstruction"].(map[string]any)
|
|
sysText := sys["parts"].([]any)[0].(map[string]any)["text"]
|
|
if sysText != "be terse\n\nextra" {
|
|
t.Errorf("system = %v", sysText)
|
|
}
|
|
genCfg := cap.body["generationConfig"].(map[string]any)
|
|
if genCfg["temperature"] != 0.3 || genCfg["maxOutputTokens"] != float64(128) {
|
|
t.Errorf("generationConfig = %v", genCfg)
|
|
}
|
|
contents := cap.body["contents"].([]any)
|
|
if len(contents) != 1 {
|
|
t.Fatalf("contents = %v (system must not appear)", contents)
|
|
}
|
|
|
|
if resp.Text() != "hello from gemini" {
|
|
t.Errorf("text = %q", resp.Text())
|
|
}
|
|
if resp.Usage.InputTokens != 7 || resp.Usage.OutputTokens != 7 {
|
|
t.Errorf("usage = %+v (output must include thoughts)", resp.Usage)
|
|
}
|
|
if resp.FinishReason != llm.FinishStop {
|
|
t.Errorf("finish = %v", resp.FinishReason)
|
|
}
|
|
if resp.Model != "google/gemini-2.5-flash" {
|
|
t.Errorf("model = %q", resp.Model)
|
|
}
|
|
}
|
|
|
|
func TestImageInlineData(t *testing.T) {
|
|
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = io.WriteString(w, textResponse("a png"))
|
|
})
|
|
m, _ := p.Model("gemini-2.5-flash")
|
|
_, err := m.Generate(context.Background(), llm.Request{
|
|
Messages: []llm.Message{llm.UserParts(llm.Text("see"), llm.Image("image/png", []byte{1, 2, 3}))},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
parts := cap.body["contents"].([]any)[0].(map[string]any)["parts"].([]any)
|
|
var foundBlob bool
|
|
for _, pt := range parts {
|
|
if blob, ok := pt.(map[string]any)["inlineData"].(map[string]any); ok {
|
|
foundBlob = true
|
|
if blob["mimeType"] != "image/png" || blob["data"] != "AQID" {
|
|
t.Errorf("blob = %v", blob)
|
|
}
|
|
}
|
|
}
|
|
if !foundBlob {
|
|
t.Error("no inlineData part sent")
|
|
}
|
|
}
|
|
|
|
func TestToolsAndFunctionCalls(t *testing.T) {
|
|
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = io.WriteString(w, `{
|
|
"candidates":[{"content":{"role":"model","parts":[
|
|
{"functionCall":{"name":"get_weather","args":{"city":"Tokyo"}}}
|
|
]},"finishReason":"STOP"}]
|
|
}`)
|
|
})
|
|
m, _ := p.Model("gemini-2.5-flash")
|
|
resp, err := m.Generate(context.Background(), basicRequest(), llm.WithTools(llm.Tool{
|
|
Name: "get_weather", Description: "weather",
|
|
Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`),
|
|
}))
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
|
|
tools := cap.body["tools"].([]any)
|
|
decls := tools[0].(map[string]any)["functionDeclarations"].([]any)
|
|
decl := decls[0].(map[string]any)
|
|
if decl["name"] != "get_weather" {
|
|
t.Errorf("decl = %v", decl)
|
|
}
|
|
if _, ok := decl["parametersJsonSchema"].(map[string]any); !ok {
|
|
t.Errorf("parametersJsonSchema missing: %v", decl)
|
|
}
|
|
|
|
if len(resp.ToolCalls) != 1 {
|
|
t.Fatalf("tool calls = %+v", resp.ToolCalls)
|
|
}
|
|
tc := resp.ToolCalls[0]
|
|
if tc.Name != "get_weather" || tc.ID == "" {
|
|
t.Errorf("call = %+v (id synthesized)", tc)
|
|
}
|
|
var args struct {
|
|
City string `json:"city"`
|
|
}
|
|
if err := json.Unmarshal(tc.Arguments, &args); err != nil || args.City != "Tokyo" {
|
|
t.Errorf("args = %s", tc.Arguments)
|
|
}
|
|
if resp.FinishReason != llm.FinishToolCalls {
|
|
t.Errorf("finish = %v", resp.FinishReason)
|
|
}
|
|
}
|
|
|
|
func TestToolResultsAndHistory(t *testing.T) {
|
|
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = io.WriteString(w, textResponse("21C"))
|
|
})
|
|
m, _ := p.Model("gemini-2.5-flash")
|
|
_, err := m.Generate(context.Background(), llm.Request{
|
|
Messages: []llm.Message{
|
|
llm.UserText("weather?"),
|
|
{Role: llm.RoleAssistant, ToolCalls: []llm.ToolCall{
|
|
{ID: "c1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Tokyo"}`)},
|
|
}},
|
|
llm.ToolResultsMessage(
|
|
llm.ToolResult{ID: "c1", Name: "get_weather", Content: `{"temp":21}`},
|
|
llm.ToolResult{ID: "c2", Name: "broken", Content: "boom", IsError: true},
|
|
),
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
contents := cap.body["contents"].([]any)
|
|
if len(contents) != 3 {
|
|
t.Fatalf("contents = %d, want 3", len(contents))
|
|
}
|
|
model := contents[1].(map[string]any)
|
|
if model["role"] != "model" {
|
|
t.Errorf("assistant role = %v", model["role"])
|
|
}
|
|
fc := model["parts"].([]any)[0].(map[string]any)["functionCall"].(map[string]any)
|
|
if fc["name"] != "get_weather" {
|
|
t.Errorf("functionCall = %v", fc)
|
|
}
|
|
results := contents[2].(map[string]any)
|
|
parts := results["parts"].([]any)
|
|
fr1 := parts[0].(map[string]any)["functionResponse"].(map[string]any)
|
|
if fr1["name"] != "get_weather" {
|
|
t.Errorf("functionResponse = %v", fr1)
|
|
}
|
|
if resp1 := fr1["response"].(map[string]any); resp1["output"] != `{"temp":21}` {
|
|
t.Errorf("response payload = %v", resp1)
|
|
}
|
|
fr2 := parts[1].(map[string]any)["functionResponse"].(map[string]any)
|
|
if resp2 := fr2["response"].(map[string]any); resp2["error"] != "boom" {
|
|
t.Errorf("error payload = %v", resp2)
|
|
}
|
|
}
|
|
|
|
func TestToolChoiceMapping(t *testing.T) {
|
|
for choice, want := range map[string]string{
|
|
"required": "ANY",
|
|
"get_weather": "ANY",
|
|
} {
|
|
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = io.WriteString(w, textResponse("x"))
|
|
})
|
|
m, _ := p.Model("g")
|
|
_, err := m.Generate(context.Background(), basicRequest(),
|
|
llm.WithTools(llm.Tool{Name: "get_weather"}), llm.WithToolChoice(choice))
|
|
if err != nil {
|
|
t.Fatalf("Generate(%s): %v", choice, err)
|
|
}
|
|
tc := cap.body["toolConfig"].(map[string]any)["functionCallingConfig"].(map[string]any)
|
|
if tc["mode"] != want {
|
|
t.Errorf("choice %q → mode %v, want %v", choice, tc["mode"], want)
|
|
}
|
|
if choice == "get_weather" {
|
|
allowed := tc["allowedFunctionNames"].([]any)
|
|
if allowed[0] != "get_weather" {
|
|
t.Errorf("allowedFunctionNames = %v", allowed)
|
|
}
|
|
}
|
|
}
|
|
|
|
t.Run("none drops tools", func(t *testing.T) {
|
|
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = io.WriteString(w, textResponse("x"))
|
|
})
|
|
m, _ := p.Model("g")
|
|
if _, err := m.Generate(context.Background(), basicRequest(),
|
|
llm.WithTools(llm.Tool{Name: "t"}), llm.WithToolChoice("none")); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
if _, present := cap.body["tools"]; present {
|
|
t.Error("tool_choice none must omit tools")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestStructuredOutput(t *testing.T) {
|
|
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = io.WriteString(w, textResponse(`{"name":"Ada"}`))
|
|
})
|
|
m, _ := p.Model("g")
|
|
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
|
|
resp, err := m.Generate(context.Background(), basicRequest(), llm.WithSchema(schema, "person"))
|
|
if err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
genCfg := cap.body["generationConfig"].(map[string]any)
|
|
if genCfg["responseMimeType"] != "application/json" {
|
|
t.Errorf("responseMimeType = %v", genCfg["responseMimeType"])
|
|
}
|
|
if _, ok := genCfg["responseJsonSchema"].(map[string]any); !ok {
|
|
t.Errorf("responseJsonSchema = %v", genCfg["responseJsonSchema"])
|
|
}
|
|
if resp.Text() != `{"name":"Ada"}` {
|
|
t.Errorf("text = %q", resp.Text())
|
|
}
|
|
}
|
|
|
|
func TestReasoningEffortMapsToThinkingLevel(t *testing.T) {
|
|
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = io.WriteString(w, textResponse("x"))
|
|
})
|
|
m, _ := p.Model("g")
|
|
if _, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("high")); err != nil {
|
|
t.Fatalf("Generate: %v", err)
|
|
}
|
|
genCfg := cap.body["generationConfig"].(map[string]any)
|
|
thinking := genCfg["thinkingConfig"].(map[string]any)
|
|
if thinking["thinkingLevel"] != "HIGH" {
|
|
t.Errorf("thinkingConfig = %v", thinking)
|
|
}
|
|
|
|
if _, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("ultra")); err == nil {
|
|
t.Error("invalid effort should error")
|
|
}
|
|
}
|
|
|
|
func TestFinishReasonMapping(t *testing.T) {
|
|
for wire, want := range map[string]llm.FinishReason{
|
|
"STOP": llm.FinishStop,
|
|
"MAX_TOKENS": llm.FinishLength,
|
|
"SAFETY": llm.FinishContentFilter,
|
|
"PROHIBITED_CONTENT": llm.FinishContentFilter,
|
|
"MALFORMED_FUNCTION_CALL": llm.FinishOther,
|
|
} {
|
|
p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
fmt.Fprintf(w, `{"candidates":[{"content":{"role":"model","parts":[{"text":"x"}]},"finishReason":%q}]}`, wire)
|
|
})
|
|
m, _ := p.Model("g")
|
|
resp, err := m.Generate(context.Background(), basicRequest())
|
|
if err != nil {
|
|
t.Fatalf("Generate(%s): %v", wire, err)
|
|
}
|
|
if resp.FinishReason != want {
|
|
t.Errorf("finish %q = %v, want %v", wire, resp.FinishReason, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestAPIErrorMapping(t *testing.T) {
|
|
p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
// no response written below; status set in the closure
|
|
})
|
|
_ = p
|
|
p2, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(429)
|
|
_, _ = io.WriteString(w, `{"error":{"code":429,"message":"quota exhausted","status":"RESOURCE_EXHAUSTED"}}`)
|
|
})
|
|
m, _ := p2.Model("g")
|
|
_, err := m.Generate(context.Background(), basicRequest())
|
|
var apiErr *llm.APIError
|
|
if !errors.As(err, &apiErr) {
|
|
t.Fatalf("error = %v (%T), want APIError", err, err)
|
|
}
|
|
if apiErr.Status != 429 || !strings.Contains(apiErr.Message, "quota") {
|
|
t.Errorf("apiErr = %+v", apiErr)
|
|
}
|
|
if llm.Classify(err) != llm.ClassTransient {
|
|
t.Error("429 must classify transient")
|
|
}
|
|
}
|
|
|
|
func TestMissingAPIKey(t *testing.T) {
|
|
t.Setenv("GOOGLE_API_KEY", "")
|
|
t.Setenv("GEMINI_API_KEY", "")
|
|
p := New(WithAPIKey(""))
|
|
m, _ := p.Model("g")
|
|
_, err := m.Generate(context.Background(), basicRequest())
|
|
var apiErr *llm.APIError
|
|
if !errors.As(err, &apiErr) || apiErr.Status != http.StatusUnauthorized {
|
|
t.Errorf("error = %v, want synthetic 401", err)
|
|
}
|
|
}
|
|
|
|
func TestEnvKeyPrecedence(t *testing.T) {
|
|
t.Setenv("GOOGLE_API_KEY", "g-key")
|
|
t.Setenv("GEMINI_API_KEY", "gem-key")
|
|
if p := New(); p.apiKey != "g-key" {
|
|
t.Errorf("apiKey = %q, want GOOGLE_API_KEY to win", p.apiKey)
|
|
}
|
|
t.Setenv("GOOGLE_API_KEY", "")
|
|
if p := New(); p.apiKey != "gem-key" {
|
|
t.Errorf("apiKey = %q, want GEMINI_API_KEY fallback", p.apiKey)
|
|
}
|
|
}
|
|
|
|
func TestCapabilityEnforcement(t *testing.T) {
|
|
p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = io.WriteString(w, textResponse("x"))
|
|
})
|
|
m, _ := p.Model("g", llm.WithCapabilities(llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/png"}}))
|
|
_, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{
|
|
llm.UserParts(llm.Image("image/png", []byte{1}), llm.Image("image/png", []byte{2})),
|
|
}})
|
|
if !errors.Is(err, llm.ErrUnsupported) {
|
|
t.Errorf("error = %v, want ErrUnsupported", err)
|
|
}
|
|
}
|
|
|
|
func TestStreaming(t *testing.T) {
|
|
p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
_, _ = io.WriteString(w, `data: {"candidates":[{"content":{"role":"model","parts":[{"text":"Hel"}]}}]}
|
|
|
|
data: {"candidates":[{"content":{"role":"model","parts":[{"text":"lo"}]}}]}
|
|
|
|
data: {"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"ping","args":{}}}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":3,"candidatesTokenCount":6}}
|
|
|
|
`)
|
|
})
|
|
|
|
m, _ := p.Model("gemini-2.5-flash")
|
|
s, err := m.Stream(context.Background(), basicRequest())
|
|
if err != nil {
|
|
t.Fatalf("Stream: %v", err)
|
|
}
|
|
defer s.Close()
|
|
|
|
if !strings.Contains(cap.query+cap.path, "streamGenerateContent") {
|
|
t.Errorf("path = %q query = %q, want streaming endpoint", cap.path, cap.query)
|
|
}
|
|
|
|
var text strings.Builder
|
|
var calls []llm.ToolCall
|
|
var final *llm.Response
|
|
for {
|
|
ev, err := s.Next()
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Next: %v", err)
|
|
}
|
|
text.WriteString(ev.TextDelta)
|
|
if ev.ToolCall != nil {
|
|
calls = append(calls, *ev.ToolCall)
|
|
}
|
|
if ev.Response != nil {
|
|
final = ev.Response
|
|
}
|
|
}
|
|
if text.String() != "Hello" {
|
|
t.Errorf("text = %q", text.String())
|
|
}
|
|
if len(calls) != 1 || calls[0].Name != "ping" {
|
|
t.Errorf("calls = %+v", calls)
|
|
}
|
|
if final == nil {
|
|
t.Fatal("no final event")
|
|
}
|
|
if final.Usage.InputTokens != 3 || final.Usage.OutputTokens != 6 {
|
|
t.Errorf("usage = %+v", final.Usage)
|
|
}
|
|
if final.FinishReason != llm.FinishToolCalls {
|
|
t.Errorf("finish = %v", final.FinishReason)
|
|
}
|
|
}
|
|
|
|
func TestStreamCloseEarly(t *testing.T) {
|
|
p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
_, _ = io.WriteString(w, "data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"x\"}]}}]}\n\n")
|
|
})
|
|
m, _ := p.Model("g")
|
|
s, err := m.Stream(context.Background(), basicRequest())
|
|
if err != nil {
|
|
t.Fatalf("Stream: %v", err)
|
|
}
|
|
if err := s.Close(); err != nil {
|
|
t.Errorf("Close: %v", err)
|
|
}
|
|
_ = s.Close() // idempotent
|
|
}
|