Files
go-llm/v2/chat_test.go
Steve Dudenhoeffer 6a7eeef619
All checks were successful
CI / Lint (push) Successful in 9m36s
CI / V2 Module (push) Successful in 11m33s
CI / Root Module (push) Successful in 11m35s
Add comprehensive test suite for v2 module with mock provider
Cover all core library logic (Client, Model, Chat, middleware, streaming,
message conversion, request building) using a configurable mock provider
that avoids real API calls. ~50 tests across 7 files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 22:00:49 -05:00

408 lines
11 KiB
Go

package llm
import (
"context"
"errors"
"sync/atomic"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestChat_Send(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "Hello there!"})
model := newMockModel(mp)
chat := NewChat(model)
text, err := chat.Send(context.Background(), "Hi")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "Hello there!" {
t.Errorf("expected 'Hello there!', got %q", text)
}
}
func TestChat_SendMessage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "reply"})
model := newMockModel(mp)
chat := NewChat(model)
_, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
msgs := chat.Messages()
if len(msgs) != 2 {
t.Fatalf("expected 2 messages (user + assistant), got %d", len(msgs))
}
if msgs[0].Role != RoleUser {
t.Errorf("expected first message role=user, got %v", msgs[0].Role)
}
if msgs[0].Content.Text != "msg1" {
t.Errorf("expected first message text='msg1', got %q", msgs[0].Content.Text)
}
if msgs[1].Role != RoleAssistant {
t.Errorf("expected second message role=assistant, got %v", msgs[1].Role)
}
if msgs[1].Content.Text != "reply" {
t.Errorf("expected second message text='reply', got %q", msgs[1].Content.Text)
}
}
func TestChat_SetSystem(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
chat.SetSystem("You are a bot")
msgs := chat.Messages()
if len(msgs) != 1 {
t.Fatalf("expected 1 message, got %d", len(msgs))
}
if msgs[0].Role != RoleSystem {
t.Errorf("expected role=system, got %v", msgs[0].Role)
}
if msgs[0].Content.Text != "You are a bot" {
t.Errorf("expected system text, got %q", msgs[0].Content.Text)
}
// Replace system message
chat.SetSystem("You are a helpful bot")
msgs = chat.Messages()
if len(msgs) != 1 {
t.Fatalf("expected 1 message after replace, got %d", len(msgs))
}
if msgs[0].Content.Text != "You are a helpful bot" {
t.Errorf("expected replaced system text, got %q", msgs[0].Content.Text)
}
// System message stays first even after adding other messages
_, _ = chat.Send(context.Background(), "Hi")
chat.SetSystem("New system")
msgs = chat.Messages()
if msgs[0].Role != RoleSystem {
t.Errorf("expected system as first message, got %v", msgs[0].Role)
}
if msgs[0].Content.Text != "New system" {
t.Errorf("expected 'New system', got %q", msgs[0].Content.Text)
}
}
func TestChat_ToolCallLoop(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n == 1 {
// First call: request a tool
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "greet", Arguments: "{}"},
},
}, nil
}
// Second call: return text
return provider.Response{Text: "done"}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello!", nil
})
chat.SetTools(NewToolBox(tool))
text, err := chat.Send(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "done" {
t.Errorf("expected 'done', got %q", text)
}
if atomic.LoadInt32(&callCount) != 2 {
t.Errorf("expected 2 provider calls, got %d", callCount)
}
// Check message history: user, assistant (tool call), tool result, assistant (text)
msgs := chat.Messages()
if len(msgs) != 4 {
t.Fatalf("expected 4 messages, got %d", len(msgs))
}
if msgs[0].Role != RoleUser {
t.Errorf("msg[0]: expected user, got %v", msgs[0].Role)
}
if msgs[1].Role != RoleAssistant {
t.Errorf("msg[1]: expected assistant, got %v", msgs[1].Role)
}
if len(msgs[1].ToolCalls) != 1 {
t.Errorf("msg[1]: expected 1 tool call, got %d", len(msgs[1].ToolCalls))
}
if msgs[2].Role != RoleTool {
t.Errorf("msg[2]: expected tool, got %v", msgs[2].Role)
}
if msgs[2].Content.Text != "hello!" {
t.Errorf("msg[2]: expected 'hello!', got %q", msgs[2].Content.Text)
}
if msgs[3].Role != RoleAssistant {
t.Errorf("msg[3]: expected assistant, got %v", msgs[3].Role)
}
}
func TestChat_ToolCallLoop_NoTools(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "fake", Arguments: "{}"},
},
})
model := newMockModel(mp)
chat := NewChat(model)
_, err := chat.Send(context.Background(), "test")
if !errors.Is(err, ErrNoToolsConfigured) {
t.Errorf("expected ErrNoToolsConfigured, got %v", err)
}
}
func TestChat_SendRaw(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "raw response",
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "tool1", Arguments: `{"x":1}`},
},
})
model := newMockModel(mp)
chat := NewChat(model)
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "raw response" {
t.Errorf("expected 'raw response', got %q", resp.Text)
}
if !resp.HasToolCalls() {
t.Error("expected HasToolCalls() to be true")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "tool1" {
t.Errorf("expected tool name 'tool1', got %q", resp.ToolCalls[0].Name)
}
}
func TestChat_SendRaw_ManualToolResults(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n == 1 {
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "tool1", Arguments: "{}"},
},
}, nil
}
return provider.Response{Text: "final"}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
// First call returns tool calls
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !resp.HasToolCalls() {
t.Fatal("expected tool calls")
}
// Manually add tool result
chat.AddToolResults(ToolResultMessage("tc1", "tool result"))
// Second call returns text
resp, err = chat.SendRaw(context.Background(), UserMessage("continue"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "final" {
t.Errorf("expected 'final', got %q", resp.Text)
}
// Check the full history
msgs := chat.Messages()
// user, assistant(tool call), tool result, user, assistant(text)
if len(msgs) != 5 {
t.Fatalf("expected 5 messages, got %d", len(msgs))
}
if msgs[2].Role != RoleTool {
t.Errorf("expected msg[2] role=tool, got %v", msgs[2].Role)
}
if msgs[2].ToolCallID != "tc1" {
t.Errorf("expected msg[2] toolCallID=tc1, got %q", msgs[2].ToolCallID)
}
}
func TestChat_Messages(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
_, _ = chat.Send(context.Background(), "test")
msgs := chat.Messages()
// Verify it's a copy — modifying returned slice shouldn't affect chat
msgs[0] = Message{}
original := chat.Messages()
if original[0].Role != RoleUser {
t.Error("Messages() did not return a copy")
}
}
func TestChat_Reset(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
_, _ = chat.Send(context.Background(), "test")
if len(chat.Messages()) == 0 {
t.Fatal("expected messages before reset")
}
chat.Reset()
if len(chat.Messages()) != 0 {
t.Errorf("expected 0 messages after reset, got %d", len(chat.Messages()))
}
}
func TestChat_Fork(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
_, _ = chat.Send(context.Background(), "msg1")
fork := chat.Fork()
// Fork should have same history
if len(fork.Messages()) != len(chat.Messages()) {
t.Fatalf("fork should have same message count: got %d vs %d", len(fork.Messages()), len(chat.Messages()))
}
// Adding to fork should not affect original
_, _ = fork.Send(context.Background(), "msg2")
if len(fork.Messages()) == len(chat.Messages()) {
t.Error("fork messages should be independent of original")
}
// Adding to original should not affect fork
originalLen := len(chat.Messages())
_, _ = chat.Send(context.Background(), "msg3")
if len(chat.Messages()) == originalLen {
t.Error("original should have more messages after send")
}
}
func TestChat_SendWithImages(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "I see an image"})
model := newMockModel(mp)
chat := NewChat(model)
img := Image{URL: "https://example.com/image.png"}
text, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "I see an image" {
t.Errorf("expected 'I see an image', got %q", text)
}
// Verify the image was passed through to the provider
req := mp.lastRequest()
if len(req.Messages) == 0 {
t.Fatal("expected messages in request")
}
lastUserMsg := req.Messages[0]
if len(lastUserMsg.Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(lastUserMsg.Images))
}
if lastUserMsg.Images[0].URL != "https://example.com/image.png" {
t.Errorf("expected image URL, got %q", lastUserMsg.Images[0].URL)
}
}
func TestChat_MultipleToolCallRounds(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n <= 3 {
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc" + string(rune('0'+n)), Name: "counter", Arguments: "{}"},
},
}, nil
}
return provider.Response{Text: "all done"}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
var execCount int32
tool := DefineSimple("counter", "Counts", func(ctx context.Context) (string, error) {
atomic.AddInt32(&execCount, 1)
return "counted", nil
})
chat.SetTools(NewToolBox(tool))
text, err := chat.Send(context.Background(), "count three times")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "all done" {
t.Errorf("expected 'all done', got %q", text)
}
if atomic.LoadInt32(&callCount) != 4 {
t.Errorf("expected 4 provider calls, got %d", callCount)
}
if atomic.LoadInt32(&execCount) != 3 {
t.Errorf("expected 3 tool executions, got %d", execCount)
}
}
func TestChat_SendError(t *testing.T) {
wantErr := errors.New("provider failed")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, wantErr
})
model := newMockModel(mp)
chat := NewChat(model)
_, err := chat.Send(context.Background(), "test")
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, wantErr) {
t.Errorf("expected wrapped provider error, got %v", err)
}
}
func TestChat_WithRequestOptions(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
_, err := chat.Send(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if req.Temperature == nil || *req.Temperature != 0.5 {
t.Errorf("expected temperature 0.5, got %v", req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != 200 {
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
}
}