Files
go-llm/v2/middleware_test.go
Steve Dudenhoeffer 6a7eeef619
All checks were successful
CI / Lint (push) Successful in 9m36s
CI / V2 Module (push) Successful in 11m33s
CI / Root Module (push) Successful in 11m35s
Add comprehensive test suite for v2 module with mock provider
Cover all core library logic (Client, Model, Chat, middleware, streaming,
message conversion, request building) using a configurable mock provider
that avoids real API calls. ~50 tests across 7 files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 22:00:49 -05:00

283 lines
8.0 KiB
Go

package llm
import (
"context"
"errors"
"log/slog"
"sync"
"sync/atomic"
"testing"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestWithRetry_Success(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "ok" {
t.Errorf("expected 'ok', got %q", resp.Text)
}
if len(mp.Requests) != 1 {
t.Errorf("expected 1 request (no retries needed), got %d", len(mp.Requests))
}
}
func TestWithRetry_EventualSuccess(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n <= 2 {
return provider.Response{}, errors.New("transient error")
}
return provider.Response{Text: "success"}, nil
})
model := newMockModel(mp).WithMiddleware(
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "success" {
t.Errorf("expected 'success', got %q", resp.Text)
}
if atomic.LoadInt32(&callCount) != 3 {
t.Errorf("expected 3 calls, got %d", callCount)
}
}
func TestWithRetry_AllFail(t *testing.T) {
providerErr := errors.New("persistent error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, providerErr
})
model := newMockModel(mp).WithMiddleware(
WithRetry(2, func(attempt int) time.Duration { return time.Millisecond }),
)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, providerErr) {
t.Errorf("expected wrapped persistent error, got %v", err)
}
if len(mp.Requests) != 3 {
t.Errorf("expected 3 requests (1 initial + 2 retries), got %d", len(mp.Requests))
}
}
func TestWithRetry_ContextCancelled(t *testing.T) {
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, errors.New("fail")
})
model := newMockModel(mp).WithMiddleware(
WithRetry(10, func(attempt int) time.Duration { return 5 * time.Second }),
)
ctx, cancel := context.WithCancel(context.Background())
// Cancel after a short delay
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
_, err := model.Complete(ctx, []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, context.Canceled) {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestWithTimeout(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "fast"})
model := newMockModel(mp).WithMiddleware(WithTimeout(5 * time.Second))
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "fast" {
t.Errorf("expected 'fast', got %q", resp.Text)
}
}
func TestWithTimeout_Exceeded(t *testing.T) {
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
select {
case <-ctx.Done():
return provider.Response{}, ctx.Err()
case <-time.After(5 * time.Second):
return provider.Response{Text: "slow"}, nil
}
})
model := newMockModel(mp).WithMiddleware(WithTimeout(50 * time.Millisecond))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("expected DeadlineExceeded, got %v", err)
}
}
func TestWithUsageTracking(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "ok",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
// Make two requests
for i := 0; i < 2; i++ {
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error on call %d: %v", i, err)
}
}
input, output, requests := tracker.Summary()
if input != 20 {
t.Errorf("expected total input 20, got %d", input)
}
if output != 10 {
t.Errorf("expected total output 10, got %d", output)
}
if requests != 2 {
t.Errorf("expected 2 requests, got %d", requests)
}
}
func TestWithUsageTracking_NilUsage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "no usage"})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
input, output, requests := tracker.Summary()
if input != 0 || output != 0 {
t.Errorf("expected 0 tokens with nil usage, got input=%d output=%d", input, output)
}
// Add(nil) returns early without incrementing TotalRequests
if requests != 0 {
t.Errorf("expected 0 requests (nil usage skips Add), got %d", requests)
}
}
func TestUsageTracker_Concurrent(t *testing.T) {
tracker := &UsageTracker{}
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.Add(&Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
})
}()
}
wg.Wait()
input, output, requests := tracker.Summary()
if input != 1000 {
t.Errorf("expected total input 1000, got %d", input)
}
if output != 500 {
t.Errorf("expected total output 500, got %d", output)
}
if requests != 100 {
t.Errorf("expected 100 requests, got %d", requests)
}
}
func TestMiddleware_Chaining(t *testing.T) {
var order []string
mw1 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw1-before")
resp, err := next(ctx, model, messages, cfg)
order = append(order, "mw1-after")
return resp, err
}
}
mw2 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw2-before")
resp, err := next(ctx, model, messages, cfg)
order = append(order, "mw2-after")
return resp, err
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(mw1, mw2)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"}
if len(order) != len(expected) {
t.Fatalf("expected %d middleware calls, got %d: %v", len(expected), len(order), order)
}
for i, v := range expected {
if order[i] != v {
t.Errorf("order[%d]: expected %q, got %q", i, v, order[i])
}
}
}
func TestWithLogging(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "logged"})
logger := slog.Default()
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "logged" {
t.Errorf("expected 'logged', got %q", resp.Text)
}
}
func TestWithLogging_Error(t *testing.T) {
providerErr := errors.New("log this error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, providerErr
})
logger := slog.Default()
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if !errors.Is(err, providerErr) {
t.Errorf("expected provider error, got %v", err)
}
}