package llm import ( "context" "errors" "io" "testing" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) func TestStreamReader_TextEvents(t *testing.T) { events := []provider.StreamEvent{ {Type: provider.StreamEventText, Text: "Hello"}, {Type: provider.StreamEventText, Text: " world"}, {Type: provider.StreamEventDone, Response: &provider.Response{ Text: "Hello world", Usage: &provider.Usage{ InputTokens: 5, OutputTokens: 2, TotalTokens: 7, }, }}, } mp := newMockStreamProvider(events) model := newMockModel(mp) reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } defer reader.Close() // Read text events ev, err := reader.Next() if err != nil { t.Fatalf("unexpected error on first event: %v", err) } if ev.Type != StreamEventText || ev.Text != "Hello" { t.Errorf("expected text event 'Hello', got type=%d text=%q", ev.Type, ev.Text) } ev, err = reader.Next() if err != nil { t.Fatalf("unexpected error on second event: %v", err) } if ev.Type != StreamEventText || ev.Text != " world" { t.Errorf("expected text event ' world', got type=%d text=%q", ev.Type, ev.Text) } // Read done event ev, err = reader.Next() if err != nil { t.Fatalf("unexpected error on done event: %v", err) } if ev.Type != StreamEventDone { t.Errorf("expected done event, got type=%d", ev.Type) } if ev.Response == nil { t.Fatal("expected response in done event") } if ev.Response.Text != "Hello world" { t.Errorf("expected final text 'Hello world', got %q", ev.Response.Text) } // Subsequent reads should return EOF _, err = reader.Next() if !errors.Is(err, io.EOF) { t.Errorf("expected io.EOF after done, got %v", err) } } func TestStreamReader_ToolCallEvents(t *testing.T) { events := []provider.StreamEvent{ { Type: provider.StreamEventToolStart, ToolIndex: 0, ToolCall: &provider.ToolCall{ID: "tc1", Name: "search"}, }, { Type: provider.StreamEventToolDelta, ToolIndex: 0, ToolCall: &provider.ToolCall{Arguments: `{"query":`}, }, { Type: provider.StreamEventToolDelta, ToolIndex: 0, ToolCall: &provider.ToolCall{Arguments: `"test"}`}, }, { Type: provider.StreamEventToolEnd, ToolIndex: 0, ToolCall: &provider.ToolCall{ID: "tc1", Name: "search", Arguments: `{"query":"test"}`}, }, { Type: provider.StreamEventDone, Response: &provider.Response{ ToolCalls: []provider.ToolCall{ {ID: "tc1", Name: "search", Arguments: `{"query":"test"}`}, }, }, }, } mp := newMockStreamProvider(events) model := newMockModel(mp) reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } defer reader.Close() // Read tool start ev, err := reader.Next() if err != nil { t.Fatalf("unexpected error: %v", err) } if ev.Type != StreamEventToolStart { t.Errorf("expected tool start, got type=%d", ev.Type) } if ev.ToolCall == nil || ev.ToolCall.Name != "search" { t.Errorf("expected tool call 'search', got %+v", ev.ToolCall) } // Read tool deltas ev, _ = reader.Next() if ev.Type != StreamEventToolDelta { t.Errorf("expected tool delta, got type=%d", ev.Type) } ev, _ = reader.Next() if ev.Type != StreamEventToolDelta { t.Errorf("expected tool delta, got type=%d", ev.Type) } // Read tool end ev, _ = reader.Next() if ev.Type != StreamEventToolEnd { t.Errorf("expected tool end, got type=%d", ev.Type) } if ev.ToolCall == nil || ev.ToolCall.Arguments != `{"query":"test"}` { t.Errorf("expected complete arguments, got %+v", ev.ToolCall) } // Read done ev, _ = reader.Next() if ev.Type != StreamEventDone { t.Errorf("expected done, got type=%d", ev.Type) } if ev.Response == nil || len(ev.Response.ToolCalls) != 1 { t.Error("expected response with 1 tool call") } } func TestStreamReader_Error(t *testing.T) { streamErr := errors.New("stream failed") mp := &mockProvider{ CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, nil }, StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error { ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "partial"} ch <- provider.StreamEvent{Type: provider.StreamEventError, Error: streamErr} return nil }, } model := newMockModel(mp) reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } defer reader.Close() // Read partial text ev, err := reader.Next() if err != nil { t.Fatalf("unexpected error: %v", err) } if ev.Text != "partial" { t.Errorf("expected 'partial', got %q", ev.Text) } // Read error _, err = reader.Next() if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, streamErr) { t.Errorf("expected stream error, got %v", err) } } func TestStreamReader_Close(t *testing.T) { // Create a stream that sends one event then blocks until context is cancelled mp := &mockProvider{ CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, nil }, StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error { ch <- provider.StreamEvent{Type: provider.StreamEventText, Text: "start"} <-ctx.Done() return ctx.Err() }, } model := newMockModel(mp) reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } // Read the first event ev, err := reader.Next() if err != nil { t.Fatalf("unexpected error on first event: %v", err) } if ev.Text != "start" { t.Errorf("expected 'start', got %q", ev.Text) } // Close should cancel context if err := reader.Close(); err != nil { t.Fatalf("close error: %v", err) } // After close, Next should eventually terminate with either EOF or context error. // The exact behavior depends on goroutine scheduling: the channel may close (EOF) // or the error event from the cancelled context may arrive first. _, err = reader.Next() if err == nil { t.Error("expected error after close, got nil") } } func TestStreamReader_Collect(t *testing.T) { events := []provider.StreamEvent{ {Type: provider.StreamEventText, Text: "Hello"}, {Type: provider.StreamEventText, Text: " world"}, {Type: provider.StreamEventDone, Response: &provider.Response{ Text: "Hello world", Usage: &provider.Usage{ InputTokens: 10, OutputTokens: 2, TotalTokens: 12, }, }}, } mp := newMockStreamProvider(events) model := newMockModel(mp) reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } defer reader.Close() resp, err := reader.Collect() if err != nil { t.Fatalf("collect error: %v", err) } if resp.Text != "Hello world" { t.Errorf("expected 'Hello world', got %q", resp.Text) } if resp.Usage == nil { t.Fatal("expected usage") } if resp.Usage.InputTokens != 10 { t.Errorf("expected 10 input tokens, got %d", resp.Usage.InputTokens) } } func TestStreamReader_Text(t *testing.T) { events := []provider.StreamEvent{ {Type: provider.StreamEventText, Text: "result"}, {Type: provider.StreamEventDone, Response: &provider.Response{Text: "result"}}, } mp := newMockStreamProvider(events) model := newMockModel(mp) reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } defer reader.Close() text, err := reader.Text() if err != nil { t.Fatalf("text error: %v", err) } if text != "result" { t.Errorf("expected 'result', got %q", text) } } func TestStreamReader_EmptyStream(t *testing.T) { // Stream that completes without a done event (no response) mp := newMockStreamProvider([]provider.StreamEvent{ {Type: provider.StreamEventText, Text: "hi"}, }) model := newMockModel(mp) reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } defer reader.Close() _, err = reader.Collect() if err == nil { t.Fatal("expected error for stream without done event") } } func TestStreamReader_StreamFuncError(t *testing.T) { // Stream function returns error directly mp := &mockProvider{ CompleteFunc: func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, nil }, StreamFunc: func(ctx context.Context, req provider.Request, ch chan<- provider.StreamEvent) error { return errors.New("stream init failed") }, } model := newMockModel(mp) reader, err := model.Stream(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error creating reader: %v", err) } defer reader.Close() // The error should come through as an error event _, err = reader.Collect() if err == nil { t.Fatal("expected error from stream function") } }