Add go-llm v2: redesigned API for simpler LLM abstraction
v2 is a new Go module (v2/) with a dramatically simpler API: - Unified Message type (no more Input marker interface) - Define[T] for ergonomic tool creation with standard context.Context - Chat session with automatic tool-call loop (agent loop) - Streaming via pull-based StreamReader - MCP one-call connect (MCPStdioServer, MCPHTTPServer, MCPSSEServer) - Middleware support (logging, retry, timeout, usage tracking) - Decoupled JSON Schema (map[string]any, no provider coupling) - Sample tools: WebSearch, Browser, Exec, ReadFile, WriteFile, HTTP - Providers: OpenAI, Anthropic, Google (all with streaming) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
117
v2/middleware.go
Normal file
117
v2/middleware.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CompletionFunc is the signature for the completion call chain.
|
||||
type CompletionFunc func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error)
|
||||
|
||||
// Middleware wraps a completion call. It receives the next handler in the chain
|
||||
// and returns a new handler that can inspect/modify the request and response.
|
||||
type Middleware func(next CompletionFunc) CompletionFunc
|
||||
|
||||
// WithLogging returns middleware that logs requests and responses via slog.
|
||||
func WithLogging(logger *slog.Logger) Middleware {
|
||||
return func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
logger.Info("llm request",
|
||||
"model", model,
|
||||
"message_count", len(messages),
|
||||
)
|
||||
start := time.Now()
|
||||
resp, err := next(ctx, model, messages, cfg)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
logger.Error("llm error", "model", model, "elapsed", elapsed, "error", err)
|
||||
} else {
|
||||
logger.Info("llm response",
|
||||
"model", model,
|
||||
"elapsed", elapsed,
|
||||
"text_len", len(resp.Text),
|
||||
"tool_calls", len(resp.ToolCalls),
|
||||
)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetry returns middleware that retries failed requests with configurable backoff.
|
||||
func WithRetry(maxRetries int, backoff func(attempt int) time.Duration) Middleware {
|
||||
return func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return Response{}, ctx.Err()
|
||||
case <-time.After(backoff(attempt)):
|
||||
}
|
||||
}
|
||||
resp, err := next(ctx, model, messages, cfg)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return Response{}, fmt.Errorf("after %d retries: %w", maxRetries, lastErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout returns middleware that enforces a per-request timeout.
|
||||
func WithTimeout(d time.Duration) Middleware {
|
||||
return func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, d)
|
||||
defer cancel()
|
||||
return next(ctx, model, messages, cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UsageTracker accumulates token usage statistics across calls.
|
||||
type UsageTracker struct {
|
||||
mu sync.Mutex
|
||||
TotalInput int64
|
||||
TotalOutput int64
|
||||
TotalRequests int64
|
||||
}
|
||||
|
||||
// Add records usage from a single request.
|
||||
func (ut *UsageTracker) Add(u *Usage) {
|
||||
if u == nil {
|
||||
return
|
||||
}
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
ut.TotalInput += int64(u.InputTokens)
|
||||
ut.TotalOutput += int64(u.OutputTokens)
|
||||
ut.TotalRequests++
|
||||
}
|
||||
|
||||
// Summary returns the accumulated totals.
|
||||
func (ut *UsageTracker) Summary() (input, output, requests int64) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
return ut.TotalInput, ut.TotalOutput, ut.TotalRequests
|
||||
}
|
||||
|
||||
// WithUsageTracking returns middleware that accumulates token usage across calls.
|
||||
func WithUsageTracking(tracker *UsageTracker) Middleware {
|
||||
return func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
resp, err := next(ctx, model, messages, cfg)
|
||||
if err == nil {
|
||||
tracker.Add(resp.Usage)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user