package llm import ( "context" "errors" "sync/atomic" "testing" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) func TestChat_Send(t *testing.T) { mp := newMockProvider(provider.Response{Text: "Hello there!"}) model := newMockModel(mp) chat := NewChat(model) text, err := chat.Send(context.Background(), "Hi") if err != nil { t.Fatalf("unexpected error: %v", err) } if text != "Hello there!" { t.Errorf("expected 'Hello there!', got %q", text) } } func TestChat_SendMessage(t *testing.T) { mp := newMockProvider(provider.Response{Text: "reply"}) model := newMockModel(mp) chat := NewChat(model) _, err := chat.SendMessage(context.Background(), UserMessage("msg1")) if err != nil { t.Fatalf("unexpected error: %v", err) } msgs := chat.Messages() if len(msgs) != 2 { t.Fatalf("expected 2 messages (user + assistant), got %d", len(msgs)) } if msgs[0].Role != RoleUser { t.Errorf("expected first message role=user, got %v", msgs[0].Role) } if msgs[0].Content.Text != "msg1" { t.Errorf("expected first message text='msg1', got %q", msgs[0].Content.Text) } if msgs[1].Role != RoleAssistant { t.Errorf("expected second message role=assistant, got %v", msgs[1].Role) } if msgs[1].Content.Text != "reply" { t.Errorf("expected second message text='reply', got %q", msgs[1].Content.Text) } } func TestChat_SetSystem(t *testing.T) { mp := newMockProvider(provider.Response{Text: "ok"}) model := newMockModel(mp) chat := NewChat(model) chat.SetSystem("You are a bot") msgs := chat.Messages() if len(msgs) != 1 { t.Fatalf("expected 1 message, got %d", len(msgs)) } if msgs[0].Role != RoleSystem { t.Errorf("expected role=system, got %v", msgs[0].Role) } if msgs[0].Content.Text != "You are a bot" { t.Errorf("expected system text, got %q", msgs[0].Content.Text) } // Replace system message chat.SetSystem("You are a helpful bot") msgs = chat.Messages() if len(msgs) != 1 { t.Fatalf("expected 1 message after replace, got %d", len(msgs)) } if msgs[0].Content.Text != "You are a helpful bot" { t.Errorf("expected replaced system text, got %q", msgs[0].Content.Text) } // System message stays first even after adding other messages _, _ = chat.Send(context.Background(), "Hi") chat.SetSystem("New system") msgs = chat.Messages() if msgs[0].Role != RoleSystem { t.Errorf("expected system as first message, got %v", msgs[0].Role) } if msgs[0].Content.Text != "New system" { t.Errorf("expected 'New system', got %q", msgs[0].Content.Text) } } func TestChat_ToolCallLoop(t *testing.T) { var callCount int32 mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { n := atomic.AddInt32(&callCount, 1) if n == 1 { // First call: request a tool return provider.Response{ ToolCalls: []provider.ToolCall{ {ID: "tc1", Name: "greet", Arguments: "{}"}, }, }, nil } // Second call: return text return provider.Response{Text: "done"}, nil }) model := newMockModel(mp) chat := NewChat(model) tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) { return "hello!", nil }) chat.SetTools(NewToolBox(tool)) text, err := chat.Send(context.Background(), "test") if err != nil { t.Fatalf("unexpected error: %v", err) } if text != "done" { t.Errorf("expected 'done', got %q", text) } if atomic.LoadInt32(&callCount) != 2 { t.Errorf("expected 2 provider calls, got %d", callCount) } // Check message history: user, assistant (tool call), tool result, assistant (text) msgs := chat.Messages() if len(msgs) != 4 { t.Fatalf("expected 4 messages, got %d", len(msgs)) } if msgs[0].Role != RoleUser { t.Errorf("msg[0]: expected user, got %v", msgs[0].Role) } if msgs[1].Role != RoleAssistant { t.Errorf("msg[1]: expected assistant, got %v", msgs[1].Role) } if len(msgs[1].ToolCalls) != 1 { t.Errorf("msg[1]: expected 1 tool call, got %d", len(msgs[1].ToolCalls)) } if msgs[2].Role != RoleTool { t.Errorf("msg[2]: expected tool, got %v", msgs[2].Role) } if msgs[2].Content.Text != "hello!" { t.Errorf("msg[2]: expected 'hello!', got %q", msgs[2].Content.Text) } if msgs[3].Role != RoleAssistant { t.Errorf("msg[3]: expected assistant, got %v", msgs[3].Role) } } func TestChat_ToolCallLoop_NoTools(t *testing.T) { mp := newMockProvider(provider.Response{ ToolCalls: []provider.ToolCall{ {ID: "tc1", Name: "fake", Arguments: "{}"}, }, }) model := newMockModel(mp) chat := NewChat(model) _, err := chat.Send(context.Background(), "test") if !errors.Is(err, ErrNoToolsConfigured) { t.Errorf("expected ErrNoToolsConfigured, got %v", err) } } func TestChat_SendRaw(t *testing.T) { mp := newMockProvider(provider.Response{ Text: "raw response", ToolCalls: []provider.ToolCall{ {ID: "tc1", Name: "tool1", Arguments: `{"x":1}`}, }, }) model := newMockModel(mp) chat := NewChat(model) resp, err := chat.SendRaw(context.Background(), UserMessage("test")) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "raw response" { t.Errorf("expected 'raw response', got %q", resp.Text) } if !resp.HasToolCalls() { t.Error("expected HasToolCalls() to be true") } if len(resp.ToolCalls) != 1 { t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls)) } if resp.ToolCalls[0].Name != "tool1" { t.Errorf("expected tool name 'tool1', got %q", resp.ToolCalls[0].Name) } } func TestChat_SendRaw_ManualToolResults(t *testing.T) { var callCount int32 mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { n := atomic.AddInt32(&callCount, 1) if n == 1 { return provider.Response{ ToolCalls: []provider.ToolCall{ {ID: "tc1", Name: "tool1", Arguments: "{}"}, }, }, nil } return provider.Response{Text: "final"}, nil }) model := newMockModel(mp) chat := NewChat(model) // First call returns tool calls resp, err := chat.SendRaw(context.Background(), UserMessage("test")) if err != nil { t.Fatalf("unexpected error: %v", err) } if !resp.HasToolCalls() { t.Fatal("expected tool calls") } // Manually add tool result chat.AddToolResults(ToolResultMessage("tc1", "tool result")) // Second call returns text resp, err = chat.SendRaw(context.Background(), UserMessage("continue")) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "final" { t.Errorf("expected 'final', got %q", resp.Text) } // Check the full history msgs := chat.Messages() // user, assistant(tool call), tool result, user, assistant(text) if len(msgs) != 5 { t.Fatalf("expected 5 messages, got %d", len(msgs)) } if msgs[2].Role != RoleTool { t.Errorf("expected msg[2] role=tool, got %v", msgs[2].Role) } if msgs[2].ToolCallID != "tc1" { t.Errorf("expected msg[2] toolCallID=tc1, got %q", msgs[2].ToolCallID) } } func TestChat_Messages(t *testing.T) { mp := newMockProvider(provider.Response{Text: "ok"}) model := newMockModel(mp) chat := NewChat(model) _, _ = chat.Send(context.Background(), "test") msgs := chat.Messages() // Verify it's a copy — modifying returned slice shouldn't affect chat msgs[0] = Message{} original := chat.Messages() if original[0].Role != RoleUser { t.Error("Messages() did not return a copy") } } func TestChat_Reset(t *testing.T) { mp := newMockProvider(provider.Response{Text: "ok"}) model := newMockModel(mp) chat := NewChat(model) _, _ = chat.Send(context.Background(), "test") if len(chat.Messages()) == 0 { t.Fatal("expected messages before reset") } chat.Reset() if len(chat.Messages()) != 0 { t.Errorf("expected 0 messages after reset, got %d", len(chat.Messages())) } } func TestChat_Fork(t *testing.T) { mp := newMockProvider(provider.Response{Text: "ok"}) model := newMockModel(mp) chat := NewChat(model) _, _ = chat.Send(context.Background(), "msg1") fork := chat.Fork() // Fork should have same history if len(fork.Messages()) != len(chat.Messages()) { t.Fatalf("fork should have same message count: got %d vs %d", len(fork.Messages()), len(chat.Messages())) } // Adding to fork should not affect original _, _ = fork.Send(context.Background(), "msg2") if len(fork.Messages()) == len(chat.Messages()) { t.Error("fork messages should be independent of original") } // Adding to original should not affect fork originalLen := len(chat.Messages()) _, _ = chat.Send(context.Background(), "msg3") if len(chat.Messages()) == originalLen { t.Error("original should have more messages after send") } } func TestChat_SendWithImages(t *testing.T) { mp := newMockProvider(provider.Response{Text: "I see an image"}) model := newMockModel(mp) chat := NewChat(model) img := Image{URL: "https://example.com/image.png"} text, err := chat.SendWithImages(context.Background(), "What's in this image?", img) if err != nil { t.Fatalf("unexpected error: %v", err) } if text != "I see an image" { t.Errorf("expected 'I see an image', got %q", text) } // Verify the image was passed through to the provider req := mp.lastRequest() if len(req.Messages) == 0 { t.Fatal("expected messages in request") } lastUserMsg := req.Messages[0] if len(lastUserMsg.Images) != 1 { t.Fatalf("expected 1 image, got %d", len(lastUserMsg.Images)) } if lastUserMsg.Images[0].URL != "https://example.com/image.png" { t.Errorf("expected image URL, got %q", lastUserMsg.Images[0].URL) } } func TestChat_MultipleToolCallRounds(t *testing.T) { var callCount int32 mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { n := atomic.AddInt32(&callCount, 1) if n <= 3 { return provider.Response{ ToolCalls: []provider.ToolCall{ {ID: "tc" + string(rune('0'+n)), Name: "counter", Arguments: "{}"}, }, }, nil } return provider.Response{Text: "all done"}, nil }) model := newMockModel(mp) chat := NewChat(model) var execCount int32 tool := DefineSimple("counter", "Counts", func(ctx context.Context) (string, error) { atomic.AddInt32(&execCount, 1) return "counted", nil }) chat.SetTools(NewToolBox(tool)) text, err := chat.Send(context.Background(), "count three times") if err != nil { t.Fatalf("unexpected error: %v", err) } if text != "all done" { t.Errorf("expected 'all done', got %q", text) } if atomic.LoadInt32(&callCount) != 4 { t.Errorf("expected 4 provider calls, got %d", callCount) } if atomic.LoadInt32(&execCount) != 3 { t.Errorf("expected 3 tool executions, got %d", execCount) } } func TestChat_SendError(t *testing.T) { wantErr := errors.New("provider failed") mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, wantErr }) model := newMockModel(mp) chat := NewChat(model) _, err := chat.Send(context.Background(), "test") if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, wantErr) { t.Errorf("expected wrapped provider error, got %v", err) } } func TestChat_WithRequestOptions(t *testing.T) { mp := newMockProvider(provider.Response{Text: "ok"}) model := newMockModel(mp) chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200)) _, err := chat.Send(context.Background(), "test") if err != nil { t.Fatalf("unexpected error: %v", err) } req := mp.lastRequest() if req.Temperature == nil || *req.Temperature != 0.5 { t.Errorf("expected temperature 0.5, got %v", req.Temperature) } if req.MaxTokens == nil || *req.MaxTokens != 200 { t.Errorf("expected maxTokens 200, got %v", req.MaxTokens) } }