Files
go-llm/v2/middleware.go
Steve Dudenhoeffer a4cb4baab5 Add go-llm v2: redesigned API for simpler LLM abstraction
v2 is a new Go module (v2/) with a dramatically simpler API:
- Unified Message type (no more Input marker interface)
- Define[T] for ergonomic tool creation with standard context.Context
- Chat session with automatic tool-call loop (agent loop)
- Streaming via pull-based StreamReader
- MCP one-call connect (MCPStdioServer, MCPHTTPServer, MCPSSEServer)
- Middleware support (logging, retry, timeout, usage tracking)
- Decoupled JSON Schema (map[string]any, no provider coupling)
- Sample tools: WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP
- Providers: OpenAI, Anthropic, Google (all with streaming)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 20:00:08 -05:00

118 lines
3.4 KiB
Go

package llm
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// CompletionFunc is the signature for the completion call chain.
type CompletionFunc func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error)
// Middleware wraps a completion call. It receives the next handler in the chain
// and returns a new handler that can inspect/modify the request and response.
type Middleware func(next CompletionFunc) CompletionFunc
// WithLogging returns middleware that logs requests and responses via slog.
func WithLogging(logger *slog.Logger) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
logger.Info("llm request",
"model", model,
"message_count", len(messages),
)
start := time.Now()
resp, err := next(ctx, model, messages, cfg)
elapsed := time.Since(start)
if err != nil {
logger.Error("llm error", "model", model, "elapsed", elapsed, "error", err)
} else {
logger.Info("llm response",
"model", model,
"elapsed", elapsed,
"text_len", len(resp.Text),
"tool_calls", len(resp.ToolCalls),
)
}
return resp, err
}
}
}
// WithRetry returns middleware that retries failed requests with configurable backoff.
func WithRetry(maxRetries int, backoff func(attempt int) time.Duration) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
select {
case <-ctx.Done():
return Response{}, ctx.Err()
case <-time.After(backoff(attempt)):
}
}
resp, err := next(ctx, model, messages, cfg)
if err == nil {
return resp, nil
}
lastErr = err
}
return Response{}, fmt.Errorf("after %d retries: %w", maxRetries, lastErr)
}
}
}
// WithTimeout returns middleware that enforces a per-request timeout.
func WithTimeout(d time.Duration) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
ctx, cancel := context.WithTimeout(ctx, d)
defer cancel()
return next(ctx, model, messages, cfg)
}
}
}
// UsageTracker accumulates token usage statistics across calls.
type UsageTracker struct {
mu sync.Mutex
TotalInput int64
TotalOutput int64
TotalRequests int64
}
// Add records usage from a single request.
func (ut *UsageTracker) Add(u *Usage) {
if u == nil {
return
}
ut.mu.Lock()
defer ut.mu.Unlock()
ut.TotalInput += int64(u.InputTokens)
ut.TotalOutput += int64(u.OutputTokens)
ut.TotalRequests++
}
// Summary returns the accumulated totals.
func (ut *UsageTracker) Summary() (input, output, requests int64) {
ut.mu.Lock()
defer ut.mu.Unlock()
return ut.TotalInput, ut.TotalOutput, ut.TotalRequests
}
// WithUsageTracking returns middleware that accumulates token usage across calls.
func WithUsageTracking(tracker *UsageTracker) Middleware {
return func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
resp, err := next(ctx, model, messages, cfg)
if err == nil {
tracker.Add(resp.Usage)
}
return resp, err
}
}
}