package google import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "gitea.stevedudenhoeffer.com/steve/majordomo/llm" ) type captured struct { path string query string body map[string]any } // serve builds a provider pointed at an httptest server (the SDK's // documented hermetic hook: HTTPOptions.BaseURL + HTTPClient). func serve(t *testing.T, handler func(w http.ResponseWriter, r *http.Request)) (*Provider, *captured) { t.Helper() cap := &captured{} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cap.path = r.URL.Path cap.query = r.URL.RawQuery raw, _ := io.ReadAll(r.Body) _ = json.Unmarshal(raw, &cap.body) handler(w, r) })) t.Cleanup(ts.Close) return New( WithAPIKey("test-key"), WithBaseURL(ts.URL), WithHTTPClient(ts.Client()), ), cap } func textResponse(text string) string { return fmt.Sprintf(`{ "candidates":[{"content":{"role":"model","parts":[{"text":%q}]},"finishReason":"STOP"}], "usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":5,"thoughtsTokenCount":2} }`, text) } func basicRequest() llm.Request { return llm.Request{Messages: []llm.Message{llm.UserText("hi")}} } func TestGenerateRoundTrip(t *testing.T) { p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, textResponse("hello from gemini")) }) m, _ := p.Model("gemini-2.5-flash") temp := 0.3 resp, err := m.Generate(context.Background(), llm.Request{ System: "be terse", Messages: []llm.Message{llm.SystemText("extra"), llm.UserText("hi")}, Temperature: &temp, MaxTokens: 128, }) if err != nil { t.Fatalf("Generate: %v", err) } if !strings.Contains(cap.path, "models/gemini-2.5-flash:generateContent") { t.Errorf("path = %q", cap.path) } sys := cap.body["systemInstruction"].(map[string]any) sysText := sys["parts"].([]any)[0].(map[string]any)["text"] if sysText != "be terse\n\nextra" { t.Errorf("system = %v", sysText) } genCfg := cap.body["generationConfig"].(map[string]any) if genCfg["temperature"] != 0.3 || genCfg["maxOutputTokens"] != float64(128) { t.Errorf("generationConfig = %v", genCfg) } contents := cap.body["contents"].([]any) if len(contents) != 1 { t.Fatalf("contents = %v (system must not appear)", contents) } if resp.Text() != "hello from gemini" { t.Errorf("text = %q", resp.Text()) } if resp.Usage.InputTokens != 7 || resp.Usage.OutputTokens != 7 { t.Errorf("usage = %+v (output must include thoughts)", resp.Usage) } if resp.FinishReason != llm.FinishStop { t.Errorf("finish = %v", resp.FinishReason) } if resp.Model != "google/gemini-2.5-flash" { t.Errorf("model = %q", resp.Model) } } func TestImageInlineData(t *testing.T) { p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, textResponse("a png")) }) m, _ := p.Model("gemini-2.5-flash") _, err := m.Generate(context.Background(), llm.Request{ Messages: []llm.Message{llm.UserParts(llm.Text("see"), llm.Image("image/png", []byte{1, 2, 3}))}, }) if err != nil { t.Fatalf("Generate: %v", err) } parts := cap.body["contents"].([]any)[0].(map[string]any)["parts"].([]any) var foundBlob bool for _, pt := range parts { if blob, ok := pt.(map[string]any)["inlineData"].(map[string]any); ok { foundBlob = true if blob["mimeType"] != "image/png" || blob["data"] != "AQID" { t.Errorf("blob = %v", blob) } } } if !foundBlob { t.Error("no inlineData part sent") } } func TestToolsAndFunctionCalls(t *testing.T) { p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, `{ "candidates":[{"content":{"role":"model","parts":[ {"functionCall":{"name":"get_weather","args":{"city":"Tokyo"}}} ]},"finishReason":"STOP"}] }`) }) m, _ := p.Model("gemini-2.5-flash") resp, err := m.Generate(context.Background(), basicRequest(), llm.WithTools(llm.Tool{ Name: "get_weather", Description: "weather", Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`), })) if err != nil { t.Fatalf("Generate: %v", err) } tools := cap.body["tools"].([]any) decls := tools[0].(map[string]any)["functionDeclarations"].([]any) decl := decls[0].(map[string]any) if decl["name"] != "get_weather" { t.Errorf("decl = %v", decl) } if _, ok := decl["parametersJsonSchema"].(map[string]any); !ok { t.Errorf("parametersJsonSchema missing: %v", decl) } if len(resp.ToolCalls) != 1 { t.Fatalf("tool calls = %+v", resp.ToolCalls) } tc := resp.ToolCalls[0] if tc.Name != "get_weather" || tc.ID == "" { t.Errorf("call = %+v (id synthesized)", tc) } var args struct { City string `json:"city"` } if err := json.Unmarshal(tc.Arguments, &args); err != nil || args.City != "Tokyo" { t.Errorf("args = %s", tc.Arguments) } if resp.FinishReason != llm.FinishToolCalls { t.Errorf("finish = %v", resp.FinishReason) } } func TestToolResultsAndHistory(t *testing.T) { p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, textResponse("21C")) }) m, _ := p.Model("gemini-2.5-flash") _, err := m.Generate(context.Background(), llm.Request{ Messages: []llm.Message{ llm.UserText("weather?"), {Role: llm.RoleAssistant, ToolCalls: []llm.ToolCall{ {ID: "c1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"Tokyo"}`)}, }}, llm.ToolResultsMessage( llm.ToolResult{ID: "c1", Name: "get_weather", Content: `{"temp":21}`}, llm.ToolResult{ID: "c2", Name: "broken", Content: "boom", IsError: true}, ), }, }) if err != nil { t.Fatalf("Generate: %v", err) } contents := cap.body["contents"].([]any) if len(contents) != 3 { t.Fatalf("contents = %d, want 3", len(contents)) } model := contents[1].(map[string]any) if model["role"] != "model" { t.Errorf("assistant role = %v", model["role"]) } fc := model["parts"].([]any)[0].(map[string]any)["functionCall"].(map[string]any) if fc["name"] != "get_weather" { t.Errorf("functionCall = %v", fc) } results := contents[2].(map[string]any) parts := results["parts"].([]any) fr1 := parts[0].(map[string]any)["functionResponse"].(map[string]any) if fr1["name"] != "get_weather" { t.Errorf("functionResponse = %v", fr1) } if resp1 := fr1["response"].(map[string]any); resp1["output"] != `{"temp":21}` { t.Errorf("response payload = %v", resp1) } fr2 := parts[1].(map[string]any)["functionResponse"].(map[string]any) if resp2 := fr2["response"].(map[string]any); resp2["error"] != "boom" { t.Errorf("error payload = %v", resp2) } } func TestToolChoiceMapping(t *testing.T) { for choice, want := range map[string]string{ "required": "ANY", "get_weather": "ANY", } { p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, textResponse("x")) }) m, _ := p.Model("g") _, err := m.Generate(context.Background(), basicRequest(), llm.WithTools(llm.Tool{Name: "get_weather"}), llm.WithToolChoice(choice)) if err != nil { t.Fatalf("Generate(%s): %v", choice, err) } tc := cap.body["toolConfig"].(map[string]any)["functionCallingConfig"].(map[string]any) if tc["mode"] != want { t.Errorf("choice %q → mode %v, want %v", choice, tc["mode"], want) } if choice == "get_weather" { allowed := tc["allowedFunctionNames"].([]any) if allowed[0] != "get_weather" { t.Errorf("allowedFunctionNames = %v", allowed) } } } t.Run("none drops tools", func(t *testing.T) { p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, textResponse("x")) }) m, _ := p.Model("g") if _, err := m.Generate(context.Background(), basicRequest(), llm.WithTools(llm.Tool{Name: "t"}), llm.WithToolChoice("none")); err != nil { t.Fatalf("Generate: %v", err) } if _, present := cap.body["tools"]; present { t.Error("tool_choice none must omit tools") } }) } func TestStructuredOutput(t *testing.T) { p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, textResponse(`{"name":"Ada"}`)) }) m, _ := p.Model("g") schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`) resp, err := m.Generate(context.Background(), basicRequest(), llm.WithSchema(schema, "person")) if err != nil { t.Fatalf("Generate: %v", err) } genCfg := cap.body["generationConfig"].(map[string]any) if genCfg["responseMimeType"] != "application/json" { t.Errorf("responseMimeType = %v", genCfg["responseMimeType"]) } if _, ok := genCfg["responseJsonSchema"].(map[string]any); !ok { t.Errorf("responseJsonSchema = %v", genCfg["responseJsonSchema"]) } if resp.Text() != `{"name":"Ada"}` { t.Errorf("text = %q", resp.Text()) } } func TestReasoningEffortMapsToThinkingLevel(t *testing.T) { p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, textResponse("x")) }) m, _ := p.Model("g") if _, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("high")); err != nil { t.Fatalf("Generate: %v", err) } genCfg := cap.body["generationConfig"].(map[string]any) thinking := genCfg["thinkingConfig"].(map[string]any) if thinking["thinkingLevel"] != "HIGH" { t.Errorf("thinkingConfig = %v", thinking) } if _, err := m.Generate(context.Background(), basicRequest(), llm.WithReasoningEffort("ultra")); err == nil { t.Error("invalid effort should error") } } func TestFinishReasonMapping(t *testing.T) { for wire, want := range map[string]llm.FinishReason{ "STOP": llm.FinishStop, "MAX_TOKENS": llm.FinishLength, "SAFETY": llm.FinishContentFilter, "PROHIBITED_CONTENT": llm.FinishContentFilter, "MALFORMED_FUNCTION_CALL": llm.FinishOther, } { p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) { fmt.Fprintf(w, `{"candidates":[{"content":{"role":"model","parts":[{"text":"x"}]},"finishReason":%q}]}`, wire) }) m, _ := p.Model("g") resp, err := m.Generate(context.Background(), basicRequest()) if err != nil { t.Fatalf("Generate(%s): %v", wire, err) } if resp.FinishReason != want { t.Errorf("finish %q = %v, want %v", wire, resp.FinishReason, want) } } } func TestAPIErrorMapping(t *testing.T) { p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) { // no response written below; status set in the closure }) _ = p p2, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(429) _, _ = io.WriteString(w, `{"error":{"code":429,"message":"quota exhausted","status":"RESOURCE_EXHAUSTED"}}`) }) m, _ := p2.Model("g") _, err := m.Generate(context.Background(), basicRequest()) var apiErr *llm.APIError if !errors.As(err, &apiErr) { t.Fatalf("error = %v (%T), want APIError", err, err) } if apiErr.Status != 429 || !strings.Contains(apiErr.Message, "quota") { t.Errorf("apiErr = %+v", apiErr) } if llm.Classify(err) != llm.ClassTransient { t.Error("429 must classify transient") } } func TestMissingAPIKey(t *testing.T) { t.Setenv("GOOGLE_API_KEY", "") t.Setenv("GEMINI_API_KEY", "") p := New(WithAPIKey("")) m, _ := p.Model("g") _, err := m.Generate(context.Background(), basicRequest()) var apiErr *llm.APIError if !errors.As(err, &apiErr) || apiErr.Status != http.StatusUnauthorized { t.Errorf("error = %v, want synthetic 401", err) } } func TestEnvKeyPrecedence(t *testing.T) { t.Setenv("GOOGLE_API_KEY", "g-key") t.Setenv("GEMINI_API_KEY", "gem-key") if p := New(); p.apiKey != "g-key" { t.Errorf("apiKey = %q, want GOOGLE_API_KEY to win", p.apiKey) } t.Setenv("GOOGLE_API_KEY", "") if p := New(); p.apiKey != "gem-key" { t.Errorf("apiKey = %q, want GEMINI_API_KEY fallback", p.apiKey) } } func TestCapabilityEnforcement(t *testing.T) { p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, textResponse("x")) }) m, _ := p.Model("g", llm.WithCapabilities(llm.Capabilities{MaxImagesPerReq: 1, AllowedImageMIME: []string{"image/png"}})) _, err := m.Generate(context.Background(), llm.Request{Messages: []llm.Message{ llm.UserParts(llm.Image("image/png", []byte{1}), llm.Image("image/png", []byte{2})), }}) if !errors.Is(err, llm.ErrUnsupported) { t.Errorf("error = %v, want ErrUnsupported", err) } } func TestStreaming(t *testing.T) { p, cap := serve(t, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") _, _ = io.WriteString(w, `data: {"candidates":[{"content":{"role":"model","parts":[{"text":"Hel"}]}}]} data: {"candidates":[{"content":{"role":"model","parts":[{"text":"lo"}]}}]} data: {"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"ping","args":{}}}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":3,"candidatesTokenCount":6}} `) }) m, _ := p.Model("gemini-2.5-flash") s, err := m.Stream(context.Background(), basicRequest()) if err != nil { t.Fatalf("Stream: %v", err) } defer s.Close() if !strings.Contains(cap.query+cap.path, "streamGenerateContent") { t.Errorf("path = %q query = %q, want streaming endpoint", cap.path, cap.query) } var text strings.Builder var calls []llm.ToolCall var final *llm.Response for { ev, err := s.Next() if errors.Is(err, io.EOF) { break } if err != nil { t.Fatalf("Next: %v", err) } text.WriteString(ev.TextDelta) if ev.ToolCall != nil { calls = append(calls, *ev.ToolCall) } if ev.Response != nil { final = ev.Response } } if text.String() != "Hello" { t.Errorf("text = %q", text.String()) } if len(calls) != 1 || calls[0].Name != "ping" { t.Errorf("calls = %+v", calls) } if final == nil { t.Fatal("no final event") } if final.Usage.InputTokens != 3 || final.Usage.OutputTokens != 6 { t.Errorf("usage = %+v", final.Usage) } if final.FinishReason != llm.FinishToolCalls { t.Errorf("finish = %v", final.FinishReason) } } func TestStreamCloseEarly(t *testing.T) { p, _ := serve(t, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") _, _ = io.WriteString(w, "data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"x\"}]}}]}\n\n") }) m, _ := p.Model("g") s, err := m.Stream(context.Background(), basicRequest()) if err != nil { t.Fatalf("Stream: %v", err) } if err := s.Close(); err != nil { t.Errorf("Close: %v", err) } _ = s.Close() // idempotent }