From 6a7eeef6196a205def7f133f6b0dccc55e207a2c Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Sat, 7 Feb 2026 22:00:49 -0500 Subject: [PATCH] Add comprehensive test suite for v2 module with mock provider Cover all core library logic (Client, Model, Chat, middleware, streaming, message conversion, request building) using a configurable mock provider that avoids real API calls. ~50 tests across 7 files. Co-Authored-By: Claude Opus 4.6 --- v2/chat_test.go | 407 +++++++++++++++++++++++++++++++++++++++ v2/message_test.go | 212 ++++++++++++++++++++ v2/middleware_test.go | 282 +++++++++++++++++++++++++++ v2/mock_provider_test.go | 87 +++++++++ v2/model_test.go | 215 +++++++++++++++++++++ v2/request_test.go | 137 +++++++++++++ v2/stream_test.go | 338 ++++++++++++++++++++++++++++++++ 7 files changed, 1678 insertions(+) create mode 100644 v2/chat_test.go create mode 100644 v2/message_test.go create mode 100644 v2/middleware_test.go create mode 100644 v2/mock_provider_test.go create mode 100644 v2/model_test.go create mode 100644 v2/request_test.go create mode 100644 v2/stream_test.go diff --git a/v2/chat_test.go b/v2/chat_test.go new file mode 100644 index 0000000..be1864f --- /dev/null +++ b/v2/chat_test.go @@ -0,0 +1,407 @@ +package llm + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestChat_Send(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "Hello there!"}) + model := newMockModel(mp) + chat := NewChat(model) + + text, err := chat.Send(context.Background(), "Hi") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if text != "Hello there!" { + t.Errorf("expected 'Hello there!', got %q", text) + } +} + +func TestChat_SendMessage(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "reply"}) + model := newMockModel(mp) + chat := NewChat(model) + + _, err := chat.SendMessage(context.Background(), UserMessage("msg1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msgs := chat.Messages() + if len(msgs) != 2 { + t.Fatalf("expected 2 messages (user + assistant), got %d", len(msgs)) + } + if msgs[0].Role != RoleUser { + t.Errorf("expected first message role=user, got %v", msgs[0].Role) + } + if msgs[0].Content.Text != "msg1" { + t.Errorf("expected first message text='msg1', got %q", msgs[0].Content.Text) + } + if msgs[1].Role != RoleAssistant { + t.Errorf("expected second message role=assistant, got %v", msgs[1].Role) + } + if msgs[1].Content.Text != "reply" { + t.Errorf("expected second message text='reply', got %q", msgs[1].Content.Text) + } +} + +func TestChat_SetSystem(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "ok"}) + model := newMockModel(mp) + chat := NewChat(model) + + chat.SetSystem("You are a bot") + msgs := chat.Messages() + if len(msgs) != 1 { + t.Fatalf("expected 1 message, got %d", len(msgs)) + } + if msgs[0].Role != RoleSystem { + t.Errorf("expected role=system, got %v", msgs[0].Role) + } + if msgs[0].Content.Text != "You are a bot" { + t.Errorf("expected system text, got %q", msgs[0].Content.Text) + } + + // Replace system message + chat.SetSystem("You are a helpful bot") + msgs = chat.Messages() + if len(msgs) != 1 { + t.Fatalf("expected 1 message after replace, got %d", len(msgs)) + } + if msgs[0].Content.Text != "You are a helpful bot" { + t.Errorf("expected replaced system text, got %q", msgs[0].Content.Text) + } + + // System message stays first even after adding other messages + _, _ = chat.Send(context.Background(), "Hi") + chat.SetSystem("New system") + msgs = chat.Messages() + if msgs[0].Role != RoleSystem { + t.Errorf("expected system as first message, got %v", msgs[0].Role) + } + if msgs[0].Content.Text != "New system" { + t.Errorf("expected 'New system', got %q", msgs[0].Content.Text) + } +} + +func TestChat_ToolCallLoop(t *testing.T) { + var callCount int32 + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + n := atomic.AddInt32(&callCount, 1) + if n == 1 { + // First call: request a tool + return provider.Response{ + ToolCalls: []provider.ToolCall{ + {ID: "tc1", Name: "greet", Arguments: "{}"}, + }, + }, nil + } + // Second call: return text + return provider.Response{Text: "done"}, nil + }) + model := newMockModel(mp) + chat := NewChat(model) + + tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) { + return "hello!", nil + }) + chat.SetTools(NewToolBox(tool)) + + text, err := chat.Send(context.Background(), "test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if text != "done" { + t.Errorf("expected 'done', got %q", text) + } + if atomic.LoadInt32(&callCount) != 2 { + t.Errorf("expected 2 provider calls, got %d", callCount) + } + + // Check message history: user, assistant (tool call), tool result, assistant (text) + msgs := chat.Messages() + if len(msgs) != 4 { + t.Fatalf("expected 4 messages, got %d", len(msgs)) + } + if msgs[0].Role != RoleUser { + t.Errorf("msg[0]: expected user, got %v", msgs[0].Role) + } + if msgs[1].Role != RoleAssistant { + t.Errorf("msg[1]: expected assistant, got %v", msgs[1].Role) + } + if len(msgs[1].ToolCalls) != 1 { + t.Errorf("msg[1]: expected 1 tool call, got %d", len(msgs[1].ToolCalls)) + } + if msgs[2].Role != RoleTool { + t.Errorf("msg[2]: expected tool, got %v", msgs[2].Role) + } + if msgs[2].Content.Text != "hello!" { + t.Errorf("msg[2]: expected 'hello!', got %q", msgs[2].Content.Text) + } + if msgs[3].Role != RoleAssistant { + t.Errorf("msg[3]: expected assistant, got %v", msgs[3].Role) + } +} + +func TestChat_ToolCallLoop_NoTools(t *testing.T) { + mp := newMockProvider(provider.Response{ + ToolCalls: []provider.ToolCall{ + {ID: "tc1", Name: "fake", Arguments: "{}"}, + }, + }) + model := newMockModel(mp) + chat := NewChat(model) + + _, err := chat.Send(context.Background(), "test") + if !errors.Is(err, ErrNoToolsConfigured) { + t.Errorf("expected ErrNoToolsConfigured, got %v", err) + } +} + +func TestChat_SendRaw(t *testing.T) { + mp := newMockProvider(provider.Response{ + Text: "raw response", + ToolCalls: []provider.ToolCall{ + {ID: "tc1", Name: "tool1", Arguments: `{"x":1}`}, + }, + }) + model := newMockModel(mp) + chat := NewChat(model) + + resp, err := chat.SendRaw(context.Background(), UserMessage("test")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "raw response" { + t.Errorf("expected 'raw response', got %q", resp.Text) + } + if !resp.HasToolCalls() { + t.Error("expected HasToolCalls() to be true") + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].Name != "tool1" { + t.Errorf("expected tool name 'tool1', got %q", resp.ToolCalls[0].Name) + } +} + +func TestChat_SendRaw_ManualToolResults(t *testing.T) { + var callCount int32 + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + n := atomic.AddInt32(&callCount, 1) + if n == 1 { + return provider.Response{ + ToolCalls: []provider.ToolCall{ + {ID: "tc1", Name: "tool1", Arguments: "{}"}, + }, + }, nil + } + return provider.Response{Text: "final"}, nil + }) + model := newMockModel(mp) + chat := NewChat(model) + + // First call returns tool calls + resp, err := chat.SendRaw(context.Background(), UserMessage("test")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !resp.HasToolCalls() { + t.Fatal("expected tool calls") + } + + // Manually add tool result + chat.AddToolResults(ToolResultMessage("tc1", "tool result")) + + // Second call returns text + resp, err = chat.SendRaw(context.Background(), UserMessage("continue")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "final" { + t.Errorf("expected 'final', got %q", resp.Text) + } + + // Check the full history + msgs := chat.Messages() + // user, assistant(tool call), tool result, user, assistant(text) + if len(msgs) != 5 { + t.Fatalf("expected 5 messages, got %d", len(msgs)) + } + if msgs[2].Role != RoleTool { + t.Errorf("expected msg[2] role=tool, got %v", msgs[2].Role) + } + if msgs[2].ToolCallID != "tc1" { + t.Errorf("expected msg[2] toolCallID=tc1, got %q", msgs[2].ToolCallID) + } +} + +func TestChat_Messages(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "ok"}) + model := newMockModel(mp) + chat := NewChat(model) + + _, _ = chat.Send(context.Background(), "test") + + msgs := chat.Messages() + // Verify it's a copy — modifying returned slice shouldn't affect chat + msgs[0] = Message{} + + original := chat.Messages() + if original[0].Role != RoleUser { + t.Error("Messages() did not return a copy") + } +} + +func TestChat_Reset(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "ok"}) + model := newMockModel(mp) + chat := NewChat(model) + + _, _ = chat.Send(context.Background(), "test") + if len(chat.Messages()) == 0 { + t.Fatal("expected messages before reset") + } + + chat.Reset() + if len(chat.Messages()) != 0 { + t.Errorf("expected 0 messages after reset, got %d", len(chat.Messages())) + } +} + +func TestChat_Fork(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "ok"}) + model := newMockModel(mp) + chat := NewChat(model) + + _, _ = chat.Send(context.Background(), "msg1") + + fork := chat.Fork() + + // Fork should have same history + if len(fork.Messages()) != len(chat.Messages()) { + t.Fatalf("fork should have same message count: got %d vs %d", len(fork.Messages()), len(chat.Messages())) + } + + // Adding to fork should not affect original + _, _ = fork.Send(context.Background(), "msg2") + if len(fork.Messages()) == len(chat.Messages()) { + t.Error("fork messages should be independent of original") + } + + // Adding to original should not affect fork + originalLen := len(chat.Messages()) + _, _ = chat.Send(context.Background(), "msg3") + if len(chat.Messages()) == originalLen { + t.Error("original should have more messages after send") + } +} + +func TestChat_SendWithImages(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "I see an image"}) + model := newMockModel(mp) + chat := NewChat(model) + + img := Image{URL: "https://example.com/image.png"} + text, err := chat.SendWithImages(context.Background(), "What's in this image?", img) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if text != "I see an image" { + t.Errorf("expected 'I see an image', got %q", text) + } + + // Verify the image was passed through to the provider + req := mp.lastRequest() + if len(req.Messages) == 0 { + t.Fatal("expected messages in request") + } + lastUserMsg := req.Messages[0] + if len(lastUserMsg.Images) != 1 { + t.Fatalf("expected 1 image, got %d", len(lastUserMsg.Images)) + } + if lastUserMsg.Images[0].URL != "https://example.com/image.png" { + t.Errorf("expected image URL, got %q", lastUserMsg.Images[0].URL) + } +} + +func TestChat_MultipleToolCallRounds(t *testing.T) { + var callCount int32 + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + n := atomic.AddInt32(&callCount, 1) + if n <= 3 { + return provider.Response{ + ToolCalls: []provider.ToolCall{ + {ID: "tc" + string(rune('0'+n)), Name: "counter", Arguments: "{}"}, + }, + }, nil + } + return provider.Response{Text: "all done"}, nil + }) + model := newMockModel(mp) + chat := NewChat(model) + + var execCount int32 + tool := DefineSimple("counter", "Counts", func(ctx context.Context) (string, error) { + atomic.AddInt32(&execCount, 1) + return "counted", nil + }) + chat.SetTools(NewToolBox(tool)) + + text, err := chat.Send(context.Background(), "count three times") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if text != "all done" { + t.Errorf("expected 'all done', got %q", text) + } + if atomic.LoadInt32(&callCount) != 4 { + t.Errorf("expected 4 provider calls, got %d", callCount) + } + if atomic.LoadInt32(&execCount) != 3 { + t.Errorf("expected 3 tool executions, got %d", execCount) + } +} + +func TestChat_SendError(t *testing.T) { + wantErr := errors.New("provider failed") + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, wantErr + }) + model := newMockModel(mp) + chat := NewChat(model) + + _, err := chat.Send(context.Background(), "test") + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, wantErr) { + t.Errorf("expected wrapped provider error, got %v", err) + } +} + +func TestChat_WithRequestOptions(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "ok"}) + model := newMockModel(mp) + chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200)) + + _, err := chat.Send(context.Background(), "test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + req := mp.lastRequest() + if req.Temperature == nil || *req.Temperature != 0.5 { + t.Errorf("expected temperature 0.5, got %v", req.Temperature) + } + if req.MaxTokens == nil || *req.MaxTokens != 200 { + t.Errorf("expected maxTokens 200, got %v", req.MaxTokens) + } +} diff --git a/v2/message_test.go b/v2/message_test.go new file mode 100644 index 0000000..54737c7 --- /dev/null +++ b/v2/message_test.go @@ -0,0 +1,212 @@ +package llm + +import ( + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestUserMessage(t *testing.T) { + msg := UserMessage("hello") + if msg.Role != RoleUser { + t.Errorf("expected role=user, got %v", msg.Role) + } + if msg.Content.Text != "hello" { + t.Errorf("expected text='hello', got %q", msg.Content.Text) + } + if len(msg.Content.Images) != 0 { + t.Errorf("expected no images, got %d", len(msg.Content.Images)) + } +} + +func TestUserMessageWithImages(t *testing.T) { + img1 := Image{URL: "https://example.com/1.png"} + img2 := Image{Base64: "abc123", ContentType: "image/png"} + + msg := UserMessageWithImages("describe", img1, img2) + if msg.Role != RoleUser { + t.Errorf("expected role=user, got %v", msg.Role) + } + if msg.Content.Text != "describe" { + t.Errorf("expected text='describe', got %q", msg.Content.Text) + } + if len(msg.Content.Images) != 2 { + t.Fatalf("expected 2 images, got %d", len(msg.Content.Images)) + } + if msg.Content.Images[0].URL != "https://example.com/1.png" { + t.Errorf("expected image[0] URL, got %q", msg.Content.Images[0].URL) + } + if msg.Content.Images[1].Base64 != "abc123" { + t.Errorf("expected image[1] base64='abc123', got %q", msg.Content.Images[1].Base64) + } + if msg.Content.Images[1].ContentType != "image/png" { + t.Errorf("expected image[1] contentType='image/png', got %q", msg.Content.Images[1].ContentType) + } +} + +func TestSystemMessage(t *testing.T) { + msg := SystemMessage("Be helpful") + if msg.Role != RoleSystem { + t.Errorf("expected role=system, got %v", msg.Role) + } + if msg.Content.Text != "Be helpful" { + t.Errorf("expected text='Be helpful', got %q", msg.Content.Text) + } +} + +func TestAssistantMessage(t *testing.T) { + msg := AssistantMessage("Sure thing") + if msg.Role != RoleAssistant { + t.Errorf("expected role=assistant, got %v", msg.Role) + } + if msg.Content.Text != "Sure thing" { + t.Errorf("expected text='Sure thing', got %q", msg.Content.Text) + } +} + +func TestToolResultMessage(t *testing.T) { + msg := ToolResultMessage("tc-123", "result data") + if msg.Role != RoleTool { + t.Errorf("expected role=tool, got %v", msg.Role) + } + if msg.ToolCallID != "tc-123" { + t.Errorf("expected toolCallID='tc-123', got %q", msg.ToolCallID) + } + if msg.Content.Text != "result data" { + t.Errorf("expected text='result data', got %q", msg.Content.Text) + } +} + +func TestConvertMessages(t *testing.T) { + msgs := []Message{ + SystemMessage("system prompt"), + UserMessageWithImages("look at this", Image{URL: "https://example.com/img.png"}), + { + Role: RoleAssistant, + Content: Content{Text: "I'll use a tool"}, + ToolCalls: []ToolCall{ + {ID: "tc1", Name: "search", Arguments: `{"q":"test"}`}, + }, + }, + ToolResultMessage("tc1", "found it"), + } + + converted := convertMessages(msgs) + + if len(converted) != 4 { + t.Fatalf("expected 4 converted messages, got %d", len(converted)) + } + + // System message + if converted[0].Role != "system" { + t.Errorf("msg[0]: expected role='system', got %q", converted[0].Role) + } + if converted[0].Content != "system prompt" { + t.Errorf("msg[0]: expected content='system prompt', got %q", converted[0].Content) + } + + // User message with images + if converted[1].Role != "user" { + t.Errorf("msg[1]: expected role='user', got %q", converted[1].Role) + } + if len(converted[1].Images) != 1 { + t.Fatalf("msg[1]: expected 1 image, got %d", len(converted[1].Images)) + } + if converted[1].Images[0].URL != "https://example.com/img.png" { + t.Errorf("msg[1]: expected image URL, got %q", converted[1].Images[0].URL) + } + + // Assistant message with tool calls + if converted[2].Role != "assistant" { + t.Errorf("msg[2]: expected role='assistant', got %q", converted[2].Role) + } + if len(converted[2].ToolCalls) != 1 { + t.Fatalf("msg[2]: expected 1 tool call, got %d", len(converted[2].ToolCalls)) + } + if converted[2].ToolCalls[0].ID != "tc1" { + t.Errorf("msg[2]: expected tool call ID='tc1', got %q", converted[2].ToolCalls[0].ID) + } + if converted[2].ToolCalls[0].Name != "search" { + t.Errorf("msg[2]: expected tool call name='search', got %q", converted[2].ToolCalls[0].Name) + } + if converted[2].ToolCalls[0].Arguments != `{"q":"test"}` { + t.Errorf("msg[2]: expected tool call arguments, got %q", converted[2].ToolCalls[0].Arguments) + } + + // Tool result message + if converted[3].Role != "tool" { + t.Errorf("msg[3]: expected role='tool', got %q", converted[3].Role) + } + if converted[3].ToolCallID != "tc1" { + t.Errorf("msg[3]: expected toolCallID='tc1', got %q", converted[3].ToolCallID) + } + if converted[3].Content != "found it" { + t.Errorf("msg[3]: expected content='found it', got %q", converted[3].Content) + } +} + +func TestConvertProviderResponse(t *testing.T) { + t.Run("text only", func(t *testing.T) { + resp := convertProviderResponse(provider.Response{ + Text: "hello", + Usage: &provider.Usage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + }) + if resp.Text != "hello" { + t.Errorf("expected text='hello', got %q", resp.Text) + } + if resp.HasToolCalls() { + t.Error("expected no tool calls") + } + if resp.Usage == nil { + t.Fatal("expected usage") + } + if resp.Usage.InputTokens != 10 { + t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens) + } + + msg := resp.Message() + if msg.Role != RoleAssistant { + t.Errorf("expected role=assistant, got %v", msg.Role) + } + if msg.Content.Text != "hello" { + t.Errorf("expected message text='hello', got %q", msg.Content.Text) + } + }) + + t.Run("with tool calls", func(t *testing.T) { + resp := convertProviderResponse(provider.Response{ + ToolCalls: []provider.ToolCall{ + {ID: "tc1", Name: "search", Arguments: `{"q":"go"}`}, + {ID: "tc2", Name: "calc", Arguments: `{"a":1}`}, + }, + }) + if !resp.HasToolCalls() { + t.Fatal("expected tool calls") + } + if len(resp.ToolCalls) != 2 { + t.Fatalf("expected 2 tool calls, got %d", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].ID != "tc1" || resp.ToolCalls[0].Name != "search" { + t.Errorf("unexpected tool call[0]: %+v", resp.ToolCalls[0]) + } + if resp.ToolCalls[1].ID != "tc2" || resp.ToolCalls[1].Name != "calc" { + t.Errorf("unexpected tool call[1]: %+v", resp.ToolCalls[1]) + } + + msg := resp.Message() + if len(msg.ToolCalls) != 2 { + t.Errorf("expected 2 tool calls in message, got %d", len(msg.ToolCalls)) + } + }) + + t.Run("nil usage", func(t *testing.T) { + resp := convertProviderResponse(provider.Response{Text: "ok"}) + if resp.Usage != nil { + t.Errorf("expected nil usage, got %+v", resp.Usage) + } + }) +} diff --git a/v2/middleware_test.go b/v2/middleware_test.go new file mode 100644 index 0000000..652606c --- /dev/null +++ b/v2/middleware_test.go @@ -0,0 +1,282 @@ +package llm + +import ( + "context" + "errors" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestWithRetry_Success(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "ok"}) + model := newMockModel(mp).WithMiddleware( + WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }), + ) + + resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "ok" { + t.Errorf("expected 'ok', got %q", resp.Text) + } + if len(mp.Requests) != 1 { + t.Errorf("expected 1 request (no retries needed), got %d", len(mp.Requests)) + } +} + +func TestWithRetry_EventualSuccess(t *testing.T) { + var callCount int32 + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + n := atomic.AddInt32(&callCount, 1) + if n <= 2 { + return provider.Response{}, errors.New("transient error") + } + return provider.Response{Text: "success"}, nil + }) + model := newMockModel(mp).WithMiddleware( + WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }), + ) + + resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "success" { + t.Errorf("expected 'success', got %q", resp.Text) + } + if atomic.LoadInt32(&callCount) != 3 { + t.Errorf("expected 3 calls, got %d", callCount) + } +} + +func TestWithRetry_AllFail(t *testing.T) { + providerErr := errors.New("persistent error") + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, providerErr + }) + model := newMockModel(mp).WithMiddleware( + WithRetry(2, func(attempt int) time.Duration { return time.Millisecond }), + ) + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, providerErr) { + t.Errorf("expected wrapped persistent error, got %v", err) + } + if len(mp.Requests) != 3 { + t.Errorf("expected 3 requests (1 initial + 2 retries), got %d", len(mp.Requests)) + } +} + +func TestWithRetry_ContextCancelled(t *testing.T) { + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, errors.New("fail") + }) + model := newMockModel(mp).WithMiddleware( + WithRetry(10, func(attempt int) time.Duration { return 5 * time.Second }), + ) + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel after a short delay + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + _, err := model.Complete(ctx, []Message{UserMessage("test")}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("expected context.Canceled, got %v", err) + } +} + +func TestWithTimeout(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "fast"}) + model := newMockModel(mp).WithMiddleware(WithTimeout(5 * time.Second)) + + resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "fast" { + t.Errorf("expected 'fast', got %q", resp.Text) + } +} + +func TestWithTimeout_Exceeded(t *testing.T) { + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + select { + case <-ctx.Done(): + return provider.Response{}, ctx.Err() + case <-time.After(5 * time.Second): + return provider.Response{Text: "slow"}, nil + } + }) + model := newMockModel(mp).WithMiddleware(WithTimeout(50 * time.Millisecond)) + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected DeadlineExceeded, got %v", err) + } +} + +func TestWithUsageTracking(t *testing.T) { + mp := newMockProvider(provider.Response{ + Text: "ok", + Usage: &provider.Usage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + }) + tracker := &UsageTracker{} + model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker)) + + // Make two requests + for i := 0; i < 2; i++ { + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error on call %d: %v", i, err) + } + } + + input, output, requests := tracker.Summary() + if input != 20 { + t.Errorf("expected total input 20, got %d", input) + } + if output != 10 { + t.Errorf("expected total output 10, got %d", output) + } + if requests != 2 { + t.Errorf("expected 2 requests, got %d", requests) + } +} + +func TestWithUsageTracking_NilUsage(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "no usage"}) + tracker := &UsageTracker{} + model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker)) + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + input, output, requests := tracker.Summary() + if input != 0 || output != 0 { + t.Errorf("expected 0 tokens with nil usage, got input=%d output=%d", input, output) + } + // Add(nil) returns early without incrementing TotalRequests + if requests != 0 { + t.Errorf("expected 0 requests (nil usage skips Add), got %d", requests) + } +} + +func TestUsageTracker_Concurrent(t *testing.T) { + tracker := &UsageTracker{} + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + tracker.Add(&Usage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }) + }() + } + + wg.Wait() + + input, output, requests := tracker.Summary() + if input != 1000 { + t.Errorf("expected total input 1000, got %d", input) + } + if output != 500 { + t.Errorf("expected total output 500, got %d", output) + } + if requests != 100 { + t.Errorf("expected 100 requests, got %d", requests) + } +} + +func TestMiddleware_Chaining(t *testing.T) { + var order []string + mw1 := func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + order = append(order, "mw1-before") + resp, err := next(ctx, model, messages, cfg) + order = append(order, "mw1-after") + return resp, err + } + } + mw2 := func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + order = append(order, "mw2-before") + resp, err := next(ctx, model, messages, cfg) + order = append(order, "mw2-after") + return resp, err + } + } + + mp := newMockProvider(provider.Response{Text: "ok"}) + model := newMockModel(mp).WithMiddleware(mw1, mw2) + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"} + if len(order) != len(expected) { + t.Fatalf("expected %d middleware calls, got %d: %v", len(expected), len(order), order) + } + for i, v := range expected { + if order[i] != v { + t.Errorf("order[%d]: expected %q, got %q", i, v, order[i]) + } + } +} + +func TestWithLogging(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "logged"}) + logger := slog.Default() + model := newMockModel(mp).WithMiddleware(WithLogging(logger)) + + resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "logged" { + t.Errorf("expected 'logged', got %q", resp.Text) + } +} + +func TestWithLogging_Error(t *testing.T) { + providerErr := errors.New("log this error") + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, providerErr + }) + logger := slog.Default() + model := newMockModel(mp).WithMiddleware(WithLogging(logger)) + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if !errors.Is(err, providerErr) { + t.Errorf("expected provider error, got %v", err) + } +} diff --git a/v2/mock_provider_test.go b/v2/mock_provider_test.go new file mode 100644 index 0000000..972f01b --- /dev/null +++ b/v2/mock_provider_test.go @@ -0,0 +1,87 @@ +package llm + +import ( + "context" + "sync" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +// mockProvider is a configurable mock implementation of provider.Provider for testing. +type mockProvider struct { + CompleteFunc func(ctx context.Context, req provider.Request) (provider.Response, error) + StreamFunc func(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error + + // mu guards Requests + mu sync.Mutex + Requests []provider.Request +} + +func (m *mockProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) { + m.mu.Lock() + m.Requests = append(m.Requests, req) + m.mu.Unlock() + return m.CompleteFunc(ctx, req) +} + +func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error { + m.mu.Lock() + m.Requests = append(m.Requests, req) + m.mu.Unlock() + if m.StreamFunc != nil { + return m.StreamFunc(ctx, req, events) + } + close(events) + return nil +} + +// lastRequest returns the most recent request recorded by the mock. +func (m *mockProvider) lastRequest() provider.Request { + m.mu.Lock() + defer m.mu.Unlock() + if len(m.Requests) == 0 { + return provider.Request{} + } + return m.Requests[len(m.Requests)-1] +} + +// newMockProvider creates a mock that always returns the given response. +func newMockProvider(resp provider.Response) *mockProvider { + return &mockProvider{ + CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { + return resp, nil + }, + } +} + +// newMockProviderFunc creates a mock with a custom Complete function. +func newMockProviderFunc(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *mockProvider { + return &mockProvider{CompleteFunc: fn} +} + +// newMockStreamProvider creates a mock that streams the given events. +func newMockStreamProvider(events []provider.StreamEvent) *mockProvider { + return &mockProvider{ + CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, nil + }, + StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error { + for _, ev := range events { + select { + case ch <- ev: + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + }, + } +} + +// newMockModel creates a *Model backed by the given mock provider. +func newMockModel(p *mockProvider) *Model { + return &Model{ + provider: p, + model: "mock-model", + } +} diff --git a/v2/model_test.go b/v2/model_test.go new file mode 100644 index 0000000..a078846 --- /dev/null +++ b/v2/model_test.go @@ -0,0 +1,215 @@ +package llm + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestModel_Complete(t *testing.T) { + mp := newMockProvider(provider.Response{ + Text: "Hello!", + Usage: &provider.Usage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + }) + model := newMockModel(mp) + + resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "Hello!" { + t.Errorf("expected text 'Hello!', got %q", resp.Text) + } + if resp.Usage == nil { + t.Fatal("expected usage, got nil") + } + if resp.Usage.InputTokens != 10 { + t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens) + } + if resp.Usage.OutputTokens != 5 { + t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens) + } + if resp.Usage.TotalTokens != 15 { + t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens) + } +} + +func TestModel_Complete_WithOptions(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "ok"}) + model := newMockModel(mp) + + temp := 0.7 + maxTok := 100 + topP := 0.9 + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}, + WithTemperature(temp), + WithMaxTokens(maxTok), + WithTopP(topP), + WithStop("STOP", "END"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + req := mp.lastRequest() + if req.Temperature == nil || *req.Temperature != temp { + t.Errorf("expected temperature %v, got %v", temp, req.Temperature) + } + if req.MaxTokens == nil || *req.MaxTokens != maxTok { + t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens) + } + if req.TopP == nil || *req.TopP != topP { + t.Errorf("expected topP %v, got %v", topP, req.TopP) + } + if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" { + t.Errorf("expected stop [STOP END], got %v", req.Stop) + } +} + +func TestModel_Complete_Error(t *testing.T) { + wantErr := errors.New("provider error") + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, wantErr + }) + model := newMockModel(mp) + + _, err := model.Complete(context.Background(), []Message{UserMessage("Hi")}) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, wantErr) { + t.Errorf("expected error %v, got %v", wantErr, err) + } +} + +func TestModel_Complete_WithTools(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "done"}) + model := newMockModel(mp) + + tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) { + return "hello", nil + }) + tb := NewToolBox(tool) + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + req := mp.lastRequest() + if len(req.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(req.Tools)) + } + if req.Tools[0].Name != "greet" { + t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name) + } + if req.Tools[0].Description != "Says hello" { + t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description) + } +} + +func TestClient_Model(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "hi"}) + client := newClient(mp) + model := client.Model("test-model") + + resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Text != "hi" { + t.Errorf("expected 'hi', got %q", resp.Text) + } + + req := mp.lastRequest() + if req.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", req.Model) + } +} + +func TestClient_WithMiddleware(t *testing.T) { + var called bool + mw := func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + called = true + return next(ctx, model, messages, cfg) + } + } + + mp := newMockProvider(provider.Response{Text: "ok"}) + client := newClient(mp).WithMiddleware(mw) + model := client.Model("test-model") + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("middleware was not called") + } +} + +func TestModel_WithMiddleware(t *testing.T) { + var order []string + mw1 := func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + order = append(order, "mw1") + return next(ctx, model, messages, cfg) + } + } + mw2 := func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + order = append(order, "mw2") + return next(ctx, model, messages, cfg) + } + } + + mp := newMockProvider(provider.Response{Text: "ok"}) + model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2) + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" { + t.Errorf("expected middleware order [mw1 mw2], got %v", order) + } +} + +func TestModel_Complete_NoUsage(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "no usage"}) + model := newMockModel(mp) + + resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Usage != nil { + t.Errorf("expected nil usage, got %+v", resp.Usage) + } +} + +func TestModel_Complete_ResponseMessage(t *testing.T) { + mp := newMockProvider(provider.Response{Text: "response text"}) + model := newMockModel(mp) + + resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := resp.Message() + if msg.Role != RoleAssistant { + t.Errorf("expected role assistant, got %v", msg.Role) + } + if msg.Content.Text != "response text" { + t.Errorf("expected text 'response text', got %q", msg.Content.Text) + } +} diff --git a/v2/request_test.go b/v2/request_test.go new file mode 100644 index 0000000..3bba7d5 --- /dev/null +++ b/v2/request_test.go @@ -0,0 +1,137 @@ +package llm + +import ( + "context" + "testing" +) + +func TestWithTemperature(t *testing.T) { + cfg := &requestConfig{} + WithTemperature(0.7)(cfg) + if cfg.temperature == nil || *cfg.temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %v", cfg.temperature) + } +} + +func TestWithMaxTokens(t *testing.T) { + cfg := &requestConfig{} + WithMaxTokens(256)(cfg) + if cfg.maxTokens == nil || *cfg.maxTokens != 256 { + t.Errorf("expected maxTokens 256, got %v", cfg.maxTokens) + } +} + +func TestWithTopP(t *testing.T) { + cfg := &requestConfig{} + WithTopP(0.95)(cfg) + if cfg.topP == nil || *cfg.topP != 0.95 { + t.Errorf("expected topP 0.95, got %v", cfg.topP) + } +} + +func TestWithStop(t *testing.T) { + cfg := &requestConfig{} + WithStop("END", "STOP", "###")(cfg) + if len(cfg.stop) != 3 { + t.Fatalf("expected 3 stop sequences, got %d", len(cfg.stop)) + } + if cfg.stop[0] != "END" || cfg.stop[1] != "STOP" || cfg.stop[2] != "###" { + t.Errorf("unexpected stop sequences: %v", cfg.stop) + } +} + +func TestWithTools(t *testing.T) { + tool := DefineSimple("test", "A test tool", func(ctx context.Context) (string, error) { + return "ok", nil + }) + tb := NewToolBox(tool) + + cfg := &requestConfig{} + WithTools(tb)(cfg) + if cfg.tools == nil { + t.Fatal("expected tools to be set") + } + if len(cfg.tools.AllTools()) != 1 { + t.Errorf("expected 1 tool, got %d", len(cfg.tools.AllTools())) + } +} + +func TestBuildProviderRequest(t *testing.T) { + tool := DefineSimple("greet", "Greets", func(ctx context.Context) (string, error) { + return "hi", nil + }) + tb := NewToolBox(tool) + + temp := 0.8 + maxTok := 512 + topP := 0.9 + + cfg := &requestConfig{ + tools: tb, + temperature: &temp, + maxTokens: &maxTok, + topP: &topP, + stop: []string{"END"}, + } + + msgs := []Message{ + SystemMessage("be nice"), + UserMessage("hello"), + } + + req := buildProviderRequest("test-model", msgs, cfg) + + if req.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", req.Model) + } + if len(req.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(req.Messages)) + } + if req.Messages[0].Role != "system" { + t.Errorf("expected first message role='system', got %q", req.Messages[0].Role) + } + if req.Messages[1].Role != "user" { + t.Errorf("expected second message role='user', got %q", req.Messages[1].Role) + } + if req.Temperature == nil || *req.Temperature != 0.8 { + t.Errorf("expected temperature 0.8, got %v", req.Temperature) + } + if req.MaxTokens == nil || *req.MaxTokens != 512 { + t.Errorf("expected maxTokens 512, got %v", req.MaxTokens) + } + if req.TopP == nil || *req.TopP != 0.9 { + t.Errorf("expected topP 0.9, got %v", req.TopP) + } + if len(req.Stop) != 1 || req.Stop[0] != "END" { + t.Errorf("expected stop=[END], got %v", req.Stop) + } + if len(req.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(req.Tools)) + } + if req.Tools[0].Name != "greet" { + t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name) + } +} + +func TestBuildProviderRequest_EmptyConfig(t *testing.T) { + cfg := &requestConfig{} + msgs := []Message{UserMessage("hi")} + + req := buildProviderRequest("model", msgs, cfg) + + if req.Temperature != nil { + t.Errorf("expected nil temperature, got %v", req.Temperature) + } + if req.MaxTokens != nil { + t.Errorf("expected nil maxTokens, got %v", req.MaxTokens) + } + if req.TopP != nil { + t.Errorf("expected nil topP, got %v", req.TopP) + } + if len(req.Stop) != 0 { + t.Errorf("expected no stop sequences, got %v", req.Stop) + } + if len(req.Tools) != 0 { + t.Errorf("expected no tools, got %d", len(req.Tools)) + } +} diff --git a/v2/stream_test.go b/v2/stream_test.go new file mode 100644 index 0000000..57d5ce0 --- /dev/null +++ b/v2/stream_test.go @@ -0,0 +1,338 @@ +package llm + +import ( + "context" + "errors" + "io" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +func TestStreamReader_TextEvents(t *testing.T) { + events := []provider.StreamEvent{ + {Type: provider.StreamEventText, Text: "Hello"}, + {Type: provider.StreamEventText, Text: " world"}, + {Type: provider.StreamEventDone, Response: &provider.Response{ + Text: "Hello world", + Usage: &provider.Usage{ + InputTokens: 5, + OutputTokens: 2, + TotalTokens: 7, + }, + }}, + } + mp := newMockStreamProvider(events) + model := newMockModel(mp) + + reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer reader.Close() + + // Read text events + ev, err := reader.Next() + if err != nil { + t.Fatalf("unexpected error on first event: %v", err) + } + if ev.Type != StreamEventText || ev.Text != "Hello" { + t.Errorf("expected text event 'Hello', got type=%d text=%q", ev.Type, ev.Text) + } + + ev, err = reader.Next() + if err != nil { + t.Fatalf("unexpected error on second event: %v", err) + } + if ev.Type != StreamEventText || ev.Text != " world" { + t.Errorf("expected text event ' world', got type=%d text=%q", ev.Type, ev.Text) + } + + // Read done event + ev, err = reader.Next() + if err != nil { + t.Fatalf("unexpected error on done event: %v", err) + } + if ev.Type != StreamEventDone { + t.Errorf("expected done event, got type=%d", ev.Type) + } + if ev.Response == nil { + t.Fatal("expected response in done event") + } + if ev.Response.Text != "Hello world" { + t.Errorf("expected final text 'Hello world', got %q", ev.Response.Text) + } + + // Subsequent reads should return EOF + _, err = reader.Next() + if !errors.Is(err, io.EOF) { + t.Errorf("expected io.EOF after done, got %v", err) + } +} + +func TestStreamReader_ToolCallEvents(t *testing.T) { + events := []provider.StreamEvent{ + { + Type: provider.StreamEventToolStart, + ToolIndex: 0, + ToolCall: &provider.ToolCall{ID: "tc1", Name: "search"}, + }, + { + Type: provider.StreamEventToolDelta, + ToolIndex: 0, + ToolCall: &provider.ToolCall{Arguments: `{"query":`}, + }, + { + Type: provider.StreamEventToolDelta, + ToolIndex: 0, + ToolCall: &provider.ToolCall{Arguments: `"test"}`}, + }, + { + Type: provider.StreamEventToolEnd, + ToolIndex: 0, + ToolCall: &provider.ToolCall{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`}, + }, + { + Type: provider.StreamEventDone, + Response: &provider.Response{ + ToolCalls: []provider.ToolCall{ + {ID: "tc1", Name: "search", Arguments: `{"query":"test"}`}, + }, + }, + }, + } + mp := newMockStreamProvider(events) + model := newMockModel(mp) + + reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer reader.Close() + + // Read tool start + ev, err := reader.Next() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ev.Type != StreamEventToolStart { + t.Errorf("expected tool start, got type=%d", ev.Type) + } + if ev.ToolCall == nil || ev.ToolCall.Name != "search" { + t.Errorf("expected tool call 'search', got %+v", ev.ToolCall) + } + + // Read tool deltas + ev, _ = reader.Next() + if ev.Type != StreamEventToolDelta { + t.Errorf("expected tool delta, got type=%d", ev.Type) + } + + ev, _ = reader.Next() + if ev.Type != StreamEventToolDelta { + t.Errorf("expected tool delta, got type=%d", ev.Type) + } + + // Read tool end + ev, _ = reader.Next() + if ev.Type != StreamEventToolEnd { + t.Errorf("expected tool end, got type=%d", ev.Type) + } + if ev.ToolCall == nil || ev.ToolCall.Arguments != `{"query":"test"}` { + t.Errorf("expected complete arguments, got %+v", ev.ToolCall) + } + + // Read done + ev, _ = reader.Next() + if ev.Type != StreamEventDone { + t.Errorf("expected done, got type=%d", ev.Type) + } + if ev.Response == nil || len(ev.Response.ToolCalls) != 1 { + t.Error("expected response with 1 tool call") + } +} + +func TestStreamReader_Error(t *testing.T) { + streamErr := errors.New("stream failed") + mp := &mockProvider{ + CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, nil + }, + StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error { + ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "partial"} + ch <- provider.StreamEvent{Type: provider.StreamEventError, Error: streamErr} + return nil + }, + } + model := newMockModel(mp) + + reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer reader.Close() + + // Read partial text + ev, err := reader.Next() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ev.Text != "partial" { + t.Errorf("expected 'partial', got %q", ev.Text) + } + + // Read error + _, err = reader.Next() + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, streamErr) { + t.Errorf("expected stream error, got %v", err) + } +} + +func TestStreamReader_Close(t *testing.T) { + // Create a stream that sends one event then blocks until context is cancelled + mp := &mockProvider{ + CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, nil + }, + StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error { + ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "start"} + <-ctx.Done() + return ctx.Err() + }, + } + model := newMockModel(mp) + + reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Read the first event + ev, err := reader.Next() + if err != nil { + t.Fatalf("unexpected error on first event: %v", err) + } + if ev.Text != "start" { + t.Errorf("expected 'start', got %q", ev.Text) + } + + // Close should cancel context + if err := reader.Close(); err != nil { + t.Fatalf("close error: %v", err) + } + + // After close, Next should eventually terminate with either EOF or context error. + // The exact behavior depends on goroutine scheduling: the channel may close (EOF) + // or the error event from the cancelled context may arrive first. + _, err = reader.Next() + if err == nil { + t.Error("expected error after close, got nil") + } +} + +func TestStreamReader_Collect(t *testing.T) { + events := []provider.StreamEvent{ + {Type: provider.StreamEventText, Text: "Hello"}, + {Type: provider.StreamEventText, Text: " world"}, + {Type: provider.StreamEventDone, Response: &provider.Response{ + Text: "Hello world", + Usage: &provider.Usage{ + InputTokens: 10, + OutputTokens: 2, + TotalTokens: 12, + }, + }}, + } + mp := newMockStreamProvider(events) + model := newMockModel(mp) + + reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer reader.Close() + + resp, err := reader.Collect() + if err != nil { + t.Fatalf("collect error: %v", err) + } + if resp.Text != "Hello world" { + t.Errorf("expected 'Hello world', got %q", resp.Text) + } + if resp.Usage == nil { + t.Fatal("expected usage") + } + if resp.Usage.InputTokens != 10 { + t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens) + } +} + +func TestStreamReader_Text(t *testing.T) { + events := []provider.StreamEvent{ + {Type: provider.StreamEventText, Text: "result"}, + {Type: provider.StreamEventDone, Response: &provider.Response{Text: "result"}}, + } + mp := newMockStreamProvider(events) + model := newMockModel(mp) + + reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer reader.Close() + + text, err := reader.Text() + if err != nil { + t.Fatalf("text error: %v", err) + } + if text != "result" { + t.Errorf("expected 'result', got %q", text) + } +} + +func TestStreamReader_EmptyStream(t *testing.T) { + // Stream that completes without a done event (no response) + mp := newMockStreamProvider([]provider.StreamEvent{ + {Type: provider.StreamEventText, Text: "hi"}, + }) + model := newMockModel(mp) + + reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer reader.Close() + + _, err = reader.Collect() + if err == nil { + t.Fatal("expected error for stream without done event") + } +} + +func TestStreamReader_StreamFuncError(t *testing.T) { + // Stream function returns error directly + mp := &mockProvider{ + CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{}, nil + }, + StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error { + return errors.New("stream init failed") + }, + } + model := newMockModel(mp) + + reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error creating reader: %v", err) + } + defer reader.Close() + + // The error should come through as an error event + _, err = reader.Collect() + if err == nil { + t.Fatal("expected error from stream function") + } +}