Add comprehensive test suite for v2 module with mock provider
All checks were successful
CI / Lint (push) Successful in 9m36s
CI / V2 Module (push) Successful in 11m33s
CI / Root Module (push) Successful in 11m35s

Cover all core library logic (Client, Model, Chat, middleware, streaming,
message conversion, request building) using a configurable mock provider
that avoids real API calls. ~50 tests across 7 files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-07 22:00:49 -05:00
parent cbe340ced0
commit 6a7eeef619
7 changed files with 1678 additions and 0 deletions

407
v2/chat_test.go Normal file
View File

@@ -0,0 +1,407 @@
package llm
import (
"context"
"errors"
"sync/atomic"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestChat_Send(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "Hello there!"})
model := newMockModel(mp)
chat := NewChat(model)
text, err := chat.Send(context.Background(), "Hi")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "Hello there!" {
t.Errorf("expected 'Hello there!', got %q", text)
}
}
func TestChat_SendMessage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "reply"})
model := newMockModel(mp)
chat := NewChat(model)
_, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
msgs := chat.Messages()
if len(msgs) != 2 {
t.Fatalf("expected 2 messages (user + assistant), got %d", len(msgs))
}
if msgs[0].Role != RoleUser {
t.Errorf("expected first message role=user, got %v", msgs[0].Role)
}
if msgs[0].Content.Text != "msg1" {
t.Errorf("expected first message text='msg1', got %q", msgs[0].Content.Text)
}
if msgs[1].Role != RoleAssistant {
t.Errorf("expected second message role=assistant, got %v", msgs[1].Role)
}
if msgs[1].Content.Text != "reply" {
t.Errorf("expected second message text='reply', got %q", msgs[1].Content.Text)
}
}
func TestChat_SetSystem(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
chat.SetSystem("You are a bot")
msgs := chat.Messages()
if len(msgs) != 1 {
t.Fatalf("expected 1 message, got %d", len(msgs))
}
if msgs[0].Role != RoleSystem {
t.Errorf("expected role=system, got %v", msgs[0].Role)
}
if msgs[0].Content.Text != "You are a bot" {
t.Errorf("expected system text, got %q", msgs[0].Content.Text)
}
// Replace system message
chat.SetSystem("You are a helpful bot")
msgs = chat.Messages()
if len(msgs) != 1 {
t.Fatalf("expected 1 message after replace, got %d", len(msgs))
}
if msgs[0].Content.Text != "You are a helpful bot" {
t.Errorf("expected replaced system text, got %q", msgs[0].Content.Text)
}
// System message stays first even after adding other messages
_, _ = chat.Send(context.Background(), "Hi")
chat.SetSystem("New system")
msgs = chat.Messages()
if msgs[0].Role != RoleSystem {
t.Errorf("expected system as first message, got %v", msgs[0].Role)
}
if msgs[0].Content.Text != "New system" {
t.Errorf("expected 'New system', got %q", msgs[0].Content.Text)
}
}
func TestChat_ToolCallLoop(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n == 1 {
// First call: request a tool
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "greet", Arguments: "{}"},
},
}, nil
}
// Second call: return text
return provider.Response{Text: "done"}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello!", nil
})
chat.SetTools(NewToolBox(tool))
text, err := chat.Send(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "done" {
t.Errorf("expected 'done', got %q", text)
}
if atomic.LoadInt32(&callCount) != 2 {
t.Errorf("expected 2 provider calls, got %d", callCount)
}
// Check message history: user, assistant (tool call), tool result, assistant (text)
msgs := chat.Messages()
if len(msgs) != 4 {
t.Fatalf("expected 4 messages, got %d", len(msgs))
}
if msgs[0].Role != RoleUser {
t.Errorf("msg[0]: expected user, got %v", msgs[0].Role)
}
if msgs[1].Role != RoleAssistant {
t.Errorf("msg[1]: expected assistant, got %v", msgs[1].Role)
}
if len(msgs[1].ToolCalls) != 1 {
t.Errorf("msg[1]: expected 1 tool call, got %d", len(msgs[1].ToolCalls))
}
if msgs[2].Role != RoleTool {
t.Errorf("msg[2]: expected tool, got %v", msgs[2].Role)
}
if msgs[2].Content.Text != "hello!" {
t.Errorf("msg[2]: expected 'hello!', got %q", msgs[2].Content.Text)
}
if msgs[3].Role != RoleAssistant {
t.Errorf("msg[3]: expected assistant, got %v", msgs[3].Role)
}
}
func TestChat_ToolCallLoop_NoTools(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "fake", Arguments: "{}"},
},
})
model := newMockModel(mp)
chat := NewChat(model)
_, err := chat.Send(context.Background(), "test")
if !errors.Is(err, ErrNoToolsConfigured) {
t.Errorf("expected ErrNoToolsConfigured, got %v", err)
}
}
func TestChat_SendRaw(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "raw response",
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "tool1", Arguments: `{"x":1}`},
},
})
model := newMockModel(mp)
chat := NewChat(model)
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "raw response" {
t.Errorf("expected 'raw response', got %q", resp.Text)
}
if !resp.HasToolCalls() {
t.Error("expected HasToolCalls() to be true")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "tool1" {
t.Errorf("expected tool name 'tool1', got %q", resp.ToolCalls[0].Name)
}
}
func TestChat_SendRaw_ManualToolResults(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n == 1 {
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "tool1", Arguments: "{}"},
},
}, nil
}
return provider.Response{Text: "final"}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
// First call returns tool calls
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !resp.HasToolCalls() {
t.Fatal("expected tool calls")
}
// Manually add tool result
chat.AddToolResults(ToolResultMessage("tc1", "tool result"))
// Second call returns text
resp, err = chat.SendRaw(context.Background(), UserMessage("continue"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "final" {
t.Errorf("expected 'final', got %q", resp.Text)
}
// Check the full history
msgs := chat.Messages()
// user, assistant(tool call), tool result, user, assistant(text)
if len(msgs) != 5 {
t.Fatalf("expected 5 messages, got %d", len(msgs))
}
if msgs[2].Role != RoleTool {
t.Errorf("expected msg[2] role=tool, got %v", msgs[2].Role)
}
if msgs[2].ToolCallID != "tc1" {
t.Errorf("expected msg[2] toolCallID=tc1, got %q", msgs[2].ToolCallID)
}
}
func TestChat_Messages(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
_, _ = chat.Send(context.Background(), "test")
msgs := chat.Messages()
// Verify it's a copy — modifying returned slice shouldn't affect chat
msgs[0] = Message{}
original := chat.Messages()
if original[0].Role != RoleUser {
t.Error("Messages() did not return a copy")
}
}
func TestChat_Reset(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
_, _ = chat.Send(context.Background(), "test")
if len(chat.Messages()) == 0 {
t.Fatal("expected messages before reset")
}
chat.Reset()
if len(chat.Messages()) != 0 {
t.Errorf("expected 0 messages after reset, got %d", len(chat.Messages()))
}
}
func TestChat_Fork(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model)
_, _ = chat.Send(context.Background(), "msg1")
fork := chat.Fork()
// Fork should have same history
if len(fork.Messages()) != len(chat.Messages()) {
t.Fatalf("fork should have same message count: got %d vs %d", len(fork.Messages()), len(chat.Messages()))
}
// Adding to fork should not affect original
_, _ = fork.Send(context.Background(), "msg2")
if len(fork.Messages()) == len(chat.Messages()) {
t.Error("fork messages should be independent of original")
}
// Adding to original should not affect fork
originalLen := len(chat.Messages())
_, _ = chat.Send(context.Background(), "msg3")
if len(chat.Messages()) == originalLen {
t.Error("original should have more messages after send")
}
}
func TestChat_SendWithImages(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "I see an image"})
model := newMockModel(mp)
chat := NewChat(model)
img := Image{URL: "https://example.com/image.png"}
text, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "I see an image" {
t.Errorf("expected 'I see an image', got %q", text)
}
// Verify the image was passed through to the provider
req := mp.lastRequest()
if len(req.Messages) == 0 {
t.Fatal("expected messages in request")
}
lastUserMsg := req.Messages[0]
if len(lastUserMsg.Images) != 1 {
t.Fatalf("expected 1 image, got %d", len(lastUserMsg.Images))
}
if lastUserMsg.Images[0].URL != "https://example.com/image.png" {
t.Errorf("expected image URL, got %q", lastUserMsg.Images[0].URL)
}
}
func TestChat_MultipleToolCallRounds(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n <= 3 {
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc" + string(rune('0'+n)), Name: "counter", Arguments: "{}"},
},
}, nil
}
return provider.Response{Text: "all done"}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
var execCount int32
tool := DefineSimple("counter", "Counts", func(ctx context.Context) (string, error) {
atomic.AddInt32(&execCount, 1)
return "counted", nil
})
chat.SetTools(NewToolBox(tool))
text, err := chat.Send(context.Background(), "count three times")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "all done" {
t.Errorf("expected 'all done', got %q", text)
}
if atomic.LoadInt32(&callCount) != 4 {
t.Errorf("expected 4 provider calls, got %d", callCount)
}
if atomic.LoadInt32(&execCount) != 3 {
t.Errorf("expected 3 tool executions, got %d", execCount)
}
}
func TestChat_SendError(t *testing.T) {
wantErr := errors.New("provider failed")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, wantErr
})
model := newMockModel(mp)
chat := NewChat(model)
_, err := chat.Send(context.Background(), "test")
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, wantErr) {
t.Errorf("expected wrapped provider error, got %v", err)
}
}
func TestChat_WithRequestOptions(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
_, err := chat.Send(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if req.Temperature == nil || *req.Temperature != 0.5 {
t.Errorf("expected temperature 0.5, got %v", req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != 200 {
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
}
}

212
v2/message_test.go Normal file
View File

@@ -0,0 +1,212 @@
package llm
import (
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestUserMessage(t *testing.T) {
msg := UserMessage("hello")
if msg.Role != RoleUser {
t.Errorf("expected role=user, got %v", msg.Role)
}
if msg.Content.Text != "hello" {
t.Errorf("expected text='hello', got %q", msg.Content.Text)
}
if len(msg.Content.Images) != 0 {
t.Errorf("expected no images, got %d", len(msg.Content.Images))
}
}
func TestUserMessageWithImages(t *testing.T) {
img1 := Image{URL: "https://example.com/1.png"}
img2 := Image{Base64: "abc123", ContentType: "image/png"}
msg := UserMessageWithImages("describe", img1, img2)
if msg.Role != RoleUser {
t.Errorf("expected role=user, got %v", msg.Role)
}
if msg.Content.Text != "describe" {
t.Errorf("expected text='describe', got %q", msg.Content.Text)
}
if len(msg.Content.Images) != 2 {
t.Fatalf("expected 2 images, got %d", len(msg.Content.Images))
}
if msg.Content.Images[0].URL != "https://example.com/1.png" {
t.Errorf("expected image[0] URL, got %q", msg.Content.Images[0].URL)
}
if msg.Content.Images[1].Base64 != "abc123" {
t.Errorf("expected image[1] base64='abc123', got %q", msg.Content.Images[1].Base64)
}
if msg.Content.Images[1].ContentType != "image/png" {
t.Errorf("expected image[1] contentType='image/png', got %q", msg.Content.Images[1].ContentType)
}
}
func TestSystemMessage(t *testing.T) {
msg := SystemMessage("Be helpful")
if msg.Role != RoleSystem {
t.Errorf("expected role=system, got %v", msg.Role)
}
if msg.Content.Text != "Be helpful" {
t.Errorf("expected text='Be helpful', got %q", msg.Content.Text)
}
}
func TestAssistantMessage(t *testing.T) {
msg := AssistantMessage("Sure thing")
if msg.Role != RoleAssistant {
t.Errorf("expected role=assistant, got %v", msg.Role)
}
if msg.Content.Text != "Sure thing" {
t.Errorf("expected text='Sure thing', got %q", msg.Content.Text)
}
}
func TestToolResultMessage(t *testing.T) {
msg := ToolResultMessage("tc-123", "result data")
if msg.Role != RoleTool {
t.Errorf("expected role=tool, got %v", msg.Role)
}
if msg.ToolCallID != "tc-123" {
t.Errorf("expected toolCallID='tc-123', got %q", msg.ToolCallID)
}
if msg.Content.Text != "result data" {
t.Errorf("expected text='result data', got %q", msg.Content.Text)
}
}
func TestConvertMessages(t *testing.T) {
msgs := []Message{
SystemMessage("system prompt"),
UserMessageWithImages("look at this", Image{URL: "https://example.com/img.png"}),
{
Role: RoleAssistant,
Content: Content{Text: "I'll use a tool"},
ToolCalls: []ToolCall{
{ID: "tc1", Name: "search", Arguments: `{"q":"test"}`},
},
},
ToolResultMessage("tc1", "found it"),
}
converted := convertMessages(msgs)
if len(converted) != 4 {
t.Fatalf("expected 4 converted messages, got %d", len(converted))
}
// System message
if converted[0].Role != "system" {
t.Errorf("msg[0]: expected role='system', got %q", converted[0].Role)
}
if converted[0].Content != "system prompt" {
t.Errorf("msg[0]: expected content='system prompt', got %q", converted[0].Content)
}
// User message with images
if converted[1].Role != "user" {
t.Errorf("msg[1]: expected role='user', got %q", converted[1].Role)
}
if len(converted[1].Images) != 1 {
t.Fatalf("msg[1]: expected 1 image, got %d", len(converted[1].Images))
}
if converted[1].Images[0].URL != "https://example.com/img.png" {
t.Errorf("msg[1]: expected image URL, got %q", converted[1].Images[0].URL)
}
// Assistant message with tool calls
if converted[2].Role != "assistant" {
t.Errorf("msg[2]: expected role='assistant', got %q", converted[2].Role)
}
if len(converted[2].ToolCalls) != 1 {
t.Fatalf("msg[2]: expected 1 tool call, got %d", len(converted[2].ToolCalls))
}
if converted[2].ToolCalls[0].ID != "tc1" {
t.Errorf("msg[2]: expected tool call ID='tc1', got %q", converted[2].ToolCalls[0].ID)
}
if converted[2].ToolCalls[0].Name != "search" {
t.Errorf("msg[2]: expected tool call name='search', got %q", converted[2].ToolCalls[0].Name)
}
if converted[2].ToolCalls[0].Arguments != `{"q":"test"}` {
t.Errorf("msg[2]: expected tool call arguments, got %q", converted[2].ToolCalls[0].Arguments)
}
// Tool result message
if converted[3].Role != "tool" {
t.Errorf("msg[3]: expected role='tool', got %q", converted[3].Role)
}
if converted[3].ToolCallID != "tc1" {
t.Errorf("msg[3]: expected toolCallID='tc1', got %q", converted[3].ToolCallID)
}
if converted[3].Content != "found it" {
t.Errorf("msg[3]: expected content='found it', got %q", converted[3].Content)
}
}
func TestConvertProviderResponse(t *testing.T) {
t.Run("text only", func(t *testing.T) {
resp := convertProviderResponse(provider.Response{
Text: "hello",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
if resp.Text != "hello" {
t.Errorf("expected text='hello', got %q", resp.Text)
}
if resp.HasToolCalls() {
t.Error("expected no tool calls")
}
if resp.Usage == nil {
t.Fatal("expected usage")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens)
}
msg := resp.Message()
if msg.Role != RoleAssistant {
t.Errorf("expected role=assistant, got %v", msg.Role)
}
if msg.Content.Text != "hello" {
t.Errorf("expected message text='hello', got %q", msg.Content.Text)
}
})
t.Run("with tool calls", func(t *testing.T) {
resp := convertProviderResponse(provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "search", Arguments: `{"q":"go"}`},
{ID: "tc2", Name: "calc", Arguments: `{"a":1}`},
},
})
if !resp.HasToolCalls() {
t.Fatal("expected tool calls")
}
if len(resp.ToolCalls) != 2 {
t.Fatalf("expected 2 tool calls, got %d", len(resp.ToolCalls))
}
if resp.ToolCalls[0].ID != "tc1" || resp.ToolCalls[0].Name != "search" {
t.Errorf("unexpected tool call[0]: %+v", resp.ToolCalls[0])
}
if resp.ToolCalls[1].ID != "tc2" || resp.ToolCalls[1].Name != "calc" {
t.Errorf("unexpected tool call[1]: %+v", resp.ToolCalls[1])
}
msg := resp.Message()
if len(msg.ToolCalls) != 2 {
t.Errorf("expected 2 tool calls in message, got %d", len(msg.ToolCalls))
}
})
t.Run("nil usage", func(t *testing.T) {
resp := convertProviderResponse(provider.Response{Text: "ok"})
if resp.Usage != nil {
t.Errorf("expected nil usage, got %+v", resp.Usage)
}
})
}

282
v2/middleware_test.go Normal file
View File

@@ -0,0 +1,282 @@
package llm
import (
"context"
"errors"
"log/slog"
"sync"
"sync/atomic"
"testing"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestWithRetry_Success(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "ok" {
t.Errorf("expected 'ok', got %q", resp.Text)
}
if len(mp.Requests) != 1 {
t.Errorf("expected 1 request (no retries needed), got %d", len(mp.Requests))
}
}
func TestWithRetry_EventualSuccess(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n <= 2 {
return provider.Response{}, errors.New("transient error")
}
return provider.Response{Text: "success"}, nil
})
model := newMockModel(mp).WithMiddleware(
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "success" {
t.Errorf("expected 'success', got %q", resp.Text)
}
if atomic.LoadInt32(&callCount) != 3 {
t.Errorf("expected 3 calls, got %d", callCount)
}
}
func TestWithRetry_AllFail(t *testing.T) {
providerErr := errors.New("persistent error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, providerErr
})
model := newMockModel(mp).WithMiddleware(
WithRetry(2, func(attempt int) time.Duration { return time.Millisecond }),
)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, providerErr) {
t.Errorf("expected wrapped persistent error, got %v", err)
}
if len(mp.Requests) != 3 {
t.Errorf("expected 3 requests (1 initial + 2 retries), got %d", len(mp.Requests))
}
}
func TestWithRetry_ContextCancelled(t *testing.T) {
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, errors.New("fail")
})
model := newMockModel(mp).WithMiddleware(
WithRetry(10, func(attempt int) time.Duration { return 5 * time.Second }),
)
ctx, cancel := context.WithCancel(context.Background())
// Cancel after a short delay
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
_, err := model.Complete(ctx, []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, context.Canceled) {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestWithTimeout(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "fast"})
model := newMockModel(mp).WithMiddleware(WithTimeout(5 * time.Second))
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "fast" {
t.Errorf("expected 'fast', got %q", resp.Text)
}
}
func TestWithTimeout_Exceeded(t *testing.T) {
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
select {
case <-ctx.Done():
return provider.Response{}, ctx.Err()
case <-time.After(5 * time.Second):
return provider.Response{Text: "slow"}, nil
}
})
model := newMockModel(mp).WithMiddleware(WithTimeout(50 * time.Millisecond))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("expected DeadlineExceeded, got %v", err)
}
}
func TestWithUsageTracking(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "ok",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
// Make two requests
for i := 0; i < 2; i++ {
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error on call %d: %v", i, err)
}
}
input, output, requests := tracker.Summary()
if input != 20 {
t.Errorf("expected total input 20, got %d", input)
}
if output != 10 {
t.Errorf("expected total output 10, got %d", output)
}
if requests != 2 {
t.Errorf("expected 2 requests, got %d", requests)
}
}
func TestWithUsageTracking_NilUsage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "no usage"})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
input, output, requests := tracker.Summary()
if input != 0 || output != 0 {
t.Errorf("expected 0 tokens with nil usage, got input=%d output=%d", input, output)
}
// Add(nil) returns early without incrementing TotalRequests
if requests != 0 {
t.Errorf("expected 0 requests (nil usage skips Add), got %d", requests)
}
}
func TestUsageTracker_Concurrent(t *testing.T) {
tracker := &UsageTracker{}
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.Add(&Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
})
}()
}
wg.Wait()
input, output, requests := tracker.Summary()
if input != 1000 {
t.Errorf("expected total input 1000, got %d", input)
}
if output != 500 {
t.Errorf("expected total output 500, got %d", output)
}
if requests != 100 {
t.Errorf("expected 100 requests, got %d", requests)
}
}
func TestMiddleware_Chaining(t *testing.T) {
var order []string
mw1 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw1-before")
resp, err := next(ctx, model, messages, cfg)
order = append(order, "mw1-after")
return resp, err
}
}
mw2 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw2-before")
resp, err := next(ctx, model, messages, cfg)
order = append(order, "mw2-after")
return resp, err
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(mw1, mw2)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"}
if len(order) != len(expected) {
t.Fatalf("expected %d middleware calls, got %d: %v", len(expected), len(order), order)
}
for i, v := range expected {
if order[i] != v {
t.Errorf("order[%d]: expected %q, got %q", i, v, order[i])
}
}
}
func TestWithLogging(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "logged"})
logger := slog.Default()
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "logged" {
t.Errorf("expected 'logged', got %q", resp.Text)
}
}
func TestWithLogging_Error(t *testing.T) {
providerErr := errors.New("log this error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, providerErr
})
logger := slog.Default()
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if !errors.Is(err, providerErr) {
t.Errorf("expected provider error, got %v", err)
}
}

87
v2/mock_provider_test.go Normal file
View File

@@ -0,0 +1,87 @@
package llm
import (
"context"
"sync"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
// mockProvider is a configurable mock implementation of provider.Provider for testing.
type mockProvider struct {
CompleteFunc func(ctx context.Context, req provider.Request) (provider.Response, error)
StreamFunc func(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error
// mu guards Requests
mu sync.Mutex
Requests []provider.Request
}
func (m *mockProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
m.mu.Lock()
m.Requests = append(m.Requests, req)
m.mu.Unlock()
return m.CompleteFunc(ctx, req)
}
func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
m.mu.Lock()
m.Requests = append(m.Requests, req)
m.mu.Unlock()
if m.StreamFunc != nil {
return m.StreamFunc(ctx, req, events)
}
close(events)
return nil
}
// lastRequest returns the most recent request recorded by the mock.
func (m *mockProvider) lastRequest() provider.Request {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.Requests) == 0 {
return provider.Request{}
}
return m.Requests[len(m.Requests)-1]
}
// newMockProvider creates a mock that always returns the given response.
func newMockProvider(resp provider.Response) *mockProvider {
return &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return resp, nil
},
}
}
// newMockProviderFunc creates a mock with a custom Complete function.
func newMockProviderFunc(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *mockProvider {
return &mockProvider{CompleteFunc: fn}
}
// newMockStreamProvider creates a mock that streams the given events.
func newMockStreamProvider(events []provider.StreamEvent) *mockProvider {
return &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, nil
},
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
for _, ev := range events {
select {
case ch <- ev:
case <-ctx.Done():
return ctx.Err()
}
}
return nil
},
}
}
// newMockModel creates a *Model backed by the given mock provider.
func newMockModel(p *mockProvider) *Model {
return &Model{
provider: p,
model: "mock-model",
}
}

215
v2/model_test.go Normal file
View File

@@ -0,0 +1,215 @@
package llm
import (
"context"
"errors"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestModel_Complete(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "Hello!",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "Hello!" {
t.Errorf("expected text 'Hello!', got %q", resp.Text)
}
if resp.Usage == nil {
t.Fatal("expected usage, got nil")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens)
}
if resp.Usage.OutputTokens != 5 {
t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens)
}
if resp.Usage.TotalTokens != 15 {
t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens)
}
}
func TestModel_Complete_WithOptions(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
temp := 0.7
maxTok := 100
topP := 0.9
_, err := model.Complete(context.Background(), []Message{UserMessage("test")},
WithTemperature(temp),
WithMaxTokens(maxTok),
WithTopP(topP),
WithStop("STOP", "END"),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if req.Temperature == nil || *req.Temperature != temp {
t.Errorf("expected temperature %v, got %v", temp, req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != maxTok {
t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens)
}
if req.TopP == nil || *req.TopP != topP {
t.Errorf("expected topP %v, got %v", topP, req.TopP)
}
if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" {
t.Errorf("expected stop [STOP END], got %v", req.Stop)
}
}
func TestModel_Complete_Error(t *testing.T) {
wantErr := errors.New("provider error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, wantErr
})
model := newMockModel(mp)
_, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, wantErr) {
t.Errorf("expected error %v, got %v", wantErr, err)
}
}
func TestModel_Complete_WithTools(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "done"})
model := newMockModel(mp)
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello", nil
})
tb := NewToolBox(tool)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if len(req.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
}
if req.Tools[0].Name != "greet" {
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
}
if req.Tools[0].Description != "Says hello" {
t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description)
}
}
func TestClient_Model(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "hi"})
client := newClient(mp)
model := client.Model("test-model")
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "hi" {
t.Errorf("expected 'hi', got %q", resp.Text)
}
req := mp.lastRequest()
if req.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", req.Model)
}
}
func TestClient_WithMiddleware(t *testing.T) {
var called bool
mw := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
called = true
return next(ctx, model, messages, cfg)
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
client := newClient(mp).WithMiddleware(mw)
model := client.Model("test-model")
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !called {
t.Error("middleware was not called")
}
}
func TestModel_WithMiddleware(t *testing.T) {
var order []string
mw1 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw1")
return next(ctx, model, messages, cfg)
}
}
mw2 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw2")
return next(ctx, model, messages, cfg)
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" {
t.Errorf("expected middleware order [mw1 mw2], got %v", order)
}
}
func TestModel_Complete_NoUsage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "no usage"})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Usage != nil {
t.Errorf("expected nil usage, got %+v", resp.Usage)
}
}
func TestModel_Complete_ResponseMessage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "response text"})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
msg := resp.Message()
if msg.Role != RoleAssistant {
t.Errorf("expected role assistant, got %v", msg.Role)
}
if msg.Content.Text != "response text" {
t.Errorf("expected text 'response text', got %q", msg.Content.Text)
}
}

137
v2/request_test.go Normal file
View File

@@ -0,0 +1,137 @@
package llm
import (
"context"
"testing"
)
func TestWithTemperature(t *testing.T) {
cfg := &requestConfig{}
WithTemperature(0.7)(cfg)
if cfg.temperature == nil || *cfg.temperature != 0.7 {
t.Errorf("expected temperature 0.7, got %v", cfg.temperature)
}
}
func TestWithMaxTokens(t *testing.T) {
cfg := &requestConfig{}
WithMaxTokens(256)(cfg)
if cfg.maxTokens == nil || *cfg.maxTokens != 256 {
t.Errorf("expected maxTokens 256, got %v", cfg.maxTokens)
}
}
func TestWithTopP(t *testing.T) {
cfg := &requestConfig{}
WithTopP(0.95)(cfg)
if cfg.topP == nil || *cfg.topP != 0.95 {
t.Errorf("expected topP 0.95, got %v", cfg.topP)
}
}
func TestWithStop(t *testing.T) {
cfg := &requestConfig{}
WithStop("END", "STOP", "###")(cfg)
if len(cfg.stop) != 3 {
t.Fatalf("expected 3 stop sequences, got %d", len(cfg.stop))
}
if cfg.stop[0] != "END" || cfg.stop[1] != "STOP" || cfg.stop[2] != "###" {
t.Errorf("unexpected stop sequences: %v", cfg.stop)
}
}
func TestWithTools(t *testing.T) {
tool := DefineSimple("test", "A test tool", func(ctx context.Context) (string, error) {
return "ok", nil
})
tb := NewToolBox(tool)
cfg := &requestConfig{}
WithTools(tb)(cfg)
if cfg.tools == nil {
t.Fatal("expected tools to be set")
}
if len(cfg.tools.AllTools()) != 1 {
t.Errorf("expected 1 tool, got %d", len(cfg.tools.AllTools()))
}
}
func TestBuildProviderRequest(t *testing.T) {
tool := DefineSimple("greet", "Greets", func(ctx context.Context) (string, error) {
return "hi", nil
})
tb := NewToolBox(tool)
temp := 0.8
maxTok := 512
topP := 0.9
cfg := &requestConfig{
tools: tb,
temperature: &temp,
maxTokens: &maxTok,
topP: &topP,
stop: []string{"END"},
}
msgs := []Message{
SystemMessage("be nice"),
UserMessage("hello"),
}
req := buildProviderRequest("test-model", msgs, cfg)
if req.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", req.Model)
}
if len(req.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(req.Messages))
}
if req.Messages[0].Role != "system" {
t.Errorf("expected first message role='system', got %q", req.Messages[0].Role)
}
if req.Messages[1].Role != "user" {
t.Errorf("expected second message role='user', got %q", req.Messages[1].Role)
}
if req.Temperature == nil || *req.Temperature != 0.8 {
t.Errorf("expected temperature 0.8, got %v", req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != 512 {
t.Errorf("expected maxTokens 512, got %v", req.MaxTokens)
}
if req.TopP == nil || *req.TopP != 0.9 {
t.Errorf("expected topP 0.9, got %v", req.TopP)
}
if len(req.Stop) != 1 || req.Stop[0] != "END" {
t.Errorf("expected stop=[END], got %v", req.Stop)
}
if len(req.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
}
if req.Tools[0].Name != "greet" {
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
}
}
func TestBuildProviderRequest_EmptyConfig(t *testing.T) {
cfg := &requestConfig{}
msgs := []Message{UserMessage("hi")}
req := buildProviderRequest("model", msgs, cfg)
if req.Temperature != nil {
t.Errorf("expected nil temperature, got %v", req.Temperature)
}
if req.MaxTokens != nil {
t.Errorf("expected nil maxTokens, got %v", req.MaxTokens)
}
if req.TopP != nil {
t.Errorf("expected nil topP, got %v", req.TopP)
}
if len(req.Stop) != 0 {
t.Errorf("expected no stop sequences, got %v", req.Stop)
}
if len(req.Tools) != 0 {
t.Errorf("expected no tools, got %d", len(req.Tools))
}
}

338
v2/stream_test.go Normal file
View File

@@ -0,0 +1,338 @@
package llm
import (
"context"
"errors"
"io"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestStreamReader_TextEvents(t *testing.T) {
events := []provider.StreamEvent{
{Type: provider.StreamEventText, Text: "Hello"},
{Type: provider.StreamEventText, Text: " world"},
{Type: provider.StreamEventDone, Response: &provider.Response{
Text: "Hello world",
Usage: &provider.Usage{
InputTokens: 5,
OutputTokens: 2,
TotalTokens: 7,
},
}},
}
mp := newMockStreamProvider(events)
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
// Read text events
ev, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error on first event: %v", err)
}
if ev.Type != StreamEventText || ev.Text != "Hello" {
t.Errorf("expected text event 'Hello', got type=%d text=%q", ev.Type, ev.Text)
}
ev, err = reader.Next()
if err != nil {
t.Fatalf("unexpected error on second event: %v", err)
}
if ev.Type != StreamEventText || ev.Text != " world" {
t.Errorf("expected text event ' world', got type=%d text=%q", ev.Type, ev.Text)
}
// Read done event
ev, err = reader.Next()
if err != nil {
t.Fatalf("unexpected error on done event: %v", err)
}
if ev.Type != StreamEventDone {
t.Errorf("expected done event, got type=%d", ev.Type)
}
if ev.Response == nil {
t.Fatal("expected response in done event")
}
if ev.Response.Text != "Hello world" {
t.Errorf("expected final text 'Hello world', got %q", ev.Response.Text)
}
// Subsequent reads should return EOF
_, err = reader.Next()
if !errors.Is(err, io.EOF) {
t.Errorf("expected io.EOF after done, got %v", err)
}
}
func TestStreamReader_ToolCallEvents(t *testing.T) {
events := []provider.StreamEvent{
{
Type: provider.StreamEventToolStart,
ToolIndex: 0,
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search"},
},
{
Type: provider.StreamEventToolDelta,
ToolIndex: 0,
ToolCall: &provider.ToolCall{Arguments: `{"query":`},
},
{
Type: provider.StreamEventToolDelta,
ToolIndex: 0,
ToolCall: &provider.ToolCall{Arguments: `"test"}`},
},
{
Type: provider.StreamEventToolEnd,
ToolIndex: 0,
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
},
{
Type: provider.StreamEventDone,
Response: &provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
},
},
},
}
mp := newMockStreamProvider(events)
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
// Read tool start
ev, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ev.Type != StreamEventToolStart {
t.Errorf("expected tool start, got type=%d", ev.Type)
}
if ev.ToolCall == nil || ev.ToolCall.Name != "search" {
t.Errorf("expected tool call 'search', got %+v", ev.ToolCall)
}
// Read tool deltas
ev, _ = reader.Next()
if ev.Type != StreamEventToolDelta {
t.Errorf("expected tool delta, got type=%d", ev.Type)
}
ev, _ = reader.Next()
if ev.Type != StreamEventToolDelta {
t.Errorf("expected tool delta, got type=%d", ev.Type)
}
// Read tool end
ev, _ = reader.Next()
if ev.Type != StreamEventToolEnd {
t.Errorf("expected tool end, got type=%d", ev.Type)
}
if ev.ToolCall == nil || ev.ToolCall.Arguments != `{"query":"test"}` {
t.Errorf("expected complete arguments, got %+v", ev.ToolCall)
}
// Read done
ev, _ = reader.Next()
if ev.Type != StreamEventDone {
t.Errorf("expected done, got type=%d", ev.Type)
}
if ev.Response == nil || len(ev.Response.ToolCalls) != 1 {
t.Error("expected response with 1 tool call")
}
}
func TestStreamReader_Error(t *testing.T) {
streamErr := errors.New("stream failed")
mp := &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, nil
},
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "partial"}
ch <- provider.StreamEvent{Type: provider.StreamEventError, Error: streamErr}
return nil
},
}
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
// Read partial text
ev, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ev.Text != "partial" {
t.Errorf("expected 'partial', got %q", ev.Text)
}
// Read error
_, err = reader.Next()
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, streamErr) {
t.Errorf("expected stream error, got %v", err)
}
}
func TestStreamReader_Close(t *testing.T) {
// Create a stream that sends one event then blocks until context is cancelled
mp := &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, nil
},
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "start"}
<-ctx.Done()
return ctx.Err()
},
}
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Read the first event
ev, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error on first event: %v", err)
}
if ev.Text != "start" {
t.Errorf("expected 'start', got %q", ev.Text)
}
// Close should cancel context
if err := reader.Close(); err != nil {
t.Fatalf("close error: %v", err)
}
// After close, Next should eventually terminate with either EOF or context error.
// The exact behavior depends on goroutine scheduling: the channel may close (EOF)
// or the error event from the cancelled context may arrive first.
_, err = reader.Next()
if err == nil {
t.Error("expected error after close, got nil")
}
}
func TestStreamReader_Collect(t *testing.T) {
events := []provider.StreamEvent{
{Type: provider.StreamEventText, Text: "Hello"},
{Type: provider.StreamEventText, Text: " world"},
{Type: provider.StreamEventDone, Response: &provider.Response{
Text: "Hello world",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 2,
TotalTokens: 12,
},
}},
}
mp := newMockStreamProvider(events)
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
resp, err := reader.Collect()
if err != nil {
t.Fatalf("collect error: %v", err)
}
if resp.Text != "Hello world" {
t.Errorf("expected 'Hello world', got %q", resp.Text)
}
if resp.Usage == nil {
t.Fatal("expected usage")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens)
}
}
func TestStreamReader_Text(t *testing.T) {
events := []provider.StreamEvent{
{Type: provider.StreamEventText, Text: "result"},
{Type: provider.StreamEventDone, Response: &provider.Response{Text: "result"}},
}
mp := newMockStreamProvider(events)
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
text, err := reader.Text()
if err != nil {
t.Fatalf("text error: %v", err)
}
if text != "result" {
t.Errorf("expected 'result', got %q", text)
}
}
func TestStreamReader_EmptyStream(t *testing.T) {
// Stream that completes without a done event (no response)
mp := newMockStreamProvider([]provider.StreamEvent{
{Type: provider.StreamEventText, Text: "hi"},
})
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer reader.Close()
_, err = reader.Collect()
if err == nil {
t.Fatal("expected error for stream without done event")
}
}
func TestStreamReader_StreamFuncError(t *testing.T) {
// Stream function returns error directly
mp := &mockProvider{
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, nil
},
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
return errors.New("stream init failed")
},
}
model := newMockModel(mp)
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error creating reader: %v", err)
}
defer reader.Close()
// The error should come through as an error event
_, err = reader.Collect()
if err == nil {
t.Fatal("expected error from stream function")
}
}