Cover all core library logic (Client, Model, Chat, middleware, streaming, message conversion, request building) using a configurable mock provider that avoids real API calls. ~50 tests across 7 files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
408 lines
11 KiB
Go
408 lines
11 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
)
|
|
|
|
func TestChat_Send(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "Hello there!"})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
text, err := chat.Send(context.Background(), "Hi")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if text != "Hello there!" {
|
|
t.Errorf("expected 'Hello there!', got %q", text)
|
|
}
|
|
}
|
|
|
|
func TestChat_SendMessage(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "reply"})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
_, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
msgs := chat.Messages()
|
|
if len(msgs) != 2 {
|
|
t.Fatalf("expected 2 messages (user + assistant), got %d", len(msgs))
|
|
}
|
|
if msgs[0].Role != RoleUser {
|
|
t.Errorf("expected first message role=user, got %v", msgs[0].Role)
|
|
}
|
|
if msgs[0].Content.Text != "msg1" {
|
|
t.Errorf("expected first message text='msg1', got %q", msgs[0].Content.Text)
|
|
}
|
|
if msgs[1].Role != RoleAssistant {
|
|
t.Errorf("expected second message role=assistant, got %v", msgs[1].Role)
|
|
}
|
|
if msgs[1].Content.Text != "reply" {
|
|
t.Errorf("expected second message text='reply', got %q", msgs[1].Content.Text)
|
|
}
|
|
}
|
|
|
|
func TestChat_SetSystem(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
chat.SetSystem("You are a bot")
|
|
msgs := chat.Messages()
|
|
if len(msgs) != 1 {
|
|
t.Fatalf("expected 1 message, got %d", len(msgs))
|
|
}
|
|
if msgs[0].Role != RoleSystem {
|
|
t.Errorf("expected role=system, got %v", msgs[0].Role)
|
|
}
|
|
if msgs[0].Content.Text != "You are a bot" {
|
|
t.Errorf("expected system text, got %q", msgs[0].Content.Text)
|
|
}
|
|
|
|
// Replace system message
|
|
chat.SetSystem("You are a helpful bot")
|
|
msgs = chat.Messages()
|
|
if len(msgs) != 1 {
|
|
t.Fatalf("expected 1 message after replace, got %d", len(msgs))
|
|
}
|
|
if msgs[0].Content.Text != "You are a helpful bot" {
|
|
t.Errorf("expected replaced system text, got %q", msgs[0].Content.Text)
|
|
}
|
|
|
|
// System message stays first even after adding other messages
|
|
_, _ = chat.Send(context.Background(), "Hi")
|
|
chat.SetSystem("New system")
|
|
msgs = chat.Messages()
|
|
if msgs[0].Role != RoleSystem {
|
|
t.Errorf("expected system as first message, got %v", msgs[0].Role)
|
|
}
|
|
if msgs[0].Content.Text != "New system" {
|
|
t.Errorf("expected 'New system', got %q", msgs[0].Content.Text)
|
|
}
|
|
}
|
|
|
|
func TestChat_ToolCallLoop(t *testing.T) {
|
|
var callCount int32
|
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
n := atomic.AddInt32(&callCount, 1)
|
|
if n == 1 {
|
|
// First call: request a tool
|
|
return provider.Response{
|
|
ToolCalls: []provider.ToolCall{
|
|
{ID: "tc1", Name: "greet", Arguments: "{}"},
|
|
},
|
|
}, nil
|
|
}
|
|
// Second call: return text
|
|
return provider.Response{Text: "done"}, nil
|
|
})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
|
return "hello!", nil
|
|
})
|
|
chat.SetTools(NewToolBox(tool))
|
|
|
|
text, err := chat.Send(context.Background(), "test")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if text != "done" {
|
|
t.Errorf("expected 'done', got %q", text)
|
|
}
|
|
if atomic.LoadInt32(&callCount) != 2 {
|
|
t.Errorf("expected 2 provider calls, got %d", callCount)
|
|
}
|
|
|
|
// Check message history: user, assistant (tool call), tool result, assistant (text)
|
|
msgs := chat.Messages()
|
|
if len(msgs) != 4 {
|
|
t.Fatalf("expected 4 messages, got %d", len(msgs))
|
|
}
|
|
if msgs[0].Role != RoleUser {
|
|
t.Errorf("msg[0]: expected user, got %v", msgs[0].Role)
|
|
}
|
|
if msgs[1].Role != RoleAssistant {
|
|
t.Errorf("msg[1]: expected assistant, got %v", msgs[1].Role)
|
|
}
|
|
if len(msgs[1].ToolCalls) != 1 {
|
|
t.Errorf("msg[1]: expected 1 tool call, got %d", len(msgs[1].ToolCalls))
|
|
}
|
|
if msgs[2].Role != RoleTool {
|
|
t.Errorf("msg[2]: expected tool, got %v", msgs[2].Role)
|
|
}
|
|
if msgs[2].Content.Text != "hello!" {
|
|
t.Errorf("msg[2]: expected 'hello!', got %q", msgs[2].Content.Text)
|
|
}
|
|
if msgs[3].Role != RoleAssistant {
|
|
t.Errorf("msg[3]: expected assistant, got %v", msgs[3].Role)
|
|
}
|
|
}
|
|
|
|
func TestChat_ToolCallLoop_NoTools(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{
|
|
ToolCalls: []provider.ToolCall{
|
|
{ID: "tc1", Name: "fake", Arguments: "{}"},
|
|
},
|
|
})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
_, err := chat.Send(context.Background(), "test")
|
|
if !errors.Is(err, ErrNoToolsConfigured) {
|
|
t.Errorf("expected ErrNoToolsConfigured, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestChat_SendRaw(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{
|
|
Text: "raw response",
|
|
ToolCalls: []provider.ToolCall{
|
|
{ID: "tc1", Name: "tool1", Arguments: `{"x":1}`},
|
|
},
|
|
})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.Text != "raw response" {
|
|
t.Errorf("expected 'raw response', got %q", resp.Text)
|
|
}
|
|
if !resp.HasToolCalls() {
|
|
t.Error("expected HasToolCalls() to be true")
|
|
}
|
|
if len(resp.ToolCalls) != 1 {
|
|
t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
|
|
}
|
|
if resp.ToolCalls[0].Name != "tool1" {
|
|
t.Errorf("expected tool name 'tool1', got %q", resp.ToolCalls[0].Name)
|
|
}
|
|
}
|
|
|
|
func TestChat_SendRaw_ManualToolResults(t *testing.T) {
|
|
var callCount int32
|
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
n := atomic.AddInt32(&callCount, 1)
|
|
if n == 1 {
|
|
return provider.Response{
|
|
ToolCalls: []provider.ToolCall{
|
|
{ID: "tc1", Name: "tool1", Arguments: "{}"},
|
|
},
|
|
}, nil
|
|
}
|
|
return provider.Response{Text: "final"}, nil
|
|
})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
// First call returns tool calls
|
|
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !resp.HasToolCalls() {
|
|
t.Fatal("expected tool calls")
|
|
}
|
|
|
|
// Manually add tool result
|
|
chat.AddToolResults(ToolResultMessage("tc1", "tool result"))
|
|
|
|
// Second call returns text
|
|
resp, err = chat.SendRaw(context.Background(), UserMessage("continue"))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.Text != "final" {
|
|
t.Errorf("expected 'final', got %q", resp.Text)
|
|
}
|
|
|
|
// Check the full history
|
|
msgs := chat.Messages()
|
|
// user, assistant(tool call), tool result, user, assistant(text)
|
|
if len(msgs) != 5 {
|
|
t.Fatalf("expected 5 messages, got %d", len(msgs))
|
|
}
|
|
if msgs[2].Role != RoleTool {
|
|
t.Errorf("expected msg[2] role=tool, got %v", msgs[2].Role)
|
|
}
|
|
if msgs[2].ToolCallID != "tc1" {
|
|
t.Errorf("expected msg[2] toolCallID=tc1, got %q", msgs[2].ToolCallID)
|
|
}
|
|
}
|
|
|
|
func TestChat_Messages(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
_, _ = chat.Send(context.Background(), "test")
|
|
|
|
msgs := chat.Messages()
|
|
// Verify it's a copy — modifying returned slice shouldn't affect chat
|
|
msgs[0] = Message{}
|
|
|
|
original := chat.Messages()
|
|
if original[0].Role != RoleUser {
|
|
t.Error("Messages() did not return a copy")
|
|
}
|
|
}
|
|
|
|
func TestChat_Reset(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
_, _ = chat.Send(context.Background(), "test")
|
|
if len(chat.Messages()) == 0 {
|
|
t.Fatal("expected messages before reset")
|
|
}
|
|
|
|
chat.Reset()
|
|
if len(chat.Messages()) != 0 {
|
|
t.Errorf("expected 0 messages after reset, got %d", len(chat.Messages()))
|
|
}
|
|
}
|
|
|
|
func TestChat_Fork(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
_, _ = chat.Send(context.Background(), "msg1")
|
|
|
|
fork := chat.Fork()
|
|
|
|
// Fork should have same history
|
|
if len(fork.Messages()) != len(chat.Messages()) {
|
|
t.Fatalf("fork should have same message count: got %d vs %d", len(fork.Messages()), len(chat.Messages()))
|
|
}
|
|
|
|
// Adding to fork should not affect original
|
|
_, _ = fork.Send(context.Background(), "msg2")
|
|
if len(fork.Messages()) == len(chat.Messages()) {
|
|
t.Error("fork messages should be independent of original")
|
|
}
|
|
|
|
// Adding to original should not affect fork
|
|
originalLen := len(chat.Messages())
|
|
_, _ = chat.Send(context.Background(), "msg3")
|
|
if len(chat.Messages()) == originalLen {
|
|
t.Error("original should have more messages after send")
|
|
}
|
|
}
|
|
|
|
func TestChat_SendWithImages(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "I see an image"})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
img := Image{URL: "https://example.com/image.png"}
|
|
text, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if text != "I see an image" {
|
|
t.Errorf("expected 'I see an image', got %q", text)
|
|
}
|
|
|
|
// Verify the image was passed through to the provider
|
|
req := mp.lastRequest()
|
|
if len(req.Messages) == 0 {
|
|
t.Fatal("expected messages in request")
|
|
}
|
|
lastUserMsg := req.Messages[0]
|
|
if len(lastUserMsg.Images) != 1 {
|
|
t.Fatalf("expected 1 image, got %d", len(lastUserMsg.Images))
|
|
}
|
|
if lastUserMsg.Images[0].URL != "https://example.com/image.png" {
|
|
t.Errorf("expected image URL, got %q", lastUserMsg.Images[0].URL)
|
|
}
|
|
}
|
|
|
|
func TestChat_MultipleToolCallRounds(t *testing.T) {
|
|
var callCount int32
|
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
n := atomic.AddInt32(&callCount, 1)
|
|
if n <= 3 {
|
|
return provider.Response{
|
|
ToolCalls: []provider.ToolCall{
|
|
{ID: "tc" + string(rune('0'+n)), Name: "counter", Arguments: "{}"},
|
|
},
|
|
}, nil
|
|
}
|
|
return provider.Response{Text: "all done"}, nil
|
|
})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
var execCount int32
|
|
tool := DefineSimple("counter", "Counts", func(ctx context.Context) (string, error) {
|
|
atomic.AddInt32(&execCount, 1)
|
|
return "counted", nil
|
|
})
|
|
chat.SetTools(NewToolBox(tool))
|
|
|
|
text, err := chat.Send(context.Background(), "count three times")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if text != "all done" {
|
|
t.Errorf("expected 'all done', got %q", text)
|
|
}
|
|
if atomic.LoadInt32(&callCount) != 4 {
|
|
t.Errorf("expected 4 provider calls, got %d", callCount)
|
|
}
|
|
if atomic.LoadInt32(&execCount) != 3 {
|
|
t.Errorf("expected 3 tool executions, got %d", execCount)
|
|
}
|
|
}
|
|
|
|
func TestChat_SendError(t *testing.T) {
|
|
wantErr := errors.New("provider failed")
|
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
return provider.Response{}, wantErr
|
|
})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model)
|
|
|
|
_, err := chat.Send(context.Background(), "test")
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if !errors.Is(err, wantErr) {
|
|
t.Errorf("expected wrapped provider error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestChat_WithRequestOptions(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
|
model := newMockModel(mp)
|
|
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
|
|
|
|
_, err := chat.Send(context.Background(), "test")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
req := mp.lastRequest()
|
|
if req.Temperature == nil || *req.Temperature != 0.5 {
|
|
t.Errorf("expected temperature 0.5, got %v", req.Temperature)
|
|
}
|
|
if req.MaxTokens == nil || *req.MaxTokens != 200 {
|
|
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
|
|
}
|
|
}
|