package llm import ( "context" "errors" "log/slog" "sync" "sync/atomic" "testing" "time" "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" ) func TestWithRetry_Success(t *testing.T) { mp := newMockProvider(provider.Response{Text: "ok"}) model := newMockModel(mp).WithMiddleware( WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }), ) resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "ok" { t.Errorf("expected 'ok', got %q", resp.Text) } if len(mp.Requests) != 1 { t.Errorf("expected 1 request (no retries needed), got %d", len(mp.Requests)) } } func TestWithRetry_EventualSuccess(t *testing.T) { var callCount int32 mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { n := atomic.AddInt32(&callCount, 1) if n <= 2 { return provider.Response{}, errors.New("transient error") } return provider.Response{Text: "success"}, nil }) model := newMockModel(mp).WithMiddleware( WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }), ) resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "success" { t.Errorf("expected 'success', got %q", resp.Text) } if atomic.LoadInt32(&callCount) != 3 { t.Errorf("expected 3 calls, got %d", callCount) } } func TestWithRetry_AllFail(t *testing.T) { providerErr := errors.New("persistent error") mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, providerErr }) model := newMockModel(mp).WithMiddleware( WithRetry(2, func(attempt int) time.Duration { return time.Millisecond }), ) _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, providerErr) { t.Errorf("expected wrapped persistent error, got %v", err) } if len(mp.Requests) != 3 { t.Errorf("expected 3 requests (1 initial + 2 retries), got %d", len(mp.Requests)) } } func TestWithRetry_ContextCancelled(t *testing.T) { mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, errors.New("fail") }) model := newMockModel(mp).WithMiddleware( WithRetry(10, func(attempt int) time.Duration { return 5 * time.Second }), ) ctx, cancel := context.WithCancel(context.Background()) // Cancel after a short delay go func() { time.Sleep(50 * time.Millisecond) cancel() }() _, err := model.Complete(ctx, []Message{UserMessage("test")}) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, context.Canceled) { t.Errorf("expected context.Canceled, got %v", err) } } func TestWithTimeout(t *testing.T) { mp := newMockProvider(provider.Response{Text: "fast"}) model := newMockModel(mp).WithMiddleware(WithTimeout(5 * time.Second)) resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "fast" { t.Errorf("expected 'fast', got %q", resp.Text) } } func TestWithTimeout_Exceeded(t *testing.T) { mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { select { case <-ctx.Done(): return provider.Response{}, ctx.Err() case <-time.After(5 * time.Second): return provider.Response{Text: "slow"}, nil } }) model := newMockModel(mp).WithMiddleware(WithTimeout(50 * time.Millisecond)) _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, context.DeadlineExceeded) { t.Errorf("expected DeadlineExceeded, got %v", err) } } func TestWithUsageTracking(t *testing.T) { mp := newMockProvider(provider.Response{ Text: "ok", Usage: &provider.Usage{ InputTokens: 10, OutputTokens: 5, TotalTokens: 15, }, }) tracker := &UsageTracker{} model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker)) // Make two requests for i := 0; i < 2; i++ { _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error on call %d: %v", i, err) } } input, output, requests := tracker.Summary() if input != 20 { t.Errorf("expected total input 20, got %d", input) } if output != 10 { t.Errorf("expected total output 10, got %d", output) } if requests != 2 { t.Errorf("expected 2 requests, got %d", requests) } } func TestWithUsageTracking_NilUsage(t *testing.T) { mp := newMockProvider(provider.Response{Text: "no usage"}) tracker := &UsageTracker{} model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker)) _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } input, output, requests := tracker.Summary() if input != 0 || output != 0 { t.Errorf("expected 0 tokens with nil usage, got input=%d output=%d", input, output) } // Add(nil) returns early without incrementing TotalRequests if requests != 0 { t.Errorf("expected 0 requests (nil usage skips Add), got %d", requests) } } func TestUsageTracker_Concurrent(t *testing.T) { tracker := &UsageTracker{} var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func() { defer wg.Done() tracker.Add(&Usage{ InputTokens: 10, OutputTokens: 5, TotalTokens: 15, }) }() } wg.Wait() input, output, requests := tracker.Summary() if input != 1000 { t.Errorf("expected total input 1000, got %d", input) } if output != 500 { t.Errorf("expected total output 500, got %d", output) } if requests != 100 { t.Errorf("expected 100 requests, got %d", requests) } } func TestMiddleware_Chaining(t *testing.T) { var order []string mw1 := func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { order = append(order, "mw1-before") resp, err := next(ctx, model, messages, cfg) order = append(order, "mw1-after") return resp, err } } mw2 := func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { order = append(order, "mw2-before") resp, err := next(ctx, model, messages, cfg) order = append(order, "mw2-after") return resp, err } } mp := newMockProvider(provider.Response{Text: "ok"}) model := newMockModel(mp).WithMiddleware(mw1, mw2) _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } expected := []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"} if len(order) != len(expected) { t.Fatalf("expected %d middleware calls, got %d: %v", len(expected), len(order), order) } for i, v := range expected { if order[i] != v { t.Errorf("order[%d]: expected %q, got %q", i, v, order[i]) } } } func TestWithLogging(t *testing.T) { mp := newMockProvider(provider.Response{Text: "logged"}) logger := slog.Default() model := newMockModel(mp).WithMiddleware(WithLogging(logger)) resp, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.Text != "logged" { t.Errorf("expected 'logged', got %q", resp.Text) } } func TestWithLogging_Error(t *testing.T) { providerErr := errors.New("log this error") mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { return provider.Response{}, providerErr }) logger := slog.Default() model := newMockModel(mp).WithMiddleware(WithLogging(logger)) _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) if !errors.Is(err, providerErr) { t.Errorf("expected provider error, got %v", err) } }