Files
go-llm/v2/middleware_test.go
Steve Dudenhoeffer 5b687839b2
All checks were successful
CI / Lint (pull_request) Successful in 10m18s
CI / Root Module (pull_request) Successful in 11m4s
CI / V2 Module (pull_request) Successful in 11m5s
feat: comprehensive token usage tracking for V2
Add provider-specific usage details, fix streaming usage, and return
usage from all high-level APIs (Chat.Send, Generate[T], Agent.Run).

Breaking changes:
- Chat.Send/SendMessage/SendWithImages now return (string, *Usage, error)
- Generate[T]/GenerateWith[T] now return (T, *Usage, error)
- Agent.Run/RunMessages now return (string, *Usage, error)

New features:
- Usage.Details map for provider-specific token breakdowns
  (reasoning, cached, audio, thoughts tokens)
- OpenAI streaming now captures usage via StreamOptions.IncludeUsage
- Google streaming now captures UsageMetadata from final chunk
- UsageTracker.Details() for accumulated detail totals
- ModelPricing and PricingRegistry for cost computation

Closes #2

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 04:33:18 +00:00

360 lines
9.8 KiB
Go

package llm
import (
"context"
"errors"
"log/slog"
"sync"
"sync/atomic"
"testing"
"time"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestWithRetry_Success(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "ok" {
t.Errorf("expected 'ok', got %q", resp.Text)
}
if len(mp.Requests) != 1 {
t.Errorf("expected 1 request (no retries needed), got %d", len(mp.Requests))
}
}
func TestWithRetry_EventualSuccess(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n <= 2 {
return provider.Response{}, errors.New("transient error")
}
return provider.Response{Text: "success"}, nil
})
model := newMockModel(mp).WithMiddleware(
WithRetry(3, func(attempt int) time.Duration { return time.Millisecond }),
)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "success" {
t.Errorf("expected 'success', got %q", resp.Text)
}
if atomic.LoadInt32(&callCount) != 3 {
t.Errorf("expected 3 calls, got %d", callCount)
}
}
func TestWithRetry_AllFail(t *testing.T) {
providerErr := errors.New("persistent error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, providerErr
})
model := newMockModel(mp).WithMiddleware(
WithRetry(2, func(attempt int) time.Duration { return time.Millisecond }),
)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, providerErr) {
t.Errorf("expected wrapped persistent error, got %v", err)
}
if len(mp.Requests) != 3 {
t.Errorf("expected 3 requests (1 initial + 2 retries), got %d", len(mp.Requests))
}
}
func TestWithRetry_ContextCancelled(t *testing.T) {
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, errors.New("fail")
})
model := newMockModel(mp).WithMiddleware(
WithRetry(10, func(attempt int) time.Duration { return 5 * time.Second }),
)
ctx, cancel := context.WithCancel(context.Background())
// Cancel after a short delay
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
_, err := model.Complete(ctx, []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, context.Canceled) {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestWithTimeout(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "fast"})
model := newMockModel(mp).WithMiddleware(WithTimeout(5 * time.Second))
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "fast" {
t.Errorf("expected 'fast', got %q", resp.Text)
}
}
func TestWithTimeout_Exceeded(t *testing.T) {
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
select {
case <-ctx.Done():
return provider.Response{}, ctx.Err()
case <-time.After(5 * time.Second):
return provider.Response{Text: "slow"}, nil
}
})
model := newMockModel(mp).WithMiddleware(WithTimeout(50 * time.Millisecond))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("expected DeadlineExceeded, got %v", err)
}
}
func TestWithUsageTracking(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "ok",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
// Make two requests
for i := 0; i < 2; i++ {
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error on call %d: %v", i, err)
}
}
input, output, requests := tracker.Summary()
if input != 20 {
t.Errorf("expected total input 20, got %d", input)
}
if output != 10 {
t.Errorf("expected total output 10, got %d", output)
}
if requests != 2 {
t.Errorf("expected 2 requests, got %d", requests)
}
}
func TestWithUsageTracking_NilUsage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "no usage"})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
input, output, requests := tracker.Summary()
if input != 0 || output != 0 {
t.Errorf("expected 0 tokens with nil usage, got input=%d output=%d", input, output)
}
// Add(nil) returns early without incrementing TotalRequests
if requests != 0 {
t.Errorf("expected 0 requests (nil usage skips Add), got %d", requests)
}
}
func TestUsageTracker_Concurrent(t *testing.T) {
tracker := &UsageTracker{}
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.Add(&Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
})
}()
}
wg.Wait()
input, output, requests := tracker.Summary()
if input != 1000 {
t.Errorf("expected total input 1000, got %d", input)
}
if output != 500 {
t.Errorf("expected total output 500, got %d", output)
}
if requests != 100 {
t.Errorf("expected 100 requests, got %d", requests)
}
}
func TestMiddleware_Chaining(t *testing.T) {
var order []string
mw1 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw1-before")
resp, err := next(ctx, model, messages, cfg)
order = append(order, "mw1-after")
return resp, err
}
}
mw2 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw2-before")
resp, err := next(ctx, model, messages, cfg)
order = append(order, "mw2-after")
return resp, err
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(mw1, mw2)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := []string{"mw1-before", "mw2-before", "mw2-after", "mw1-after"}
if len(order) != len(expected) {
t.Fatalf("expected %d middleware calls, got %d: %v", len(expected), len(order), order)
}
for i, v := range expected {
if order[i] != v {
t.Errorf("order[%d]: expected %q, got %q", i, v, order[i])
}
}
}
func TestWithLogging(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "logged"})
logger := slog.Default()
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "logged" {
t.Errorf("expected 'logged', got %q", resp.Text)
}
}
func TestWithLogging_Error(t *testing.T) {
providerErr := errors.New("log this error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, providerErr
})
logger := slog.Default()
model := newMockModel(mp).WithMiddleware(WithLogging(logger))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if !errors.Is(err, providerErr) {
t.Errorf("expected provider error, got %v", err)
}
}
func TestUsageTracker_Details(t *testing.T) {
tracker := &UsageTracker{}
tracker.Add(&Usage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
Details: map[string]int{
"cached_input_tokens": 20,
"reasoning_tokens": 10,
},
})
tracker.Add(&Usage{
InputTokens: 80,
OutputTokens: 40,
TotalTokens: 120,
Details: map[string]int{
"cached_input_tokens": 15,
},
})
details := tracker.Details()
if details == nil {
t.Fatal("expected details, got nil")
}
if details["cached_input_tokens"] != 35 {
t.Errorf("expected cached_input_tokens=35, got %d", details["cached_input_tokens"])
}
if details["reasoning_tokens"] != 10 {
t.Errorf("expected reasoning_tokens=10, got %d", details["reasoning_tokens"])
}
// Verify returned map is a copy
details["cached_input_tokens"] = 999
fresh := tracker.Details()
if fresh["cached_input_tokens"] != 35 {
t.Error("Details() did not return a copy")
}
}
func TestUsageTracker_Details_Nil(t *testing.T) {
tracker := &UsageTracker{}
tracker.Add(&Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15})
details := tracker.Details()
if details != nil {
t.Errorf("expected nil details for usage without details, got %v", details)
}
}
func TestWithUsageTracking_WithDetails(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "ok",
Usage: &provider.Usage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
Details: map[string]int{
"cached_input_tokens": 30,
},
},
})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
details := tracker.Details()
if details["cached_input_tokens"] != 30 {
t.Errorf("expected cached_input_tokens=30, got %d", details["cached_input_tokens"])
}
}