package llm import ( "context" "fmt" "log/slog" "sync" "time" ) // CompletionFunc is the signature for the completion call chain. type CompletionFunc func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) // Middleware wraps a completion call. It receives the next handler in the chain // and returns a new handler that can inspect/modify the request and response. type Middleware func(next CompletionFunc) CompletionFunc // WithLogging returns middleware that logs requests and responses via slog. func WithLogging(logger *slog.Logger) Middleware { return func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { logger.Info("llm request", "model", model, "message_count", len(messages), ) start := time.Now() resp, err := next(ctx, model, messages, cfg) elapsed := time.Since(start) if err != nil { logger.Error("llm error", "model", model, "elapsed", elapsed, "error", err) } else { logger.Info("llm response", "model", model, "elapsed", elapsed, "text_len", len(resp.Text), "tool_calls", len(resp.ToolCalls), ) } return resp, err } } } // WithRetry returns middleware that retries failed requests with configurable backoff. func WithRetry(maxRetries int, backoff func(attempt int) time.Duration) Middleware { return func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { if attempt > 0 { select { case <-ctx.Done(): return Response{}, ctx.Err() case <-time.After(backoff(attempt)): } } resp, err := next(ctx, model, messages, cfg) if err == nil { return resp, nil } lastErr = err } return Response{}, fmt.Errorf("after %d retries: %w", maxRetries, lastErr) } } } // WithTimeout returns middleware that enforces a per-request timeout. func WithTimeout(d time.Duration) Middleware { return func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { ctx, cancel := context.WithTimeout(ctx, d) defer cancel() return next(ctx, model, messages, cfg) } } } // UsageTracker accumulates token usage statistics across calls. type UsageTracker struct { mu sync.Mutex TotalInput int64 TotalOutput int64 TotalRequests int64 } // Add records usage from a single request. func (ut *UsageTracker) Add(u *Usage) { if u == nil { return } ut.mu.Lock() defer ut.mu.Unlock() ut.TotalInput += int64(u.InputTokens) ut.TotalOutput += int64(u.OutputTokens) ut.TotalRequests++ } // Summary returns the accumulated totals. func (ut *UsageTracker) Summary() (input, output, requests int64) { ut.mu.Lock() defer ut.mu.Unlock() return ut.TotalInput, ut.TotalOutput, ut.TotalRequests } // WithUsageTracking returns middleware that accumulates token usage across calls. func WithUsageTracking(tracker *UsageTracker) Middleware { return func(next CompletionFunc) CompletionFunc { return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { resp, err := next(ctx, model, messages, cfg) if err == nil { tracker.Add(resp.Usage) } return resp, err } } }