Add comprehensive test suite for v2 module with mock provider
Cover all core library logic (Client, Model, Chat, middleware, streaming, message conversion, request building) using a configurable mock provider that avoids real API calls. ~50 tests across 7 files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
407
v2/chat_test.go
Normal file
407
v2/chat_test.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestChat_Send(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "Hello there!"})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
text, err := chat.Send(context.Background(), "Hi")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if text != "Hello there!" {
|
||||
t.Errorf("expected 'Hello there!', got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_SendMessage(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "reply"})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
msgs := chat.Messages()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected 2 messages (user + assistant), got %d", len(msgs))
|
||||
}
|
||||
if msgs[0].Role != RoleUser {
|
||||
t.Errorf("expected first message role=user, got %v", msgs[0].Role)
|
||||
}
|
||||
if msgs[0].Content.Text != "msg1" {
|
||||
t.Errorf("expected first message text='msg1', got %q", msgs[0].Content.Text)
|
||||
}
|
||||
if msgs[1].Role != RoleAssistant {
|
||||
t.Errorf("expected second message role=assistant, got %v", msgs[1].Role)
|
||||
}
|
||||
if msgs[1].Content.Text != "reply" {
|
||||
t.Errorf("expected second message text='reply', got %q", msgs[1].Content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_SetSystem(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
chat.SetSystem("You are a bot")
|
||||
msgs := chat.Messages()
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(msgs))
|
||||
}
|
||||
if msgs[0].Role != RoleSystem {
|
||||
t.Errorf("expected role=system, got %v", msgs[0].Role)
|
||||
}
|
||||
if msgs[0].Content.Text != "You are a bot" {
|
||||
t.Errorf("expected system text, got %q", msgs[0].Content.Text)
|
||||
}
|
||||
|
||||
// Replace system message
|
||||
chat.SetSystem("You are a helpful bot")
|
||||
msgs = chat.Messages()
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected 1 message after replace, got %d", len(msgs))
|
||||
}
|
||||
if msgs[0].Content.Text != "You are a helpful bot" {
|
||||
t.Errorf("expected replaced system text, got %q", msgs[0].Content.Text)
|
||||
}
|
||||
|
||||
// System message stays first even after adding other messages
|
||||
_, _ = chat.Send(context.Background(), "Hi")
|
||||
chat.SetSystem("New system")
|
||||
msgs = chat.Messages()
|
||||
if msgs[0].Role != RoleSystem {
|
||||
t.Errorf("expected system as first message, got %v", msgs[0].Role)
|
||||
}
|
||||
if msgs[0].Content.Text != "New system" {
|
||||
t.Errorf("expected 'New system', got %q", msgs[0].Content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_ToolCallLoop(t *testing.T) {
|
||||
var callCount int32
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
n := atomic.AddInt32(&callCount, 1)
|
||||
if n == 1 {
|
||||
// First call: request a tool
|
||||
return provider.Response{
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{ID: "tc1", Name: "greet", Arguments: "{}"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
// Second call: return text
|
||||
return provider.Response{Text: "done"}, nil
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||
return "hello!", nil
|
||||
})
|
||||
chat.SetTools(NewToolBox(tool))
|
||||
|
||||
text, err := chat.Send(context.Background(), "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if text != "done" {
|
||||
t.Errorf("expected 'done', got %q", text)
|
||||
}
|
||||
if atomic.LoadInt32(&callCount) != 2 {
|
||||
t.Errorf("expected 2 provider calls, got %d", callCount)
|
||||
}
|
||||
|
||||
// Check message history: user, assistant (tool call), tool result, assistant (text)
|
||||
msgs := chat.Messages()
|
||||
if len(msgs) != 4 {
|
||||
t.Fatalf("expected 4 messages, got %d", len(msgs))
|
||||
}
|
||||
if msgs[0].Role != RoleUser {
|
||||
t.Errorf("msg[0]: expected user, got %v", msgs[0].Role)
|
||||
}
|
||||
if msgs[1].Role != RoleAssistant {
|
||||
t.Errorf("msg[1]: expected assistant, got %v", msgs[1].Role)
|
||||
}
|
||||
if len(msgs[1].ToolCalls) != 1 {
|
||||
t.Errorf("msg[1]: expected 1 tool call, got %d", len(msgs[1].ToolCalls))
|
||||
}
|
||||
if msgs[2].Role != RoleTool {
|
||||
t.Errorf("msg[2]: expected tool, got %v", msgs[2].Role)
|
||||
}
|
||||
if msgs[2].Content.Text != "hello!" {
|
||||
t.Errorf("msg[2]: expected 'hello!', got %q", msgs[2].Content.Text)
|
||||
}
|
||||
if msgs[3].Role != RoleAssistant {
|
||||
t.Errorf("msg[3]: expected assistant, got %v", msgs[3].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_ToolCallLoop_NoTools(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{ID: "tc1", Name: "fake", Arguments: "{}"},
|
||||
},
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, err := chat.Send(context.Background(), "test")
|
||||
if !errors.Is(err, ErrNoToolsConfigured) {
|
||||
t.Errorf("expected ErrNoToolsConfigured, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_SendRaw(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{
|
||||
Text: "raw response",
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{ID: "tc1", Name: "tool1", Arguments: `{"x":1}`},
|
||||
},
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "raw response" {
|
||||
t.Errorf("expected 'raw response', got %q", resp.Text)
|
||||
}
|
||||
if !resp.HasToolCalls() {
|
||||
t.Error("expected HasToolCalls() to be true")
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "tool1" {
|
||||
t.Errorf("expected tool name 'tool1', got %q", resp.ToolCalls[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_SendRaw_ManualToolResults(t *testing.T) {
|
||||
var callCount int32
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
n := atomic.AddInt32(&callCount, 1)
|
||||
if n == 1 {
|
||||
return provider.Response{
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{ID: "tc1", Name: "tool1", Arguments: "{}"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return provider.Response{Text: "final"}, nil
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
// First call returns tool calls
|
||||
resp, err := chat.SendRaw(context.Background(), UserMessage("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !resp.HasToolCalls() {
|
||||
t.Fatal("expected tool calls")
|
||||
}
|
||||
|
||||
// Manually add tool result
|
||||
chat.AddToolResults(ToolResultMessage("tc1", "tool result"))
|
||||
|
||||
// Second call returns text
|
||||
resp, err = chat.SendRaw(context.Background(), UserMessage("continue"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "final" {
|
||||
t.Errorf("expected 'final', got %q", resp.Text)
|
||||
}
|
||||
|
||||
// Check the full history
|
||||
msgs := chat.Messages()
|
||||
// user, assistant(tool call), tool result, user, assistant(text)
|
||||
if len(msgs) != 5 {
|
||||
t.Fatalf("expected 5 messages, got %d", len(msgs))
|
||||
}
|
||||
if msgs[2].Role != RoleTool {
|
||||
t.Errorf("expected msg[2] role=tool, got %v", msgs[2].Role)
|
||||
}
|
||||
if msgs[2].ToolCallID != "tc1" {
|
||||
t.Errorf("expected msg[2] toolCallID=tc1, got %q", msgs[2].ToolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_Messages(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, _ = chat.Send(context.Background(), "test")
|
||||
|
||||
msgs := chat.Messages()
|
||||
// Verify it's a copy — modifying returned slice shouldn't affect chat
|
||||
msgs[0] = Message{}
|
||||
|
||||
original := chat.Messages()
|
||||
if original[0].Role != RoleUser {
|
||||
t.Error("Messages() did not return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_Reset(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, _ = chat.Send(context.Background(), "test")
|
||||
if len(chat.Messages()) == 0 {
|
||||
t.Fatal("expected messages before reset")
|
||||
}
|
||||
|
||||
chat.Reset()
|
||||
if len(chat.Messages()) != 0 {
|
||||
t.Errorf("expected 0 messages after reset, got %d", len(chat.Messages()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_Fork(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, _ = chat.Send(context.Background(), "msg1")
|
||||
|
||||
fork := chat.Fork()
|
||||
|
||||
// Fork should have same history
|
||||
if len(fork.Messages()) != len(chat.Messages()) {
|
||||
t.Fatalf("fork should have same message count: got %d vs %d", len(fork.Messages()), len(chat.Messages()))
|
||||
}
|
||||
|
||||
// Adding to fork should not affect original
|
||||
_, _ = fork.Send(context.Background(), "msg2")
|
||||
if len(fork.Messages()) == len(chat.Messages()) {
|
||||
t.Error("fork messages should be independent of original")
|
||||
}
|
||||
|
||||
// Adding to original should not affect fork
|
||||
originalLen := len(chat.Messages())
|
||||
_, _ = chat.Send(context.Background(), "msg3")
|
||||
if len(chat.Messages()) == originalLen {
|
||||
t.Error("original should have more messages after send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_SendWithImages(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "I see an image"})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
img := Image{URL: "https://example.com/image.png"}
|
||||
text, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if text != "I see an image" {
|
||||
t.Errorf("expected 'I see an image', got %q", text)
|
||||
}
|
||||
|
||||
// Verify the image was passed through to the provider
|
||||
req := mp.lastRequest()
|
||||
if len(req.Messages) == 0 {
|
||||
t.Fatal("expected messages in request")
|
||||
}
|
||||
lastUserMsg := req.Messages[0]
|
||||
if len(lastUserMsg.Images) != 1 {
|
||||
t.Fatalf("expected 1 image, got %d", len(lastUserMsg.Images))
|
||||
}
|
||||
if lastUserMsg.Images[0].URL != "https://example.com/image.png" {
|
||||
t.Errorf("expected image URL, got %q", lastUserMsg.Images[0].URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_MultipleToolCallRounds(t *testing.T) {
|
||||
var callCount int32
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
n := atomic.AddInt32(&callCount, 1)
|
||||
if n <= 3 {
|
||||
return provider.Response{
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{ID: "tc" + string(rune('0'+n)), Name: "counter", Arguments: "{}"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return provider.Response{Text: "all done"}, nil
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
var execCount int32
|
||||
tool := DefineSimple("counter", "Counts", func(ctx context.Context) (string, error) {
|
||||
atomic.AddInt32(&execCount, 1)
|
||||
return "counted", nil
|
||||
})
|
||||
chat.SetTools(NewToolBox(tool))
|
||||
|
||||
text, err := chat.Send(context.Background(), "count three times")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if text != "all done" {
|
||||
t.Errorf("expected 'all done', got %q", text)
|
||||
}
|
||||
if atomic.LoadInt32(&callCount) != 4 {
|
||||
t.Errorf("expected 4 provider calls, got %d", callCount)
|
||||
}
|
||||
if atomic.LoadInt32(&execCount) != 3 {
|
||||
t.Errorf("expected 3 tool executions, got %d", execCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_SendError(t *testing.T) {
|
||||
wantErr := errors.New("provider failed")
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, wantErr
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, err := chat.Send(context.Background(), "test")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Errorf("expected wrapped provider error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_WithRequestOptions(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
|
||||
|
||||
_, err := chat.Send(context.Background(), "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
req := mp.lastRequest()
|
||||
if req.Temperature == nil || *req.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5, got %v", req.Temperature)
|
||||
}
|
||||
if req.MaxTokens == nil || *req.MaxTokens != 200 {
|
||||
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
|
||||
}
|
||||
}
|
||||
212
v2/message_test.go
Normal file
212
v2/message_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestUserMessage(t *testing.T) {
|
||||
msg := UserMessage("hello")
|
||||
if msg.Role != RoleUser {
|
||||
t.Errorf("expected role=user, got %v", msg.Role)
|
||||
}
|
||||
if msg.Content.Text != "hello" {
|
||||
t.Errorf("expected text='hello', got %q", msg.Content.Text)
|
||||
}
|
||||
if len(msg.Content.Images) != 0 {
|
||||
t.Errorf("expected no images, got %d", len(msg.Content.Images))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserMessageWithImages(t *testing.T) {
|
||||
img1 := Image{URL: "https://example.com/1.png"}
|
||||
img2 := Image{Base64: "abc123", ContentType: "image/png"}
|
||||
|
||||
msg := UserMessageWithImages("describe", img1, img2)
|
||||
if msg.Role != RoleUser {
|
||||
t.Errorf("expected role=user, got %v", msg.Role)
|
||||
}
|
||||
if msg.Content.Text != "describe" {
|
||||
t.Errorf("expected text='describe', got %q", msg.Content.Text)
|
||||
}
|
||||
if len(msg.Content.Images) != 2 {
|
||||
t.Fatalf("expected 2 images, got %d", len(msg.Content.Images))
|
||||
}
|
||||
if msg.Content.Images[0].URL != "https://example.com/1.png" {
|
||||
t.Errorf("expected image[0] URL, got %q", msg.Content.Images[0].URL)
|
||||
}
|
||||
if msg.Content.Images[1].Base64 != "abc123" {
|
||||
t.Errorf("expected image[1] base64='abc123', got %q", msg.Content.Images[1].Base64)
|
||||
}
|
||||
if msg.Content.Images[1].ContentType != "image/png" {
|
||||
t.Errorf("expected image[1] contentType='image/png', got %q", msg.Content.Images[1].ContentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemMessage(t *testing.T) {
|
||||
msg := SystemMessage("Be helpful")
|
||||
if msg.Role != RoleSystem {
|
||||
t.Errorf("expected role=system, got %v", msg.Role)
|
||||
}
|
||||
if msg.Content.Text != "Be helpful" {
|
||||
t.Errorf("expected text='Be helpful', got %q", msg.Content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssistantMessage(t *testing.T) {
|
||||
msg := AssistantMessage("Sure thing")
|
||||
if msg.Role != RoleAssistant {
|
||||
t.Errorf("expected role=assistant, got %v", msg.Role)
|
||||
}
|
||||
if msg.Content.Text != "Sure thing" {
|
||||
t.Errorf("expected text='Sure thing', got %q", msg.Content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultMessage(t *testing.T) {
|
||||
msg := ToolResultMessage("tc-123", "result data")
|
||||
if msg.Role != RoleTool {
|
||||
t.Errorf("expected role=tool, got %v", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "tc-123" {
|
||||
t.Errorf("expected toolCallID='tc-123', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content.Text != "result data" {
|
||||
t.Errorf("expected text='result data', got %q", msg.Content.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessages(t *testing.T) {
|
||||
msgs := []Message{
|
||||
SystemMessage("system prompt"),
|
||||
UserMessageWithImages("look at this", Image{URL: "https://example.com/img.png"}),
|
||||
{
|
||||
Role: RoleAssistant,
|
||||
Content: Content{Text: "I'll use a tool"},
|
||||
ToolCalls: []ToolCall{
|
||||
{ID: "tc1", Name: "search", Arguments: `{"q":"test"}`},
|
||||
},
|
||||
},
|
||||
ToolResultMessage("tc1", "found it"),
|
||||
}
|
||||
|
||||
converted := convertMessages(msgs)
|
||||
|
||||
if len(converted) != 4 {
|
||||
t.Fatalf("expected 4 converted messages, got %d", len(converted))
|
||||
}
|
||||
|
||||
// System message
|
||||
if converted[0].Role != "system" {
|
||||
t.Errorf("msg[0]: expected role='system', got %q", converted[0].Role)
|
||||
}
|
||||
if converted[0].Content != "system prompt" {
|
||||
t.Errorf("msg[0]: expected content='system prompt', got %q", converted[0].Content)
|
||||
}
|
||||
|
||||
// User message with images
|
||||
if converted[1].Role != "user" {
|
||||
t.Errorf("msg[1]: expected role='user', got %q", converted[1].Role)
|
||||
}
|
||||
if len(converted[1].Images) != 1 {
|
||||
t.Fatalf("msg[1]: expected 1 image, got %d", len(converted[1].Images))
|
||||
}
|
||||
if converted[1].Images[0].URL != "https://example.com/img.png" {
|
||||
t.Errorf("msg[1]: expected image URL, got %q", converted[1].Images[0].URL)
|
||||
}
|
||||
|
||||
// Assistant message with tool calls
|
||||
if converted[2].Role != "assistant" {
|
||||
t.Errorf("msg[2]: expected role='assistant', got %q", converted[2].Role)
|
||||
}
|
||||
if len(converted[2].ToolCalls) != 1 {
|
||||
t.Fatalf("msg[2]: expected 1 tool call, got %d", len(converted[2].ToolCalls))
|
||||
}
|
||||
if converted[2].ToolCalls[0].ID != "tc1" {
|
||||
t.Errorf("msg[2]: expected tool call ID='tc1', got %q", converted[2].ToolCalls[0].ID)
|
||||
}
|
||||
if converted[2].ToolCalls[0].Name != "search" {
|
||||
t.Errorf("msg[2]: expected tool call name='search', got %q", converted[2].ToolCalls[0].Name)
|
||||
}
|
||||
if converted[2].ToolCalls[0].Arguments != `{"q":"test"}` {
|
||||
t.Errorf("msg[2]: expected tool call arguments, got %q", converted[2].ToolCalls[0].Arguments)
|
||||
}
|
||||
|
||||
// Tool result message
|
||||
if converted[3].Role != "tool" {
|
||||
t.Errorf("msg[3]: expected role='tool', got %q", converted[3].Role)
|
||||
}
|
||||
if converted[3].ToolCallID != "tc1" {
|
||||
t.Errorf("msg[3]: expected toolCallID='tc1', got %q", converted[3].ToolCallID)
|
||||
}
|
||||
if converted[3].Content != "found it" {
|
||||
t.Errorf("msg[3]: expected content='found it', got %q", converted[3].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProviderResponse(t *testing.T) {
|
||||
t.Run("text only", func(t *testing.T) {
|
||||
resp := convertProviderResponse(provider.Response{
|
||||
Text: "hello",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
})
|
||||
if resp.Text != "hello" {
|
||||
t.Errorf("expected text='hello', got %q", resp.Text)
|
||||
}
|
||||
if resp.HasToolCalls() {
|
||||
t.Error("expected no tool calls")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("expected usage")
|
||||
}
|
||||
if resp.Usage.InputTokens != 10 {
|
||||
t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens)
|
||||
}
|
||||
|
||||
msg := resp.Message()
|
||||
if msg.Role != RoleAssistant {
|
||||
t.Errorf("expected role=assistant, got %v", msg.Role)
|
||||
}
|
||||
if msg.Content.Text != "hello" {
|
||||
t.Errorf("expected message text='hello', got %q", msg.Content.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with tool calls", func(t *testing.T) {
|
||||
resp := convertProviderResponse(provider.Response{
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{ID: "tc1", Name: "search", Arguments: `{"q":"go"}`},
|
||||
{ID: "tc2", Name: "calc", Arguments: `{"a":1}`},
|
||||
},
|
||||
})
|
||||
if !resp.HasToolCalls() {
|
||||
t.Fatal("expected tool calls")
|
||||
}
|
||||
if len(resp.ToolCalls) != 2 {
|
||||
t.Fatalf("expected 2 tool calls, got %d", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].ID != "tc1" || resp.ToolCalls[0].Name != "search" {
|
||||
t.Errorf("unexpected tool call[0]: %+v", resp.ToolCalls[0])
|
||||
}
|
||||
if resp.ToolCalls[1].ID != "tc2" || resp.ToolCalls[1].Name != "calc" {
|
||||
t.Errorf("unexpected tool call[1]: %+v", resp.ToolCalls[1])
|
||||
}
|
||||
|
||||
msg := resp.Message()
|
||||
if len(msg.ToolCalls) != 2 {
|
||||
t.Errorf("expected 2 tool calls in message, got %d", len(msg.ToolCalls))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil usage", func(t *testing.T) {
|
||||
resp := convertProviderResponse(provider.Response{Text: "ok"})
|
||||
if resp.Usage != nil {
|
||||
t.Errorf("expected nil usage, got %+v", resp.Usage)
|
||||
}
|
||||
})
|
||||
}
|
||||
282
v2/middleware_test.go
Normal file
282
v2/middleware_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestWithRetry_Success(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp).WithMiddleware(
|
||||
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
|
||||
)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "ok" {
|
||||
t.Errorf("expected 'ok', got %q", resp.Text)
|
||||
}
|
||||
if len(mp.Requests) != 1 {
|
||||
t.Errorf("expected 1 request (no retries needed), got %d", len(mp.Requests))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithRetry_EventualSuccess(t *testing.T) {
|
||||
var callCount int32
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
n := atomic.AddInt32(&callCount, 1)
|
||||
if n <= 2 {
|
||||
return provider.Response{}, errors.New("transient error")
|
||||
}
|
||||
return provider.Response{Text: "success"}, nil
|
||||
})
|
||||
model := newMockModel(mp).WithMiddleware(
|
||||
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
|
||||
)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "success" {
|
||||
t.Errorf("expected 'success', got %q", resp.Text)
|
||||
}
|
||||
if atomic.LoadInt32(&callCount) != 3 {
|
||||
t.Errorf("expected 3 calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithRetry_AllFail(t *testing.T) {
|
||||
providerErr := errors.New("persistent error")
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, providerErr
|
||||
})
|
||||
model := newMockModel(mp).WithMiddleware(
|
||||
WithRetry(2, func(attempt int) time.Duration { return time.Millisecond }),
|
||||
)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, providerErr) {
|
||||
t.Errorf("expected wrapped persistent error, got %v", err)
|
||||
}
|
||||
if len(mp.Requests) != 3 {
|
||||
t.Errorf("expected 3 requests (1 initial + 2 retries), got %d", len(mp.Requests))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithRetry_ContextCancelled(t *testing.T) {
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, errors.New("fail")
|
||||
})
|
||||
model := newMockModel(mp).WithMiddleware(
|
||||
WithRetry(10, func(attempt int) time.Duration { return 5 * time.Second }),
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Cancel after a short delay
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
_, err := model.Complete(ctx, []Message{UserMessage("test")})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "fast"})
|
||||
model := newMockModel(mp).WithMiddleware(WithTimeout(5 * time.Second))
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "fast" {
|
||||
t.Errorf("expected 'fast', got %q", resp.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTimeout_Exceeded(t *testing.T) {
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return provider.Response{}, ctx.Err()
|
||||
case <-time.After(5 * time.Second):
|
||||
return provider.Response{Text: "slow"}, nil
|
||||
}
|
||||
})
|
||||
model := newMockModel(mp).WithMiddleware(WithTimeout(50 * time.Millisecond))
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Errorf("expected DeadlineExceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithUsageTracking(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{
|
||||
Text: "ok",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
})
|
||||
tracker := &UsageTracker{}
|
||||
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
|
||||
|
||||
// Make two requests
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on call %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
input, output, requests := tracker.Summary()
|
||||
if input != 20 {
|
||||
t.Errorf("expected total input 20, got %d", input)
|
||||
}
|
||||
if output != 10 {
|
||||
t.Errorf("expected total output 10, got %d", output)
|
||||
}
|
||||
if requests != 2 {
|
||||
t.Errorf("expected 2 requests, got %d", requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithUsageTracking_NilUsage(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "no usage"})
|
||||
tracker := &UsageTracker{}
|
||||
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
input, output, requests := tracker.Summary()
|
||||
if input != 0 || output != 0 {
|
||||
t.Errorf("expected 0 tokens with nil usage, got input=%d output=%d", input, output)
|
||||
}
|
||||
// Add(nil) returns early without incrementing TotalRequests
|
||||
if requests != 0 {
|
||||
t.Errorf("expected 0 requests (nil usage skips Add), got %d", requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTracker_Concurrent(t *testing.T) {
|
||||
tracker := &UsageTracker{}
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.Add(&Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
input, output, requests := tracker.Summary()
|
||||
if input != 1000 {
|
||||
t.Errorf("expected total input 1000, got %d", input)
|
||||
}
|
||||
if output != 500 {
|
||||
t.Errorf("expected total output 500, got %d", output)
|
||||
}
|
||||
if requests != 100 {
|
||||
t.Errorf("expected 100 requests, got %d", requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Chaining(t *testing.T) {
|
||||
var order []string
|
||||
mw1 := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
order = append(order, "mw1-before")
|
||||
resp, err := next(ctx, model, messages, cfg)
|
||||
order = append(order, "mw1-after")
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
mw2 := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
order = append(order, "mw2-before")
|
||||
resp, err := next(ctx, model, messages, cfg)
|
||||
order = append(order, "mw2-after")
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp).WithMiddleware(mw1, mw2)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("expected %d middleware calls, got %d: %v", len(expected), len(order), order)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if order[i] != v {
|
||||
t.Errorf("order[%d]: expected %q, got %q", i, v, order[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogging(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "logged"})
|
||||
logger := slog.Default()
|
||||
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "logged" {
|
||||
t.Errorf("expected 'logged', got %q", resp.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogging_Error(t *testing.T) {
|
||||
providerErr := errors.New("log this error")
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, providerErr
|
||||
})
|
||||
logger := slog.Default()
|
||||
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if !errors.Is(err, providerErr) {
|
||||
t.Errorf("expected provider error, got %v", err)
|
||||
}
|
||||
}
|
||||
87
v2/mock_provider_test.go
Normal file
87
v2/mock_provider_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
// mockProvider is a configurable mock implementation of provider.Provider for testing.
|
||||
type mockProvider struct {
|
||||
CompleteFunc func(ctx context.Context, req provider.Request) (provider.Response, error)
|
||||
StreamFunc func(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error
|
||||
|
||||
// mu guards Requests
|
||||
mu sync.Mutex
|
||||
Requests []provider.Request
|
||||
}
|
||||
|
||||
func (m *mockProvider) Complete(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
m.mu.Lock()
|
||||
m.Requests = append(m.Requests, req)
|
||||
m.mu.Unlock()
|
||||
return m.CompleteFunc(ctx, req)
|
||||
}
|
||||
|
||||
func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events chan<- provider.StreamEvent) error {
|
||||
m.mu.Lock()
|
||||
m.Requests = append(m.Requests, req)
|
||||
m.mu.Unlock()
|
||||
if m.StreamFunc != nil {
|
||||
return m.StreamFunc(ctx, req, events)
|
||||
}
|
||||
close(events)
|
||||
return nil
|
||||
}
|
||||
|
||||
// lastRequest returns the most recent request recorded by the mock.
|
||||
func (m *mockProvider) lastRequest() provider.Request {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if len(m.Requests) == 0 {
|
||||
return provider.Request{}
|
||||
}
|
||||
return m.Requests[len(m.Requests)-1]
|
||||
}
|
||||
|
||||
// newMockProvider creates a mock that always returns the given response.
|
||||
func newMockProvider(resp provider.Response) *mockProvider {
|
||||
return &mockProvider{
|
||||
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newMockProviderFunc creates a mock with a custom Complete function.
|
||||
func newMockProviderFunc(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *mockProvider {
|
||||
return &mockProvider{CompleteFunc: fn}
|
||||
}
|
||||
|
||||
// newMockStreamProvider creates a mock that streams the given events.
|
||||
func newMockStreamProvider(events []provider.StreamEvent) *mockProvider {
|
||||
return &mockProvider{
|
||||
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, nil
|
||||
},
|
||||
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
||||
for _, ev := range events {
|
||||
select {
|
||||
case ch <- ev:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newMockModel creates a *Model backed by the given mock provider.
|
||||
func newMockModel(p *mockProvider) *Model {
|
||||
return &Model{
|
||||
provider: p,
|
||||
model: "mock-model",
|
||||
}
|
||||
}
|
||||
215
v2/model_test.go
Normal file
215
v2/model_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestModel_Complete(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{
|
||||
Text: "Hello!",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "Hello!" {
|
||||
t.Errorf("expected text 'Hello!', got %q", resp.Text)
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("expected usage, got nil")
|
||||
}
|
||||
if resp.Usage.InputTokens != 10 {
|
||||
t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens)
|
||||
}
|
||||
if resp.Usage.OutputTokens != 5 {
|
||||
t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens)
|
||||
}
|
||||
if resp.Usage.TotalTokens != 15 {
|
||||
t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_WithOptions(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp)
|
||||
|
||||
temp := 0.7
|
||||
maxTok := 100
|
||||
topP := 0.9
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")},
|
||||
WithTemperature(temp),
|
||||
WithMaxTokens(maxTok),
|
||||
WithTopP(topP),
|
||||
WithStop("STOP", "END"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
req := mp.lastRequest()
|
||||
if req.Temperature == nil || *req.Temperature != temp {
|
||||
t.Errorf("expected temperature %v, got %v", temp, req.Temperature)
|
||||
}
|
||||
if req.MaxTokens == nil || *req.MaxTokens != maxTok {
|
||||
t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens)
|
||||
}
|
||||
if req.TopP == nil || *req.TopP != topP {
|
||||
t.Errorf("expected topP %v, got %v", topP, req.TopP)
|
||||
}
|
||||
if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" {
|
||||
t.Errorf("expected stop [STOP END], got %v", req.Stop)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_Error(t *testing.T) {
|
||||
wantErr := errors.New("provider error")
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, wantErr
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Errorf("expected error %v, got %v", wantErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_WithTools(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "done"})
|
||||
model := newMockModel(mp)
|
||||
|
||||
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||
return "hello", nil
|
||||
})
|
||||
tb := NewToolBox(tool)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
req := mp.lastRequest()
|
||||
if len(req.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
||||
}
|
||||
if req.Tools[0].Name != "greet" {
|
||||
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
|
||||
}
|
||||
if req.Tools[0].Description != "Says hello" {
|
||||
t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Model(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "hi"})
|
||||
client := newClient(mp)
|
||||
model := client.Model("test-model")
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "hi" {
|
||||
t.Errorf("expected 'hi', got %q", resp.Text)
|
||||
}
|
||||
|
||||
req := mp.lastRequest()
|
||||
if req.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_WithMiddleware(t *testing.T) {
|
||||
var called bool
|
||||
mw := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
called = true
|
||||
return next(ctx, model, messages, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
client := newClient(mp).WithMiddleware(mw)
|
||||
model := client.Model("test-model")
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("middleware was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_WithMiddleware(t *testing.T) {
|
||||
var order []string
|
||||
mw1 := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
order = append(order, "mw1")
|
||||
return next(ctx, model, messages, cfg)
|
||||
}
|
||||
}
|
||||
mw2 := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
order = append(order, "mw2")
|
||||
return next(ctx, model, messages, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" {
|
||||
t.Errorf("expected middleware order [mw1 mw2], got %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_NoUsage(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "no usage"})
|
||||
model := newMockModel(mp)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Errorf("expected nil usage, got %+v", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_ResponseMessage(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "response text"})
|
||||
model := newMockModel(mp)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
msg := resp.Message()
|
||||
if msg.Role != RoleAssistant {
|
||||
t.Errorf("expected role assistant, got %v", msg.Role)
|
||||
}
|
||||
if msg.Content.Text != "response text" {
|
||||
t.Errorf("expected text 'response text', got %q", msg.Content.Text)
|
||||
}
|
||||
}
|
||||
137
v2/request_test.go
Normal file
137
v2/request_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithTemperature(t *testing.T) {
|
||||
cfg := &requestConfig{}
|
||||
WithTemperature(0.7)(cfg)
|
||||
if cfg.temperature == nil || *cfg.temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", cfg.temperature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithMaxTokens(t *testing.T) {
|
||||
cfg := &requestConfig{}
|
||||
WithMaxTokens(256)(cfg)
|
||||
if cfg.maxTokens == nil || *cfg.maxTokens != 256 {
|
||||
t.Errorf("expected maxTokens 256, got %v", cfg.maxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTopP(t *testing.T) {
|
||||
cfg := &requestConfig{}
|
||||
WithTopP(0.95)(cfg)
|
||||
if cfg.topP == nil || *cfg.topP != 0.95 {
|
||||
t.Errorf("expected topP 0.95, got %v", cfg.topP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithStop(t *testing.T) {
|
||||
cfg := &requestConfig{}
|
||||
WithStop("END", "STOP", "###")(cfg)
|
||||
if len(cfg.stop) != 3 {
|
||||
t.Fatalf("expected 3 stop sequences, got %d", len(cfg.stop))
|
||||
}
|
||||
if cfg.stop[0] != "END" || cfg.stop[1] != "STOP" || cfg.stop[2] != "###" {
|
||||
t.Errorf("unexpected stop sequences: %v", cfg.stop)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTools(t *testing.T) {
|
||||
tool := DefineSimple("test", "A test tool", func(ctx context.Context) (string, error) {
|
||||
return "ok", nil
|
||||
})
|
||||
tb := NewToolBox(tool)
|
||||
|
||||
cfg := &requestConfig{}
|
||||
WithTools(tb)(cfg)
|
||||
if cfg.tools == nil {
|
||||
t.Fatal("expected tools to be set")
|
||||
}
|
||||
if len(cfg.tools.AllTools()) != 1 {
|
||||
t.Errorf("expected 1 tool, got %d", len(cfg.tools.AllTools()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProviderRequest(t *testing.T) {
|
||||
tool := DefineSimple("greet", "Greets", func(ctx context.Context) (string, error) {
|
||||
return "hi", nil
|
||||
})
|
||||
tb := NewToolBox(tool)
|
||||
|
||||
temp := 0.8
|
||||
maxTok := 512
|
||||
topP := 0.9
|
||||
|
||||
cfg := &requestConfig{
|
||||
tools: tb,
|
||||
temperature: &temp,
|
||||
maxTokens: &maxTok,
|
||||
topP: &topP,
|
||||
stop: []string{"END"},
|
||||
}
|
||||
|
||||
msgs := []Message{
|
||||
SystemMessage("be nice"),
|
||||
UserMessage("hello"),
|
||||
}
|
||||
|
||||
req := buildProviderRequest("test-model", msgs, cfg)
|
||||
|
||||
if req.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", req.Model)
|
||||
}
|
||||
if len(req.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(req.Messages))
|
||||
}
|
||||
if req.Messages[0].Role != "system" {
|
||||
t.Errorf("expected first message role='system', got %q", req.Messages[0].Role)
|
||||
}
|
||||
if req.Messages[1].Role != "user" {
|
||||
t.Errorf("expected second message role='user', got %q", req.Messages[1].Role)
|
||||
}
|
||||
if req.Temperature == nil || *req.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", req.Temperature)
|
||||
}
|
||||
if req.MaxTokens == nil || *req.MaxTokens != 512 {
|
||||
t.Errorf("expected maxTokens 512, got %v", req.MaxTokens)
|
||||
}
|
||||
if req.TopP == nil || *req.TopP != 0.9 {
|
||||
t.Errorf("expected topP 0.9, got %v", req.TopP)
|
||||
}
|
||||
if len(req.Stop) != 1 || req.Stop[0] != "END" {
|
||||
t.Errorf("expected stop=[END], got %v", req.Stop)
|
||||
}
|
||||
if len(req.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
||||
}
|
||||
if req.Tools[0].Name != "greet" {
|
||||
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProviderRequest_EmptyConfig(t *testing.T) {
|
||||
cfg := &requestConfig{}
|
||||
msgs := []Message{UserMessage("hi")}
|
||||
|
||||
req := buildProviderRequest("model", msgs, cfg)
|
||||
|
||||
if req.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", req.Temperature)
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
t.Errorf("expected nil maxTokens, got %v", req.MaxTokens)
|
||||
}
|
||||
if req.TopP != nil {
|
||||
t.Errorf("expected nil topP, got %v", req.TopP)
|
||||
}
|
||||
if len(req.Stop) != 0 {
|
||||
t.Errorf("expected no stop sequences, got %v", req.Stop)
|
||||
}
|
||||
if len(req.Tools) != 0 {
|
||||
t.Errorf("expected no tools, got %d", len(req.Tools))
|
||||
}
|
||||
}
|
||||
338
v2/stream_test.go
Normal file
338
v2/stream_test.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestStreamReader_TextEvents(t *testing.T) {
|
||||
events := []provider.StreamEvent{
|
||||
{Type: provider.StreamEventText, Text: "Hello"},
|
||||
{Type: provider.StreamEventText, Text: " world"},
|
||||
{Type: provider.StreamEventDone, Response: &provider.Response{
|
||||
Text: "Hello world",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 5,
|
||||
OutputTokens: 2,
|
||||
TotalTokens: 7,
|
||||
},
|
||||
}},
|
||||
}
|
||||
mp := newMockStreamProvider(events)
|
||||
model := newMockModel(mp)
|
||||
|
||||
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Read text events
|
||||
ev, err := reader.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on first event: %v", err)
|
||||
}
|
||||
if ev.Type != StreamEventText || ev.Text != "Hello" {
|
||||
t.Errorf("expected text event 'Hello', got type=%d text=%q", ev.Type, ev.Text)
|
||||
}
|
||||
|
||||
ev, err = reader.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on second event: %v", err)
|
||||
}
|
||||
if ev.Type != StreamEventText || ev.Text != " world" {
|
||||
t.Errorf("expected text event ' world', got type=%d text=%q", ev.Type, ev.Text)
|
||||
}
|
||||
|
||||
// Read done event
|
||||
ev, err = reader.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done event: %v", err)
|
||||
}
|
||||
if ev.Type != StreamEventDone {
|
||||
t.Errorf("expected done event, got type=%d", ev.Type)
|
||||
}
|
||||
if ev.Response == nil {
|
||||
t.Fatal("expected response in done event")
|
||||
}
|
||||
if ev.Response.Text != "Hello world" {
|
||||
t.Errorf("expected final text 'Hello world', got %q", ev.Response.Text)
|
||||
}
|
||||
|
||||
// Subsequent reads should return EOF
|
||||
_, err = reader.Next()
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Errorf("expected io.EOF after done, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamReader_ToolCallEvents(t *testing.T) {
|
||||
events := []provider.StreamEvent{
|
||||
{
|
||||
Type: provider.StreamEventToolStart,
|
||||
ToolIndex: 0,
|
||||
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search"},
|
||||
},
|
||||
{
|
||||
Type: provider.StreamEventToolDelta,
|
||||
ToolIndex: 0,
|
||||
ToolCall: &provider.ToolCall{Arguments: `{"query":`},
|
||||
},
|
||||
{
|
||||
Type: provider.StreamEventToolDelta,
|
||||
ToolIndex: 0,
|
||||
ToolCall: &provider.ToolCall{Arguments: `"test"}`},
|
||||
},
|
||||
{
|
||||
Type: provider.StreamEventToolEnd,
|
||||
ToolIndex: 0,
|
||||
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
|
||||
},
|
||||
{
|
||||
Type: provider.StreamEventDone,
|
||||
Response: &provider.Response{
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mp := newMockStreamProvider(events)
|
||||
model := newMockModel(mp)
|
||||
|
||||
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Read tool start
|
||||
ev, err := reader.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ev.Type != StreamEventToolStart {
|
||||
t.Errorf("expected tool start, got type=%d", ev.Type)
|
||||
}
|
||||
if ev.ToolCall == nil || ev.ToolCall.Name != "search" {
|
||||
t.Errorf("expected tool call 'search', got %+v", ev.ToolCall)
|
||||
}
|
||||
|
||||
// Read tool deltas
|
||||
ev, _ = reader.Next()
|
||||
if ev.Type != StreamEventToolDelta {
|
||||
t.Errorf("expected tool delta, got type=%d", ev.Type)
|
||||
}
|
||||
|
||||
ev, _ = reader.Next()
|
||||
if ev.Type != StreamEventToolDelta {
|
||||
t.Errorf("expected tool delta, got type=%d", ev.Type)
|
||||
}
|
||||
|
||||
// Read tool end
|
||||
ev, _ = reader.Next()
|
||||
if ev.Type != StreamEventToolEnd {
|
||||
t.Errorf("expected tool end, got type=%d", ev.Type)
|
||||
}
|
||||
if ev.ToolCall == nil || ev.ToolCall.Arguments != `{"query":"test"}` {
|
||||
t.Errorf("expected complete arguments, got %+v", ev.ToolCall)
|
||||
}
|
||||
|
||||
// Read done
|
||||
ev, _ = reader.Next()
|
||||
if ev.Type != StreamEventDone {
|
||||
t.Errorf("expected done, got type=%d", ev.Type)
|
||||
}
|
||||
if ev.Response == nil || len(ev.Response.ToolCalls) != 1 {
|
||||
t.Error("expected response with 1 tool call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamReader_Error(t *testing.T) {
|
||||
streamErr := errors.New("stream failed")
|
||||
mp := &mockProvider{
|
||||
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, nil
|
||||
},
|
||||
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
||||
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "partial"}
|
||||
ch <- provider.StreamEvent{Type: provider.StreamEventError, Error: streamErr}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
model := newMockModel(mp)
|
||||
|
||||
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Read partial text
|
||||
ev, err := reader.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ev.Text != "partial" {
|
||||
t.Errorf("expected 'partial', got %q", ev.Text)
|
||||
}
|
||||
|
||||
// Read error
|
||||
_, err = reader.Next()
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, streamErr) {
|
||||
t.Errorf("expected stream error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamReader_Close(t *testing.T) {
|
||||
// Create a stream that sends one event then blocks until context is cancelled
|
||||
mp := &mockProvider{
|
||||
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, nil
|
||||
},
|
||||
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
||||
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "start"}
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
}
|
||||
model := newMockModel(mp)
|
||||
|
||||
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Read the first event
|
||||
ev, err := reader.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on first event: %v", err)
|
||||
}
|
||||
if ev.Text != "start" {
|
||||
t.Errorf("expected 'start', got %q", ev.Text)
|
||||
}
|
||||
|
||||
// Close should cancel context
|
||||
if err := reader.Close(); err != nil {
|
||||
t.Fatalf("close error: %v", err)
|
||||
}
|
||||
|
||||
// After close, Next should eventually terminate with either EOF or context error.
|
||||
// The exact behavior depends on goroutine scheduling: the channel may close (EOF)
|
||||
// or the error event from the cancelled context may arrive first.
|
||||
_, err = reader.Next()
|
||||
if err == nil {
|
||||
t.Error("expected error after close, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamReader_Collect(t *testing.T) {
|
||||
events := []provider.StreamEvent{
|
||||
{Type: provider.StreamEventText, Text: "Hello"},
|
||||
{Type: provider.StreamEventText, Text: " world"},
|
||||
{Type: provider.StreamEventDone, Response: &provider.Response{
|
||||
Text: "Hello world",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 2,
|
||||
TotalTokens: 12,
|
||||
},
|
||||
}},
|
||||
}
|
||||
mp := newMockStreamProvider(events)
|
||||
model := newMockModel(mp)
|
||||
|
||||
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
resp, err := reader.Collect()
|
||||
if err != nil {
|
||||
t.Fatalf("collect error: %v", err)
|
||||
}
|
||||
if resp.Text != "Hello world" {
|
||||
t.Errorf("expected 'Hello world', got %q", resp.Text)
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("expected usage")
|
||||
}
|
||||
if resp.Usage.InputTokens != 10 {
|
||||
t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamReader_Text(t *testing.T) {
|
||||
events := []provider.StreamEvent{
|
||||
{Type: provider.StreamEventText, Text: "result"},
|
||||
{Type: provider.StreamEventDone, Response: &provider.Response{Text: "result"}},
|
||||
}
|
||||
mp := newMockStreamProvider(events)
|
||||
model := newMockModel(mp)
|
||||
|
||||
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
text, err := reader.Text()
|
||||
if err != nil {
|
||||
t.Fatalf("text error: %v", err)
|
||||
}
|
||||
if text != "result" {
|
||||
t.Errorf("expected 'result', got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamReader_EmptyStream(t *testing.T) {
|
||||
// Stream that completes without a done event (no response)
|
||||
mp := newMockStreamProvider([]provider.StreamEvent{
|
||||
{Type: provider.StreamEventText, Text: "hi"},
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
_, err = reader.Collect()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for stream without done event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamReader_StreamFuncError(t *testing.T) {
|
||||
// Stream function returns error directly
|
||||
mp := &mockProvider{
|
||||
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, nil
|
||||
},
|
||||
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
||||
return errors.New("stream init failed")
|
||||
},
|
||||
}
|
||||
model := newMockModel(mp)
|
||||
|
||||
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error creating reader: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// The error should come through as an error event
|
||||
_, err = reader.Collect()
|
||||
if err == nil {
|
||||
t.Fatal("expected error from stream function")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user