Add comprehensive test suite for v2 module with mock provider
Cover all core library logic (Client, Model, Chat, middleware, streaming, message conversion, request building) using a configurable mock provider that avoids real API calls. ~50 tests across 7 files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
282
v2/middleware_test.go
Normal file
282
v2/middleware_test.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestWithRetry_Success(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp).WithMiddleware(
|
||||
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
|
||||
)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "ok" {
|
||||
t.Errorf("expected 'ok', got %q", resp.Text)
|
||||
}
|
||||
if len(mp.Requests) != 1 {
|
||||
t.Errorf("expected 1 request (no retries needed), got %d", len(mp.Requests))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithRetry_EventualSuccess(t *testing.T) {
|
||||
var callCount int32
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
n := atomic.AddInt32(&callCount, 1)
|
||||
if n <= 2 {
|
||||
return provider.Response{}, errors.New("transient error")
|
||||
}
|
||||
return provider.Response{Text: "success"}, nil
|
||||
})
|
||||
model := newMockModel(mp).WithMiddleware(
|
||||
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
|
||||
)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "success" {
|
||||
t.Errorf("expected 'success', got %q", resp.Text)
|
||||
}
|
||||
if atomic.LoadInt32(&callCount) != 3 {
|
||||
t.Errorf("expected 3 calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithRetry_AllFail(t *testing.T) {
|
||||
providerErr := errors.New("persistent error")
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, providerErr
|
||||
})
|
||||
model := newMockModel(mp).WithMiddleware(
|
||||
WithRetry(2, func(attempt int) time.Duration { return time.Millisecond }),
|
||||
)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, providerErr) {
|
||||
t.Errorf("expected wrapped persistent error, got %v", err)
|
||||
}
|
||||
if len(mp.Requests) != 3 {
|
||||
t.Errorf("expected 3 requests (1 initial + 2 retries), got %d", len(mp.Requests))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithRetry_ContextCancelled(t *testing.T) {
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, errors.New("fail")
|
||||
})
|
||||
model := newMockModel(mp).WithMiddleware(
|
||||
WithRetry(10, func(attempt int) time.Duration { return 5 * time.Second }),
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Cancel after a short delay
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
_, err := model.Complete(ctx, []Message{UserMessage("test")})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "fast"})
|
||||
model := newMockModel(mp).WithMiddleware(WithTimeout(5 * time.Second))
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "fast" {
|
||||
t.Errorf("expected 'fast', got %q", resp.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTimeout_Exceeded(t *testing.T) {
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return provider.Response{}, ctx.Err()
|
||||
case <-time.After(5 * time.Second):
|
||||
return provider.Response{Text: "slow"}, nil
|
||||
}
|
||||
})
|
||||
model := newMockModel(mp).WithMiddleware(WithTimeout(50 * time.Millisecond))
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Errorf("expected DeadlineExceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithUsageTracking(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{
|
||||
Text: "ok",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
})
|
||||
tracker := &UsageTracker{}
|
||||
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
|
||||
|
||||
// Make two requests
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on call %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
input, output, requests := tracker.Summary()
|
||||
if input != 20 {
|
||||
t.Errorf("expected total input 20, got %d", input)
|
||||
}
|
||||
if output != 10 {
|
||||
t.Errorf("expected total output 10, got %d", output)
|
||||
}
|
||||
if requests != 2 {
|
||||
t.Errorf("expected 2 requests, got %d", requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithUsageTracking_NilUsage(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "no usage"})
|
||||
tracker := &UsageTracker{}
|
||||
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
input, output, requests := tracker.Summary()
|
||||
if input != 0 || output != 0 {
|
||||
t.Errorf("expected 0 tokens with nil usage, got input=%d output=%d", input, output)
|
||||
}
|
||||
// Add(nil) returns early without incrementing TotalRequests
|
||||
if requests != 0 {
|
||||
t.Errorf("expected 0 requests (nil usage skips Add), got %d", requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTracker_Concurrent(t *testing.T) {
|
||||
tracker := &UsageTracker{}
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.Add(&Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
input, output, requests := tracker.Summary()
|
||||
if input != 1000 {
|
||||
t.Errorf("expected total input 1000, got %d", input)
|
||||
}
|
||||
if output != 500 {
|
||||
t.Errorf("expected total output 500, got %d", output)
|
||||
}
|
||||
if requests != 100 {
|
||||
t.Errorf("expected 100 requests, got %d", requests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_Chaining(t *testing.T) {
|
||||
var order []string
|
||||
mw1 := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
order = append(order, "mw1-before")
|
||||
resp, err := next(ctx, model, messages, cfg)
|
||||
order = append(order, "mw1-after")
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
mw2 := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
order = append(order, "mw2-before")
|
||||
resp, err := next(ctx, model, messages, cfg)
|
||||
order = append(order, "mw2-after")
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp).WithMiddleware(mw1, mw2)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"}
|
||||
if len(order) != len(expected) {
|
||||
t.Fatalf("expected %d middleware calls, got %d: %v", len(expected), len(order), order)
|
||||
}
|
||||
for i, v := range expected {
|
||||
if order[i] != v {
|
||||
t.Errorf("order[%d]: expected %q, got %q", i, v, order[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogging(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "logged"})
|
||||
logger := slog.Default()
|
||||
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "logged" {
|
||||
t.Errorf("expected 'logged', got %q", resp.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithLogging_Error(t *testing.T) {
|
||||
providerErr := errors.New("log this error")
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, providerErr
|
||||
})
|
||||
logger := slog.Default()
|
||||
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if !errors.Is(err, providerErr) {
|
||||
t.Errorf("expected provider error, got %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user