package llm import ( "context" "errors" "testing" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) type testPerson struct { Name string `json:"name" description:"The person's name"` Age int `json:"age" description:"The person's age"` } func TestGenerate(t *testing.T) { mp := newMockProvider(provider.Response{ ToolCalls: []provider.ToolCall{ { ID: "call_1", Name: "structured_output", Arguments: `{"name":"Alice","age":30}`, }, }, }) model := newMockModel(mp) result, err := Generate[testPerson](context.Background(), model, "Tell me about Alice") if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Name != "Alice" { t.Errorf("expected name 'Alice', got %q", result.Name) } if result.Age != 30 { t.Errorf("expected age 30, got %d", result.Age) } // Verify the tool was sent in the request req := mp.lastRequest() if len(req.Tools) != 1 { t.Fatalf("expected 1 tool, got %d", len(req.Tools)) } if req.Tools[0].Name != "structured_output" { t.Errorf("expected tool name 'structured_output', got %q", req.Tools[0].Name) } } func TestGenerateWith(t *testing.T) { mp := newMockProvider(provider.Response{ ToolCalls: []provider.ToolCall{ { ID: "call_1", Name: "structured_output", Arguments: `{"name":"Bob","age":25}`, }, }, }) model := newMockModel(mp) messages := []Message{ SystemMessage("You are helpful."), UserMessage("Tell me about Bob"), } result, err := GenerateWith[testPerson](context.Background(), model, messages) if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Name != "Bob" { t.Errorf("expected name 'Bob', got %q", result.Name) } if result.Age != 25 { t.Errorf("expected age 25, got %d", result.Age) } // Verify messages were passed through req := mp.lastRequest() if len(req.Messages) != 2 { t.Fatalf("expected 2 messages, got %d", len(req.Messages)) } if req.Messages[0].Role != "system" { t.Errorf("expected first message role 'system', got %q", req.Messages[0].Role) } } func TestGenerate_NoToolCall(t *testing.T) { mp := newMockProvider(provider.Response{ Text: "I can't use tools right now.", }) model := newMockModel(mp) _, err := Generate[testPerson](context.Background(), model, "Tell me about someone") if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, ErrNoStructuredOutput) { t.Errorf("expected ErrNoStructuredOutput, got %v", err) } } func TestGenerate_InvalidJSON(t *testing.T) { mp := newMockProvider(provider.Response{ ToolCalls: []provider.ToolCall{ { ID: "call_1", Name: "structured_output", Arguments: `{not valid json}`, }, }, }) model := newMockModel(mp) _, err := Generate[testPerson](context.Background(), model, "Tell me about someone") if err == nil { t.Fatal("expected error, got nil") } if errors.Is(err, ErrNoStructuredOutput) { t.Error("expected parse error, not ErrNoStructuredOutput") } } type testAddress struct { Street string `json:"street" description:"Street address"` City string `json:"city" description:"City name"` } type testPersonWithAddress struct { Name string `json:"name" description:"The person's name"` Age int `json:"age" description:"The person's age"` Address testAddress `json:"address" description:"The person's address"` } func TestGenerate_NestedStruct(t *testing.T) { mp := newMockProvider(provider.Response{ ToolCalls: []provider.ToolCall{ { ID: "call_1", Name: "structured_output", Arguments: `{"name":"Carol","age":40,"address":{"street":"123 Main St","city":"Springfield"}}`, }, }, }) model := newMockModel(mp) result, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol") if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Name != "Carol" { t.Errorf("expected name 'Carol', got %q", result.Name) } if result.Address.Street != "123 Main St" { t.Errorf("expected street '123 Main St', got %q", result.Address.Street) } if result.Address.City != "Springfield" { t.Errorf("expected city 'Springfield', got %q", result.Address.City) } } func TestGenerate_WithOptions(t *testing.T) { mp := newMockProvider(provider.Response{ ToolCalls: []provider.ToolCall{ { ID: "call_1", Name: "structured_output", Arguments: `{"name":"Dave","age":35}`, }, }, }) model := newMockModel(mp) _, err := Generate[testPerson](context.Background(), model, "Tell me about Dave", WithTemperature(0.5), WithMaxTokens(200), ) if err != nil { t.Fatalf("unexpected error: %v", err) } req := mp.lastRequest() if req.Temperature == nil || *req.Temperature != 0.5 { t.Errorf("expected temperature 0.5, got %v", req.Temperature) } if req.MaxTokens == nil || *req.MaxTokens != 200 { t.Errorf("expected maxTokens 200, got %v", req.MaxTokens) } } func TestGenerate_WithMiddleware(t *testing.T) { var middlewareCalled bool mw := func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { middlewareCalled = true return next(ctx, model, messages, cfg) } } mp := newMockProvider(provider.Response{ ToolCalls: []provider.ToolCall{ { ID: "call_1", Name: "structured_output", Arguments: `{"name":"Eve","age":28}`, }, }, }) model := newMockModel(mp).WithMiddleware(mw) result, err := Generate[testPerson](context.Background(), model, "Tell me about Eve") if err != nil { t.Fatalf("unexpected error: %v", err) } if !middlewareCalled { t.Error("middleware was not called") } if result.Name != "Eve" { t.Errorf("expected name 'Eve', got %q", result.Name) } } func TestGenerate_WrongToolName(t *testing.T) { mp := newMockProvider(provider.Response{ ToolCalls: []provider.ToolCall{ { ID: "call_1", Name: "some_other_tool", Arguments: `{"name":"Frank","age":50}`, }, }, }) model := newMockModel(mp) _, err := Generate[testPerson](context.Background(), model, "Tell me about Frank") if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, ErrNoStructuredOutput) { t.Errorf("expected ErrNoStructuredOutput, got %v", err) } }