Cover all core library logic (Client, Model, Chat, middleware, streaming, message conversion, request building) using a configurable mock provider that avoids real API calls. ~50 tests across 7 files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
339 lines
9.0 KiB
Go
339 lines
9.0 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"testing"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
)
|
|
|
|
func TestStreamReader_TextEvents(t *testing.T) {
|
|
events := []provider.StreamEvent{
|
|
{Type: provider.StreamEventText, Text: "Hello"},
|
|
{Type: provider.StreamEventText, Text: " world"},
|
|
{Type: provider.StreamEventDone, Response: &provider.Response{
|
|
Text: "Hello world",
|
|
Usage: &provider.Usage{
|
|
InputTokens: 5,
|
|
OutputTokens: 2,
|
|
TotalTokens: 7,
|
|
},
|
|
}},
|
|
}
|
|
mp := newMockStreamProvider(events)
|
|
model := newMockModel(mp)
|
|
|
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
// Read text events
|
|
ev, err := reader.Next()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error on first event: %v", err)
|
|
}
|
|
if ev.Type != StreamEventText || ev.Text != "Hello" {
|
|
t.Errorf("expected text event 'Hello', got type=%d text=%q", ev.Type, ev.Text)
|
|
}
|
|
|
|
ev, err = reader.Next()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error on second event: %v", err)
|
|
}
|
|
if ev.Type != StreamEventText || ev.Text != " world" {
|
|
t.Errorf("expected text event ' world', got type=%d text=%q", ev.Type, ev.Text)
|
|
}
|
|
|
|
// Read done event
|
|
ev, err = reader.Next()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error on done event: %v", err)
|
|
}
|
|
if ev.Type != StreamEventDone {
|
|
t.Errorf("expected done event, got type=%d", ev.Type)
|
|
}
|
|
if ev.Response == nil {
|
|
t.Fatal("expected response in done event")
|
|
}
|
|
if ev.Response.Text != "Hello world" {
|
|
t.Errorf("expected final text 'Hello world', got %q", ev.Response.Text)
|
|
}
|
|
|
|
// Subsequent reads should return EOF
|
|
_, err = reader.Next()
|
|
if !errors.Is(err, io.EOF) {
|
|
t.Errorf("expected io.EOF after done, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestStreamReader_ToolCallEvents(t *testing.T) {
|
|
events := []provider.StreamEvent{
|
|
{
|
|
Type: provider.StreamEventToolStart,
|
|
ToolIndex: 0,
|
|
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search"},
|
|
},
|
|
{
|
|
Type: provider.StreamEventToolDelta,
|
|
ToolIndex: 0,
|
|
ToolCall: &provider.ToolCall{Arguments: `{"query":`},
|
|
},
|
|
{
|
|
Type: provider.StreamEventToolDelta,
|
|
ToolIndex: 0,
|
|
ToolCall: &provider.ToolCall{Arguments: `"test"}`},
|
|
},
|
|
{
|
|
Type: provider.StreamEventToolEnd,
|
|
ToolIndex: 0,
|
|
ToolCall: &provider.ToolCall{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
|
|
},
|
|
{
|
|
Type: provider.StreamEventDone,
|
|
Response: &provider.Response{
|
|
ToolCalls: []provider.ToolCall{
|
|
{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mp := newMockStreamProvider(events)
|
|
model := newMockModel(mp)
|
|
|
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
// Read tool start
|
|
ev, err := reader.Next()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if ev.Type != StreamEventToolStart {
|
|
t.Errorf("expected tool start, got type=%d", ev.Type)
|
|
}
|
|
if ev.ToolCall == nil || ev.ToolCall.Name != "search" {
|
|
t.Errorf("expected tool call 'search', got %+v", ev.ToolCall)
|
|
}
|
|
|
|
// Read tool deltas
|
|
ev, _ = reader.Next()
|
|
if ev.Type != StreamEventToolDelta {
|
|
t.Errorf("expected tool delta, got type=%d", ev.Type)
|
|
}
|
|
|
|
ev, _ = reader.Next()
|
|
if ev.Type != StreamEventToolDelta {
|
|
t.Errorf("expected tool delta, got type=%d", ev.Type)
|
|
}
|
|
|
|
// Read tool end
|
|
ev, _ = reader.Next()
|
|
if ev.Type != StreamEventToolEnd {
|
|
t.Errorf("expected tool end, got type=%d", ev.Type)
|
|
}
|
|
if ev.ToolCall == nil || ev.ToolCall.Arguments != `{"query":"test"}` {
|
|
t.Errorf("expected complete arguments, got %+v", ev.ToolCall)
|
|
}
|
|
|
|
// Read done
|
|
ev, _ = reader.Next()
|
|
if ev.Type != StreamEventDone {
|
|
t.Errorf("expected done, got type=%d", ev.Type)
|
|
}
|
|
if ev.Response == nil || len(ev.Response.ToolCalls) != 1 {
|
|
t.Error("expected response with 1 tool call")
|
|
}
|
|
}
|
|
|
|
func TestStreamReader_Error(t *testing.T) {
|
|
streamErr := errors.New("stream failed")
|
|
mp := &mockProvider{
|
|
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
return provider.Response{}, nil
|
|
},
|
|
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
|
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "partial"}
|
|
ch <- provider.StreamEvent{Type: provider.StreamEventError, Error: streamErr}
|
|
return nil
|
|
},
|
|
}
|
|
model := newMockModel(mp)
|
|
|
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
// Read partial text
|
|
ev, err := reader.Next()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if ev.Text != "partial" {
|
|
t.Errorf("expected 'partial', got %q", ev.Text)
|
|
}
|
|
|
|
// Read error
|
|
_, err = reader.Next()
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if !errors.Is(err, streamErr) {
|
|
t.Errorf("expected stream error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestStreamReader_Close(t *testing.T) {
|
|
// Create a stream that sends one event then blocks until context is cancelled
|
|
mp := &mockProvider{
|
|
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
return provider.Response{}, nil
|
|
},
|
|
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
|
ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "start"}
|
|
<-ctx.Done()
|
|
return ctx.Err()
|
|
},
|
|
}
|
|
model := newMockModel(mp)
|
|
|
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
// Read the first event
|
|
ev, err := reader.Next()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error on first event: %v", err)
|
|
}
|
|
if ev.Text != "start" {
|
|
t.Errorf("expected 'start', got %q", ev.Text)
|
|
}
|
|
|
|
// Close should cancel context
|
|
if err := reader.Close(); err != nil {
|
|
t.Fatalf("close error: %v", err)
|
|
}
|
|
|
|
// After close, Next should eventually terminate with either EOF or context error.
|
|
// The exact behavior depends on goroutine scheduling: the channel may close (EOF)
|
|
// or the error event from the cancelled context may arrive first.
|
|
_, err = reader.Next()
|
|
if err == nil {
|
|
t.Error("expected error after close, got nil")
|
|
}
|
|
}
|
|
|
|
func TestStreamReader_Collect(t *testing.T) {
|
|
events := []provider.StreamEvent{
|
|
{Type: provider.StreamEventText, Text: "Hello"},
|
|
{Type: provider.StreamEventText, Text: " world"},
|
|
{Type: provider.StreamEventDone, Response: &provider.Response{
|
|
Text: "Hello world",
|
|
Usage: &provider.Usage{
|
|
InputTokens: 10,
|
|
OutputTokens: 2,
|
|
TotalTokens: 12,
|
|
},
|
|
}},
|
|
}
|
|
mp := newMockStreamProvider(events)
|
|
model := newMockModel(mp)
|
|
|
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
resp, err := reader.Collect()
|
|
if err != nil {
|
|
t.Fatalf("collect error: %v", err)
|
|
}
|
|
if resp.Text != "Hello world" {
|
|
t.Errorf("expected 'Hello world', got %q", resp.Text)
|
|
}
|
|
if resp.Usage == nil {
|
|
t.Fatal("expected usage")
|
|
}
|
|
if resp.Usage.InputTokens != 10 {
|
|
t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens)
|
|
}
|
|
}
|
|
|
|
func TestStreamReader_Text(t *testing.T) {
|
|
events := []provider.StreamEvent{
|
|
{Type: provider.StreamEventText, Text: "result"},
|
|
{Type: provider.StreamEventDone, Response: &provider.Response{Text: "result"}},
|
|
}
|
|
mp := newMockStreamProvider(events)
|
|
model := newMockModel(mp)
|
|
|
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
text, err := reader.Text()
|
|
if err != nil {
|
|
t.Fatalf("text error: %v", err)
|
|
}
|
|
if text != "result" {
|
|
t.Errorf("expected 'result', got %q", text)
|
|
}
|
|
}
|
|
|
|
func TestStreamReader_EmptyStream(t *testing.T) {
|
|
// Stream that completes without a done event (no response)
|
|
mp := newMockStreamProvider([]provider.StreamEvent{
|
|
{Type: provider.StreamEventText, Text: "hi"},
|
|
})
|
|
model := newMockModel(mp)
|
|
|
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
_, err = reader.Collect()
|
|
if err == nil {
|
|
t.Fatal("expected error for stream without done event")
|
|
}
|
|
}
|
|
|
|
func TestStreamReader_StreamFuncError(t *testing.T) {
|
|
// Stream function returns error directly
|
|
mp := &mockProvider{
|
|
CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
return provider.Response{}, nil
|
|
},
|
|
StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error {
|
|
return errors.New("stream init failed")
|
|
},
|
|
}
|
|
model := newMockModel(mp)
|
|
|
|
reader, err := model.Stream(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error creating reader: %v", err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
// The error should come through as an error event
|
|
_, err = reader.Collect()
|
|
if err == nil {
|
|
t.Fatal("expected error from stream function")
|
|
}
|
|
}
|