feat: comprehensive token usage tracking for V2
All checks were successful
CI / Lint (pull_request) Successful in 10m18s
CI / Root Module (pull_request) Successful in 11m4s
CI / V2 Module (pull_request) Successful in 11m5s

Add provider-specific usage details, fix streaming usage, and return
usage from all high-level APIs (Chat.Send, Generate[T], Agent.Run).

Breaking changes:
- Chat.Send/SendMessage/SendWithImages now return (string, *Usage, error)
- Generate[T]/GenerateWith[T] now return (T, *Usage, error)
- Agent.Run/RunMessages now return (string, *Usage, error)

New features:
- Usage.Details map for provider-specific token breakdowns
  (reasoning, cached, audio, thoughts tokens)
- OpenAI streaming now captures usage via StreamOptions.IncludeUsage
- Google streaming now captures UsageMetadata from final chunk
- UsageTracker.Details() for accumulated detail totals
- ModelPricing and PricingRegistry for cost computation

Closes #2

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-02 04:33:18 +00:00
parent 7e1705c385
commit 5b687839b2
17 changed files with 684 additions and 61 deletions

View File

@@ -18,7 +18,7 @@
// coder.AsTool("code", "Write and run code"), // coder.AsTool("code", "Write and run code"),
// )), // )),
// ) // )
// result, err := orchestrator.Run(ctx, "Build a fibonacci function in Go") // result, _, err := orchestrator.Run(ctx, "Build a fibonacci function in Go")
package agent package agent
import ( import (
@@ -64,13 +64,15 @@ func New(model *llm.Model, system string, opts ...Option) *Agent {
// Run executes the agent with a user prompt. Each call is a fresh conversation. // Run executes the agent with a user prompt. Each call is a fresh conversation.
// The agent loops tool calls automatically until it produces a text response. // The agent loops tool calls automatically until it produces a text response.
func (a *Agent) Run(ctx context.Context, prompt string) (string, error) { // Returns the text response, accumulated token usage, and any error.
func (a *Agent) Run(ctx context.Context, prompt string) (string, *llm.Usage, error) {
return a.RunMessages(ctx, []llm.Message{llm.UserMessage(prompt)}) return a.RunMessages(ctx, []llm.Message{llm.UserMessage(prompt)})
} }
// RunMessages executes the agent with full message control. // RunMessages executes the agent with full message control.
// Each call is a fresh conversation. The agent loops tool calls automatically. // Each call is a fresh conversation. The agent loops tool calls automatically.
func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, error) { // Returns the text response, accumulated token usage, and any error.
func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, *llm.Usage, error) {
chat := llm.NewChat(a.model, a.reqOpts...) chat := llm.NewChat(a.model, a.reqOpts...)
if a.system != "" { if a.system != "" {
chat.SetSystem(a.system) chat.SetSystem(a.system)
@@ -107,7 +109,8 @@ type delegateParams struct {
func (a *Agent) AsTool(name, description string) llm.Tool { func (a *Agent) AsTool(name, description string) llm.Tool {
return llm.Define[delegateParams](name, description, return llm.Define[delegateParams](name, description,
func(ctx context.Context, p delegateParams) (string, error) { func(ctx context.Context, p delegateParams) (string, error) {
return a.Run(ctx, p.Input) text, _, err := a.Run(ctx, p.Input)
return text, err
}, },
) )
} }

View File

@@ -29,15 +29,6 @@ func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events
return nil return nil
} }
func (m *mockProvider) lastRequest() provider.Request {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.requests) == 0 {
return provider.Request{}
}
return m.requests[len(m.requests)-1]
}
func newMockModel(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *llm.Model { func newMockModel(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *llm.Model {
mp := &mockProvider{completeFunc: fn} mp := &mockProvider{completeFunc: fn}
return llm.NewClient(mp).Model("mock-model") return llm.NewClient(mp).Model("mock-model")
@@ -53,7 +44,7 @@ func TestAgent_Run(t *testing.T) {
model := newSimpleMockModel("Hello from agent!") model := newSimpleMockModel("Hello from agent!")
a := New(model, "You are a helpful assistant.") a := New(model, "You are a helpful assistant.")
result, err := a.Run(context.Background(), "Say hello") result, _, err := a.Run(context.Background(), "Say hello")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -83,7 +74,7 @@ func TestAgent_Run_WithTools(t *testing.T) {
}) })
a := New(model, "You are helpful.", WithTools(llm.NewToolBox(tool))) a := New(model, "You are helpful.", WithTools(llm.NewToolBox(tool)))
result, err := a.Run(context.Background(), "Use the greet tool") result, _, err := a.Run(context.Background(), "Use the greet tool")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -147,7 +138,7 @@ func TestAgent_AsTool_ParentChild(t *testing.T) {
)), )),
) )
result, err := parent.Run(context.Background(), "Tell me about Go generics") result, _, err := parent.Run(context.Background(), "Tell me about Go generics")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -169,7 +160,7 @@ func TestAgent_RunMessages(t *testing.T) {
llm.UserMessage("Follow up"), llm.UserMessage("Follow up"),
} }
result, err := a.RunMessages(context.Background(), messages) result, _, err := a.RunMessages(context.Background(), messages)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -187,7 +178,7 @@ func TestAgent_ContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately cancel() // Cancel immediately
_, err := a.Run(ctx, "This should fail") _, _, err := a.Run(ctx, "This should fail")
if err == nil { if err == nil {
t.Fatal("expected error from cancelled context") t.Fatal("expected error from cancelled context")
} }
@@ -204,7 +195,7 @@ func TestAgent_WithRequestOptions(t *testing.T) {
WithRequestOptions(llm.WithTemperature(0.3), llm.WithMaxTokens(100)), WithRequestOptions(llm.WithTemperature(0.3), llm.WithMaxTokens(100)),
) )
_, err := a.Run(context.Background(), "test") _, _, err := a.Run(context.Background(), "test")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -224,7 +215,7 @@ func TestAgent_Run_Error(t *testing.T) {
}) })
a := New(model, "You are helpful.") a := New(model, "You are helpful.")
_, err := a.Run(context.Background(), "test") _, _, err := a.Run(context.Background(), "test")
if err == nil { if err == nil {
t.Fatal("expected error, got nil") t.Fatal("expected error, got nil")
} }
@@ -234,7 +225,7 @@ func TestAgent_EmptySystem(t *testing.T) {
model := newSimpleMockModel("no system prompt") model := newSimpleMockModel("no system prompt")
a := New(model, "") // Empty system prompt a := New(model, "") // Empty system prompt
result, err := a.Run(context.Background(), "test") result, _, err := a.Run(context.Background(), "test")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -242,3 +233,34 @@ func TestAgent_EmptySystem(t *testing.T) {
t.Errorf("unexpected result: %q", result) t.Errorf("unexpected result: %q", result)
} }
} }
func TestAgent_Run_ReturnsUsage(t *testing.T) {
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{
Text: "result",
Usage: &provider.Usage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
},
}, nil
})
a := New(model, "You are helpful.")
result, usage, err := a.Run(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "result" {
t.Errorf("expected 'result', got %q", result)
}
if usage == nil {
t.Fatal("expected usage, got nil")
}
if usage.InputTokens != 100 {
t.Errorf("expected input 100, got %d", usage.InputTokens)
}
if usage.OutputTokens != 50 {
t.Errorf("expected output 50, got %d", usage.OutputTokens)
}
}

View File

@@ -25,7 +25,7 @@ func Example_researcher() {
agent.WithRequestOptions(llm.WithTemperature(0.3)), agent.WithRequestOptions(llm.WithTemperature(0.3)),
) )
result, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?") result, _, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?")
if err != nil { if err != nil {
fmt.Println("Error:", err) fmt.Println("Error:", err)
return return
@@ -50,7 +50,7 @@ func Example_coder() {
)), )),
) )
result, err := coder.Run(context.Background(), result, _, err := coder.Run(context.Background(),
"Create a Go program that prints the first 10 Fibonacci numbers. Save it and run it.") "Create a Go program that prints the first 10 Fibonacci numbers. Save it and run it.")
if err != nil { if err != nil {
fmt.Println("Error:", err) fmt.Println("Error:", err)
@@ -97,7 +97,7 @@ func Example_orchestrator() {
)), )),
) )
result, err := orchestrator.Run(context.Background(), result, _, err := orchestrator.Run(context.Background(),
"Research how to implement a binary search tree in Go, then create one with insert and search operations.") "Research how to implement a binary search tree in Go, then create one with insert and search operations.")
if err != nil { if err != nil {
fmt.Println("Error:", err) fmt.Println("Error:", err)

View File

@@ -270,6 +270,16 @@ func (p *Provider) convertResponse(resp anth.MessagesResponse) provider.Response
OutputTokens: resp.Usage.OutputTokens, OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
} }
details := map[string]int{}
if resp.Usage.CacheCreationInputTokens > 0 {
details[provider.UsageDetailCacheCreationTokens] = resp.Usage.CacheCreationInputTokens
}
if resp.Usage.CacheReadInputTokens > 0 {
details[provider.UsageDetailCachedInputTokens] = resp.Usage.CacheReadInputTokens
}
if len(details) > 0 {
res.Usage.Details = details
}
return res return res
} }

View File

@@ -38,44 +38,50 @@ func (c *Chat) SetTools(tb *ToolBox) {
c.tools = tb c.tools = tb
} }
// Send sends a user message and returns the assistant's text response. // Send sends a user message and returns the assistant's text response along with
// accumulated token usage from all iterations of the tool-call loop.
// If the model calls tools, they are executed automatically and the loop // If the model calls tools, they are executed automatically and the loop
// continues until the model produces a text response (the "agent loop"). // continues until the model produces a text response (the "agent loop").
func (c *Chat) Send(ctx context.Context, text string) (string, error) { func (c *Chat) Send(ctx context.Context, text string) (string, *Usage, error) {
return c.SendMessage(ctx, UserMessage(text)) return c.SendMessage(ctx, UserMessage(text))
} }
// SendWithImages sends a user message with images attached. // SendWithImages sends a user message with images attached.
func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, error) { func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, *Usage, error) {
return c.SendMessage(ctx, UserMessageWithImages(text, images...)) return c.SendMessage(ctx, UserMessageWithImages(text, images...))
} }
// SendMessage sends an arbitrary message and returns the final text response. // SendMessage sends an arbitrary message and returns the final text response along with
// accumulated token usage from all iterations of the tool-call loop.
// Handles the full tool-call loop automatically. // Handles the full tool-call loop automatically.
func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, error) { func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, *Usage, error) {
c.messages = append(c.messages, msg) c.messages = append(c.messages, msg)
opts := c.buildOpts() opts := c.buildOpts()
var totalUsage *Usage
for { for {
resp, err := c.model.Complete(ctx, c.messages, opts...) resp, err := c.model.Complete(ctx, c.messages, opts...)
if err != nil { if err != nil {
return "", fmt.Errorf("completion failed: %w", err) return "", totalUsage, fmt.Errorf("completion failed: %w", err)
} }
totalUsage = addUsage(totalUsage, resp.Usage)
c.messages = append(c.messages, resp.Message()) c.messages = append(c.messages, resp.Message())
if !resp.HasToolCalls() { if !resp.HasToolCalls() {
return resp.Text, nil return resp.Text, totalUsage, nil
} }
if c.tools == nil { if c.tools == nil {
return "", ErrNoToolsConfigured return "", totalUsage, ErrNoToolsConfigured
} }
toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls) toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls)
if err != nil { if err != nil {
return "", fmt.Errorf("tool execution failed: %w", err) return "", totalUsage, fmt.Errorf("tool execution failed: %w", err)
} }
c.messages = append(c.messages, toolResults...) c.messages = append(c.messages, toolResults...)

View File

@@ -14,7 +14,7 @@ func TestChat_Send(t *testing.T) {
model := newMockModel(mp) model := newMockModel(mp)
chat := NewChat(model) chat := NewChat(model)
text, err := chat.Send(context.Background(), "Hi") text, _, err := chat.Send(context.Background(), "Hi")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -28,7 +28,7 @@ func TestChat_SendMessage(t *testing.T) {
model := newMockModel(mp) model := newMockModel(mp)
chat := NewChat(model) chat := NewChat(model)
_, err := chat.SendMessage(context.Background(), UserMessage("msg1")) _, _, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -79,7 +79,7 @@ func TestChat_SetSystem(t *testing.T) {
} }
// System message stays first even after adding other messages // System message stays first even after adding other messages
_, _ = chat.Send(context.Background(), "Hi") _, _, _ = chat.Send(context.Background(), "Hi")
chat.SetSystem("New system") chat.SetSystem("New system")
msgs = chat.Messages() msgs = chat.Messages()
if msgs[0].Role != RoleSystem { if msgs[0].Role != RoleSystem {
@@ -113,7 +113,7 @@ func TestChat_ToolCallLoop(t *testing.T) {
}) })
chat.SetTools(NewToolBox(tool)) chat.SetTools(NewToolBox(tool))
text, err := chat.Send(context.Background(), "test") text, _, err := chat.Send(context.Background(), "test")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -158,7 +158,7 @@ func TestChat_ToolCallLoop_NoTools(t *testing.T) {
model := newMockModel(mp) model := newMockModel(mp)
chat := NewChat(model) chat := NewChat(model)
_, err := chat.Send(context.Background(), "test") _, _, err := chat.Send(context.Background(), "test")
if !errors.Is(err, ErrNoToolsConfigured) { if !errors.Is(err, ErrNoToolsConfigured) {
t.Errorf("expected ErrNoToolsConfigured, got %v", err) t.Errorf("expected ErrNoToolsConfigured, got %v", err)
} }
@@ -248,7 +248,7 @@ func TestChat_Messages(t *testing.T) {
model := newMockModel(mp) model := newMockModel(mp)
chat := NewChat(model) chat := NewChat(model)
_, _ = chat.Send(context.Background(), "test") _, _, _ = chat.Send(context.Background(), "test")
msgs := chat.Messages() msgs := chat.Messages()
// Verify it's a copy — modifying returned slice shouldn't affect chat // Verify it's a copy — modifying returned slice shouldn't affect chat
@@ -265,7 +265,7 @@ func TestChat_Reset(t *testing.T) {
model := newMockModel(mp) model := newMockModel(mp)
chat := NewChat(model) chat := NewChat(model)
_, _ = chat.Send(context.Background(), "test") _, _, _ = chat.Send(context.Background(), "test")
if len(chat.Messages()) == 0 { if len(chat.Messages()) == 0 {
t.Fatal("expected messages before reset") t.Fatal("expected messages before reset")
} }
@@ -281,7 +281,7 @@ func TestChat_Fork(t *testing.T) {
model := newMockModel(mp) model := newMockModel(mp)
chat := NewChat(model) chat := NewChat(model)
_, _ = chat.Send(context.Background(), "msg1") _, _, _ = chat.Send(context.Background(), "msg1")
fork := chat.Fork() fork := chat.Fork()
@@ -291,14 +291,14 @@ func TestChat_Fork(t *testing.T) {
} }
// Adding to fork should not affect original // Adding to fork should not affect original
_, _ = fork.Send(context.Background(), "msg2") _, _, _ = fork.Send(context.Background(), "msg2")
if len(fork.Messages()) == len(chat.Messages()) { if len(fork.Messages()) == len(chat.Messages()) {
t.Error("fork messages should be independent of original") t.Error("fork messages should be independent of original")
} }
// Adding to original should not affect fork // Adding to original should not affect fork
originalLen := len(chat.Messages()) originalLen := len(chat.Messages())
_, _ = chat.Send(context.Background(), "msg3") _, _, _ = chat.Send(context.Background(), "msg3")
if len(chat.Messages()) == originalLen { if len(chat.Messages()) == originalLen {
t.Error("original should have more messages after send") t.Error("original should have more messages after send")
} }
@@ -310,7 +310,7 @@ func TestChat_SendWithImages(t *testing.T) {
chat := NewChat(model) chat := NewChat(model)
img := Image{URL: "https://example.com/image.png"} img := Image{URL: "https://example.com/image.png"}
text, err := chat.SendWithImages(context.Background(), "What's in this image?", img) text, _, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -355,7 +355,7 @@ func TestChat_MultipleToolCallRounds(t *testing.T) {
}) })
chat.SetTools(NewToolBox(tool)) chat.SetTools(NewToolBox(tool))
text, err := chat.Send(context.Background(), "count three times") text, _, err := chat.Send(context.Background(), "count three times")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -378,7 +378,7 @@ func TestChat_SendError(t *testing.T) {
model := newMockModel(mp) model := newMockModel(mp)
chat := NewChat(model) chat := NewChat(model)
_, err := chat.Send(context.Background(), "test") _, _, err := chat.Send(context.Background(), "test")
if err == nil { if err == nil {
t.Fatal("expected error, got nil") t.Fatal("expected error, got nil")
} }
@@ -392,7 +392,7 @@ func TestChat_WithRequestOptions(t *testing.T) {
model := newMockModel(mp) model := newMockModel(mp)
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200)) chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
_, err := chat.Send(context.Background(), "test") _, _, err := chat.Send(context.Background(), "test")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -405,3 +405,107 @@ func TestChat_WithRequestOptions(t *testing.T) {
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens) t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
} }
} }
func TestChat_Send_UsageAccumulation(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n == 1 {
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "greet", Arguments: "{}"},
},
Usage: &provider.Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
}, nil
}
return provider.Response{
Text: "done",
Usage: &provider.Usage{InputTokens: 20, OutputTokens: 8, TotalTokens: 28},
}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello!", nil
})
chat.SetTools(NewToolBox(tool))
text, usage, err := chat.Send(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if text != "done" {
t.Errorf("expected 'done', got %q", text)
}
if usage == nil {
t.Fatal("expected usage, got nil")
}
if usage.InputTokens != 30 {
t.Errorf("expected accumulated input 30, got %d", usage.InputTokens)
}
if usage.OutputTokens != 13 {
t.Errorf("expected accumulated output 13, got %d", usage.OutputTokens)
}
if usage.TotalTokens != 43 {
t.Errorf("expected accumulated total 43, got %d", usage.TotalTokens)
}
}
func TestChat_Send_UsageWithDetails(t *testing.T) {
var callCount int32
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
n := atomic.AddInt32(&callCount, 1)
if n == 1 {
return provider.Response{
ToolCalls: []provider.ToolCall{
{ID: "tc1", Name: "greet", Arguments: "{}"},
},
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
Details: map[string]int{
"cached_input_tokens": 3,
},
},
}, nil
}
return provider.Response{
Text: "done",
Usage: &provider.Usage{
InputTokens: 20,
OutputTokens: 8,
TotalTokens: 28,
Details: map[string]int{
"cached_input_tokens": 7,
"reasoning_tokens": 2,
},
},
}, nil
})
model := newMockModel(mp)
chat := NewChat(model)
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello!", nil
})
chat.SetTools(NewToolBox(tool))
_, usage, err := chat.Send(context.Background(), "test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if usage == nil {
t.Fatal("expected usage, got nil")
}
if usage.Details == nil {
t.Fatal("expected usage details, got nil")
}
if usage.Details["cached_input_tokens"] != 10 {
t.Errorf("expected cached_input_tokens=10, got %d", usage.Details["cached_input_tokens"])
}
if usage.Details["reasoning_tokens"] != 2 {
t.Errorf("expected reasoning_tokens=2, got %d", usage.Details["reasoning_tokens"])
}
}

View File

@@ -13,14 +13,16 @@ const structuredOutputToolName = "structured_output"
// Generate sends a single user prompt to the model and parses the response into T. // Generate sends a single user prompt to the model and parses the response into T.
// T must be a struct. The model is forced to return structured output matching T's schema // T must be a struct. The model is forced to return structured output matching T's schema
// by using a hidden tool call internally. // by using a hidden tool call internally.
func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, error) { // Returns the parsed value, token usage, and any error.
func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, *Usage, error) {
return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...) return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...)
} }
// GenerateWith sends the given messages to the model and parses the response into T. // GenerateWith sends the given messages to the model and parses the response into T.
// T must be a struct. The model is forced to return structured output matching T's schema // T must be a struct. The model is forced to return structured output matching T's schema
// by using a hidden tool call internally. // by using a hidden tool call internally.
func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, error) { // Returns the parsed value, token usage, and any error.
func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, *Usage, error) {
var zero T var zero T
s := schema.FromStruct(zero) s := schema.FromStruct(zero)
@@ -36,7 +38,7 @@ func GenerateWith[T any](ctx context.Context, model *Model, messages []Message,
resp, err := model.Complete(ctx, messages, opts...) resp, err := model.Complete(ctx, messages, opts...)
if err != nil { if err != nil {
return zero, err return zero, nil, err
} }
// Find the structured_output tool call in the response. // Find the structured_output tool call in the response.
@@ -44,11 +46,11 @@ func GenerateWith[T any](ctx context.Context, model *Model, messages []Message,
if tc.Name == structuredOutputToolName { if tc.Name == structuredOutputToolName {
var result T var result T
if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil { if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil {
return zero, fmt.Errorf("failed to parse structured output: %w", err) return zero, resp.Usage, fmt.Errorf("failed to parse structured output: %w", err)
} }
return result, nil return result, resp.Usage, nil
} }
} }
return zero, ErrNoStructuredOutput return zero, resp.Usage, ErrNoStructuredOutput
} }

View File

@@ -25,7 +25,7 @@ func TestGenerate(t *testing.T) {
}) })
model := newMockModel(mp) model := newMockModel(mp)
result, err := Generate[testPerson](context.Background(), model, "Tell me about Alice") result, _, err := Generate[testPerson](context.Background(), model, "Tell me about Alice")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -63,7 +63,7 @@ func TestGenerateWith(t *testing.T) {
UserMessage("Tell me about Bob"), UserMessage("Tell me about Bob"),
} }
result, err := GenerateWith[testPerson](context.Background(), model, messages) result, _, err := GenerateWith[testPerson](context.Background(), model, messages)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -90,7 +90,7 @@ func TestGenerate_NoToolCall(t *testing.T) {
}) })
model := newMockModel(mp) model := newMockModel(mp)
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone") _, _, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
if err == nil { if err == nil {
t.Fatal("expected error, got nil") t.Fatal("expected error, got nil")
} }
@@ -111,7 +111,7 @@ func TestGenerate_InvalidJSON(t *testing.T) {
}) })
model := newMockModel(mp) model := newMockModel(mp)
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone") _, _, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
if err == nil { if err == nil {
t.Fatal("expected error, got nil") t.Fatal("expected error, got nil")
} }
@@ -143,7 +143,7 @@ func TestGenerate_NestedStruct(t *testing.T) {
}) })
model := newMockModel(mp) model := newMockModel(mp)
result, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol") result, _, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -170,7 +170,7 @@ func TestGenerate_WithOptions(t *testing.T) {
}) })
model := newMockModel(mp) model := newMockModel(mp)
_, err := Generate[testPerson](context.Background(), model, "Tell me about Dave", _, _, err := Generate[testPerson](context.Background(), model, "Tell me about Dave",
WithTemperature(0.5), WithTemperature(0.5),
WithMaxTokens(200), WithMaxTokens(200),
) )
@@ -207,7 +207,7 @@ func TestGenerate_WithMiddleware(t *testing.T) {
}) })
model := newMockModel(mp).WithMiddleware(mw) model := newMockModel(mp).WithMiddleware(mw)
result, err := Generate[testPerson](context.Background(), model, "Tell me about Eve") result, _, err := Generate[testPerson](context.Background(), model, "Tell me about Eve")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -231,7 +231,7 @@ func TestGenerate_WrongToolName(t *testing.T) {
}) })
model := newMockModel(mp) model := newMockModel(mp)
_, err := Generate[testPerson](context.Background(), model, "Tell me about Frank") _, _, err := Generate[testPerson](context.Background(), model, "Tell me about Frank")
if err == nil { if err == nil {
t.Fatal("expected error, got nil") t.Fatal("expected error, got nil")
} }
@@ -239,3 +239,44 @@ func TestGenerate_WrongToolName(t *testing.T) {
t.Errorf("expected ErrNoStructuredOutput, got %v", err) t.Errorf("expected ErrNoStructuredOutput, got %v", err)
} }
} }
func TestGenerate_ReturnsUsage(t *testing.T) {
mp := newMockProvider(provider.Response{
ToolCalls: []provider.ToolCall{
{
ID: "call_1",
Name: "structured_output",
Arguments: `{"name":"Grace","age":22}`,
},
},
Usage: &provider.Usage{
InputTokens: 50,
OutputTokens: 20,
TotalTokens: 70,
Details: map[string]int{
"reasoning_tokens": 5,
},
},
})
model := newMockModel(mp)
result, usage, err := Generate[testPerson](context.Background(), model, "Tell me about Grace")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Name != "Grace" {
t.Errorf("expected name 'Grace', got %q", result.Name)
}
if usage == nil {
t.Fatal("expected usage, got nil")
}
if usage.InputTokens != 50 {
t.Errorf("expected input 50, got %d", usage.InputTokens)
}
if usage.OutputTokens != 20 {
t.Errorf("expected output 20, got %d", usage.OutputTokens)
}
if usage.Details["reasoning_tokens"] != 5 {
t.Errorf("expected reasoning_tokens=5, got %d", usage.Details["reasoning_tokens"])
}
}

View File

@@ -59,12 +59,32 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
var fullText strings.Builder var fullText strings.Builder
var toolCalls []provider.ToolCall var toolCalls []provider.ToolCall
var usage *provider.Usage
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) { for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
if err != nil { if err != nil {
return fmt.Errorf("google stream error: %w", err) return fmt.Errorf("google stream error: %w", err)
} }
// Track usage from the last chunk (final chunk has cumulative counts)
if resp.UsageMetadata != nil {
usage = &provider.Usage{
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
}
details := map[string]int{}
if resp.UsageMetadata.CachedContentTokenCount > 0 {
details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount)
}
if resp.UsageMetadata.ThoughtsTokenCount > 0 {
details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount)
}
if len(details) > 0 {
usage.Details = details
}
}
for _, c := range resp.Candidates { for _, c := range resp.Candidates {
if c.Content == nil { if c.Content == nil {
continue continue
@@ -105,6 +125,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
Response: &provider.Response{ Response: &provider.Response{
Text: fullText.String(), Text: fullText.String(),
ToolCalls: toolCalls, ToolCalls: toolCalls,
Usage: usage,
}, },
} }
@@ -284,6 +305,16 @@ func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provide
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount), OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(resp.UsageMetadata.TotalTokenCount), TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
} }
details := map[string]int{}
if resp.UsageMetadata.CachedContentTokenCount > 0 {
details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount)
}
if resp.UsageMetadata.ThoughtsTokenCount > 0 {
details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount)
}
if len(details) > 0 {
res.Usage.Details = details
}
} }
return res, nil return res, nil

View File

@@ -177,6 +177,7 @@ func convertProviderResponse(resp provider.Response) Response {
InputTokens: resp.Usage.InputTokens, InputTokens: resp.Usage.InputTokens,
OutputTokens: resp.Usage.OutputTokens, OutputTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.TotalTokens, TotalTokens: resp.Usage.TotalTokens,
Details: resp.Usage.Details,
} }
} }

View File

@@ -82,6 +82,7 @@ type UsageTracker struct {
TotalInput int64 TotalInput int64
TotalOutput int64 TotalOutput int64
TotalRequests int64 TotalRequests int64
TotalDetails map[string]int64
} }
// Add records usage from a single request. // Add records usage from a single request.
@@ -94,6 +95,14 @@ func (ut *UsageTracker) Add(u *Usage) {
ut.TotalInput += int64(u.InputTokens) ut.TotalInput += int64(u.InputTokens)
ut.TotalOutput += int64(u.OutputTokens) ut.TotalOutput += int64(u.OutputTokens)
ut.TotalRequests++ ut.TotalRequests++
if len(u.Details) > 0 {
if ut.TotalDetails == nil {
ut.TotalDetails = make(map[string]int64)
}
for k, v := range u.Details {
ut.TotalDetails[k] += int64(v)
}
}
} }
// Summary returns the accumulated totals. // Summary returns the accumulated totals.
@@ -103,6 +112,20 @@ func (ut *UsageTracker) Summary() (input, output, requests int64) {
return ut.TotalInput, ut.TotalOutput, ut.TotalRequests return ut.TotalInput, ut.TotalOutput, ut.TotalRequests
} }
// Details returns a copy of the accumulated detail totals.
func (ut *UsageTracker) Details() map[string]int64 {
ut.mu.Lock()
defer ut.mu.Unlock()
if ut.TotalDetails == nil {
return nil
}
cp := make(map[string]int64, len(ut.TotalDetails))
for k, v := range ut.TotalDetails {
cp[k] = v
}
return cp
}
// WithUsageTracking returns middleware that accumulates token usage across calls. // WithUsageTracking returns middleware that accumulates token usage across calls.
func WithUsageTracking(tracker *UsageTracker) Middleware { func WithUsageTracking(tracker *UsageTracker) Middleware {
return func(next CompletionFunc) CompletionFunc { return func(next CompletionFunc) CompletionFunc {

View File

@@ -280,3 +280,80 @@ func TestWithLogging_Error(t *testing.T) {
t.Errorf("expected provider error, got %v", err) t.Errorf("expected provider error, got %v", err)
} }
} }
func TestUsageTracker_Details(t *testing.T) {
tracker := &UsageTracker{}
tracker.Add(&Usage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
Details: map[string]int{
"cached_input_tokens": 20,
"reasoning_tokens": 10,
},
})
tracker.Add(&Usage{
InputTokens: 80,
OutputTokens: 40,
TotalTokens: 120,
Details: map[string]int{
"cached_input_tokens": 15,
},
})
details := tracker.Details()
if details == nil {
t.Fatal("expected details, got nil")
}
if details["cached_input_tokens"] != 35 {
t.Errorf("expected cached_input_tokens=35, got %d", details["cached_input_tokens"])
}
if details["reasoning_tokens"] != 10 {
t.Errorf("expected reasoning_tokens=10, got %d", details["reasoning_tokens"])
}
// Verify returned map is a copy
details["cached_input_tokens"] = 999
fresh := tracker.Details()
if fresh["cached_input_tokens"] != 35 {
t.Error("Details() did not return a copy")
}
}
func TestUsageTracker_Details_Nil(t *testing.T) {
tracker := &UsageTracker{}
tracker.Add(&Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15})
details := tracker.Details()
if details != nil {
t.Errorf("expected nil details for usage without details, got %v", details)
}
}
func TestWithUsageTracking_WithDetails(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "ok",
Usage: &provider.Usage{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
Details: map[string]int{
"cached_input_tokens": 30,
},
},
})
tracker := &UsageTracker{}
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
details := tracker.Details()
if details["cached_input_tokens"] != 30 {
t.Errorf("expected cached_input_tokens=30, got %d", details["cached_input_tokens"])
}
}

View File

@@ -58,15 +58,30 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
cl := openai.NewClient(opts...) cl := openai.NewClient(opts...)
oaiReq := p.buildRequest(req) oaiReq := p.buildRequest(req)
oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{
IncludeUsage: openai.Bool(true),
}
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq) stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
var fullText strings.Builder var fullText strings.Builder
var toolCalls []provider.ToolCall var toolCalls []provider.ToolCall
toolCallArgs := map[int]*strings.Builder{} toolCallArgs := map[int]*strings.Builder{}
var usage *provider.Usage
for stream.Next() { for stream.Next() {
chunk := stream.Current() chunk := stream.Current()
// Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true)
if chunk.Usage.TotalTokens > 0 {
usage = &provider.Usage{
InputTokens: int(chunk.Usage.PromptTokens),
OutputTokens: int(chunk.Usage.CompletionTokens),
TotalTokens: int(chunk.Usage.TotalTokens),
Details: extractUsageDetails(chunk.Usage),
}
}
for _, choice := range chunk.Choices { for _, choice := range chunk.Choices {
// Text delta // Text delta
if choice.Delta.Content != "" { if choice.Delta.Content != "" {
@@ -138,6 +153,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
Response: &provider.Response{ Response: &provider.Response{
Text: fullText.String(), Text: fullText.String(),
ToolCalls: toolCalls, ToolCalls: toolCalls,
Usage: usage,
}, },
} }
@@ -363,6 +379,7 @@ func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Respons
OutputTokens: int(resp.Usage.CompletionTokens), OutputTokens: int(resp.Usage.CompletionTokens),
TotalTokens: int(resp.Usage.TotalTokens), TotalTokens: int(resp.Usage.TotalTokens),
} }
res.Usage.Details = extractUsageDetails(resp.Usage)
} }
return res return res
@@ -381,6 +398,27 @@ func audioFormat(contentType string) string {
} }
} }
// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage.
func extractUsageDetails(usage openai.CompletionUsage) map[string]int {
details := map[string]int{}
if usage.CompletionTokensDetails.ReasoningTokens > 0 {
details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens)
}
if usage.CompletionTokensDetails.AudioTokens > 0 {
details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens)
}
if usage.PromptTokensDetails.CachedTokens > 0 {
details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens)
}
if usage.PromptTokensDetails.AudioTokens > 0 {
details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens)
}
if len(details) == 0 {
return nil
}
return details
}
// audioFormatFromURL guesses the audio format from a URL's file extension. // audioFormatFromURL guesses the audio format from a URL's file extension.
func audioFormatFromURL(u string) string { func audioFormatFromURL(u string) string {
ext := strings.ToLower(path.Ext(u)) ext := strings.ToLower(path.Ext(u))

83
v2/pricing.go Normal file
View File

@@ -0,0 +1,83 @@
package llm
import "sync"
// ModelPricing defines per-token pricing for a model.
type ModelPricing struct {
InputPricePerToken float64 // USD per input token
OutputPricePerToken float64 // USD per output token
CachedInputPricePerToken float64 // USD per cached input token (0 = same as input)
}
// Cost computes the total USD cost from a Usage.
// When CachedInputPricePerToken is set and the usage includes cached_input_tokens,
// those tokens are charged at the cached rate instead of the regular input rate.
func (mp ModelPricing) Cost(u *Usage) float64 {
if u == nil {
return 0
}
inputTokens := u.InputTokens
cachedTokens := 0
if u.Details != nil {
cachedTokens = u.Details[UsageDetailCachedInputTokens]
}
var cost float64
if mp.CachedInputPricePerToken > 0 && cachedTokens > 0 {
regularInput := inputTokens - cachedTokens
if regularInput < 0 {
regularInput = 0
}
cost += float64(regularInput) * mp.InputPricePerToken
cost += float64(cachedTokens) * mp.CachedInputPricePerToken
} else {
cost += float64(inputTokens) * mp.InputPricePerToken
}
cost += float64(u.OutputTokens) * mp.OutputPricePerToken
return cost
}
// PricingRegistry maps model names to their pricing.
// Callers populate it with the models and prices relevant to their use case.
type PricingRegistry struct {
mu sync.RWMutex
models map[string]ModelPricing
}
// NewPricingRegistry creates an empty pricing registry.
func NewPricingRegistry() *PricingRegistry {
return &PricingRegistry{
models: make(map[string]ModelPricing),
}
}
// Set registers pricing for a model.
func (pr *PricingRegistry) Set(model string, pricing ModelPricing) {
pr.mu.Lock()
defer pr.mu.Unlock()
pr.models[model] = pricing
}
// Has returns true if pricing is registered for the given model.
func (pr *PricingRegistry) Has(model string) bool {
pr.mu.RLock()
defer pr.mu.RUnlock()
_, ok := pr.models[model]
return ok
}
// Cost computes the USD cost for the given model and usage.
// Returns 0 if the model is not registered.
func (pr *PricingRegistry) Cost(model string, u *Usage) float64 {
pr.mu.RLock()
pricing, ok := pr.models[model]
pr.mu.RUnlock()
if !ok {
return 0
}
return pricing.Cost(u)
}

128
v2/pricing_test.go Normal file
View File

@@ -0,0 +1,128 @@
package llm
import (
"math"
"testing"
)
func TestModelPricing_Cost(t *testing.T) {
pricing := ModelPricing{
InputPricePerToken: 0.000003, // $3/MTok
OutputPricePerToken: 0.000015, // $15/MTok
}
usage := &Usage{
InputTokens: 1000,
OutputTokens: 500,
TotalTokens: 1500,
}
cost := pricing.Cost(usage)
expected := 1000*0.000003 + 500*0.000015
if math.Abs(cost-expected) > 1e-10 {
t.Errorf("expected cost %f, got %f", expected, cost)
}
}
func TestModelPricing_Cost_WithCachedTokens(t *testing.T) {
pricing := ModelPricing{
InputPricePerToken: 0.000003, // $3/MTok
OutputPricePerToken: 0.000015, // $15/MTok
CachedInputPricePerToken: 0.0000015, // $1.50/MTok (50% discount)
}
usage := &Usage{
InputTokens: 1000,
OutputTokens: 500,
TotalTokens: 1500,
Details: map[string]int{
UsageDetailCachedInputTokens: 400,
},
}
cost := pricing.Cost(usage)
// 600 regular input tokens + 400 cached tokens + 500 output tokens
expected := 600*0.000003 + 400*0.0000015 + 500*0.000015
if math.Abs(cost-expected) > 1e-10 {
t.Errorf("expected cost %f, got %f", expected, cost)
}
}
func TestModelPricing_Cost_NilUsage(t *testing.T) {
pricing := ModelPricing{
InputPricePerToken: 0.000003,
OutputPricePerToken: 0.000015,
}
cost := pricing.Cost(nil)
if cost != 0 {
t.Errorf("expected 0 for nil usage, got %f", cost)
}
}
func TestModelPricing_Cost_NoCachedPrice(t *testing.T) {
// When CachedInputPricePerToken is 0, all input tokens use InputPricePerToken
pricing := ModelPricing{
InputPricePerToken: 0.000003,
OutputPricePerToken: 0.000015,
}
usage := &Usage{
InputTokens: 1000,
OutputTokens: 500,
TotalTokens: 1500,
Details: map[string]int{
UsageDetailCachedInputTokens: 400,
},
}
cost := pricing.Cost(usage)
expected := 1000*0.000003 + 500*0.000015
if math.Abs(cost-expected) > 1e-10 {
t.Errorf("expected cost %f, got %f", expected, cost)
}
}
func TestPricingRegistry(t *testing.T) {
registry := NewPricingRegistry()
registry.Set("gpt-4o", ModelPricing{
InputPricePerToken: 0.0000025,
OutputPricePerToken: 0.00001,
})
if !registry.Has("gpt-4o") {
t.Error("expected Has('gpt-4o') to be true")
}
if registry.Has("gpt-3.5-turbo") {
t.Error("expected Has('gpt-3.5-turbo') to be false")
}
usage := &Usage{InputTokens: 1000, OutputTokens: 200, TotalTokens: 1200}
cost := registry.Cost("gpt-4o", usage)
expected := 1000*0.0000025 + 200*0.00001
if math.Abs(cost-expected) > 1e-10 {
t.Errorf("expected cost %f, got %f", expected, cost)
}
// Unknown model returns 0
cost = registry.Cost("unknown-model", usage)
if cost != 0 {
t.Errorf("expected 0 for unknown model, got %f", cost)
}
}
func TestPricingRegistry_Override(t *testing.T) {
registry := NewPricingRegistry()
registry.Set("model-a", ModelPricing{InputPricePerToken: 0.001, OutputPricePerToken: 0.002})
registry.Set("model-a", ModelPricing{InputPricePerToken: 0.003, OutputPricePerToken: 0.004})
usage := &Usage{InputTokens: 100, OutputTokens: 50, TotalTokens: 150}
cost := registry.Cost("model-a", usage)
expected := 100*0.003 + 50*0.004
if math.Abs(cost-expected) > 1e-10 {
t.Errorf("expected overridden cost %f, got %f", expected, cost)
}
}

View File

@@ -64,8 +64,19 @@ type Usage struct {
InputTokens int InputTokens int
OutputTokens int OutputTokens int
TotalTokens int TotalTokens int
Details map[string]int // provider-specific breakdown (e.g., cached, reasoning tokens)
} }
// Standardized detail keys for provider-specific token breakdowns.
const (
UsageDetailReasoningTokens = "reasoning_tokens"
UsageDetailCachedInputTokens = "cached_input_tokens"
UsageDetailCacheCreationTokens = "cache_creation_tokens"
UsageDetailAudioInputTokens = "audio_input_tokens"
UsageDetailAudioOutputTokens = "audio_output_tokens"
UsageDetailThoughtsTokens = "thoughts_tokens"
)
// StreamEventType identifies the kind of stream event. // StreamEventType identifies the kind of stream event.
type StreamEventType int type StreamEventType int

View File

@@ -1,5 +1,7 @@
package llm package llm
import "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
// Response represents the result of a completion request. // Response represents the result of a completion request.
type Response struct { type Response struct {
// Text is the assistant's text content. Empty if only tool calls. // Text is the assistant's text content. Empty if only tool calls.
@@ -31,4 +33,45 @@ type Usage struct {
InputTokens int InputTokens int
OutputTokens int OutputTokens int
TotalTokens int TotalTokens int
Details map[string]int // provider-specific breakdown (e.g., cached, reasoning tokens)
} }
// addUsage merges usage u into the receiver, accumulating token counts and details.
// If the receiver is nil, it returns a copy of u. If u is nil, it returns the receiver unchanged.
func addUsage(total *Usage, u *Usage) *Usage {
if u == nil {
return total
}
if total == nil {
cp := *u
if u.Details != nil {
cp.Details = make(map[string]int, len(u.Details))
for k, v := range u.Details {
cp.Details[k] = v
}
}
return &cp
}
total.InputTokens += u.InputTokens
total.OutputTokens += u.OutputTokens
total.TotalTokens += u.TotalTokens
if u.Details != nil {
if total.Details == nil {
total.Details = make(map[string]int, len(u.Details))
}
for k, v := range u.Details {
total.Details[k] += v
}
}
return total
}
// Re-export detail key constants from provider package for convenience.
const (
UsageDetailReasoningTokens = provider.UsageDetailReasoningTokens
UsageDetailCachedInputTokens = provider.UsageDetailCachedInputTokens
UsageDetailCacheCreationTokens = provider.UsageDetailCacheCreationTokens
UsageDetailAudioInputTokens = provider.UsageDetailAudioInputTokens
UsageDetailAudioOutputTokens = provider.UsageDetailAudioOutputTokens
UsageDetailThoughtsTokens = provider.UsageDetailThoughtsTokens
)