From 5b687839b2f89e71dc02efd5adacb597da52d848 Mon Sep 17 00:00:00 2001 From: Steve Dudenhoeffer Date: Mon, 2 Mar 2026 04:33:18 +0000 Subject: [PATCH] feat: comprehensive token usage tracking for V2 Add provider-specific usage details, fix streaming usage, and return usage from all high-level APIs (Chat.Send, Generate[T], Agent.Run). Breaking changes: - Chat.Send/SendMessage/SendWithImages now return (string, *Usage, error) - Generate[T]/GenerateWith[T] now return (T, *Usage, error) - Agent.Run/RunMessages now return (string, *Usage, error) New features: - Usage.Details map for provider-specific token breakdowns (reasoning, cached, audio, thoughts tokens) - OpenAI streaming now captures usage via StreamOptions.IncludeUsage - Google streaming now captures UsageMetadata from final chunk - UsageTracker.Details() for accumulated detail totals - ModelPricing and PricingRegistry for cost computation Closes #2 Co-Authored-By: Claude Opus 4.6 --- v2/agent/agent.go | 11 ++-- v2/agent/agent_test.go | 56 +++++++++++----- v2/agent/example_test.go | 6 +- v2/anthropic/anthropic.go | 10 +++ v2/chat.go | 24 ++++--- v2/chat_test.go | 132 ++++++++++++++++++++++++++++++++++---- v2/generate.go | 14 ++-- v2/generate_test.go | 57 +++++++++++++--- v2/google/google.go | 31 +++++++++ v2/llm.go | 1 + v2/middleware.go | 23 +++++++ v2/middleware_test.go | 77 ++++++++++++++++++++++ v2/openai/openai.go | 38 +++++++++++ v2/pricing.go | 83 ++++++++++++++++++++++++ v2/pricing_test.go | 128 ++++++++++++++++++++++++++++++++++++ v2/provider/provider.go | 11 ++++ v2/response.go | 43 +++++++++++++ 17 files changed, 684 insertions(+), 61 deletions(-) create mode 100644 v2/pricing.go create mode 100644 v2/pricing_test.go diff --git a/v2/agent/agent.go b/v2/agent/agent.go index a8ce1c5..5a9113f 100644 --- a/v2/agent/agent.go +++ b/v2/agent/agent.go @@ -18,7 +18,7 @@ // coder.AsTool("code", "Write and run code"), // )), // ) -// result, err := orchestrator.Run(ctx, "Build a fibonacci function in Go") +// result, _, err := orchestrator.Run(ctx, "Build a fibonacci function in Go") package agent import ( @@ -64,13 +64,15 @@ func New(model *llm.Model, system string, opts ...Option) *Agent { // Run executes the agent with a user prompt. Each call is a fresh conversation. // The agent loops tool calls automatically until it produces a text response. -func (a *Agent) Run(ctx context.Context, prompt string) (string, error) { +// Returns the text response, accumulated token usage, and any error. +func (a *Agent) Run(ctx context.Context, prompt string) (string, *llm.Usage, error) { return a.RunMessages(ctx, []llm.Message{llm.UserMessage(prompt)}) } // RunMessages executes the agent with full message control. // Each call is a fresh conversation. The agent loops tool calls automatically. -func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, error) { +// Returns the text response, accumulated token usage, and any error. +func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, *llm.Usage, error) { chat := llm.NewChat(a.model, a.reqOpts...) if a.system != "" { chat.SetSystem(a.system) @@ -107,7 +109,8 @@ type delegateParams struct { func (a *Agent) AsTool(name, description string) llm.Tool { return llm.Define[delegateParams](name, description, func(ctx context.Context, p delegateParams) (string, error) { - return a.Run(ctx, p.Input) + text, _, err := a.Run(ctx, p.Input) + return text, err }, ) } diff --git a/v2/agent/agent_test.go b/v2/agent/agent_test.go index f201fff..b84206c 100644 --- a/v2/agent/agent_test.go +++ b/v2/agent/agent_test.go @@ -29,15 +29,6 @@ func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events return nil } -func (m *mockProvider) lastRequest() provider.Request { - m.mu.Lock() - defer m.mu.Unlock() - if len(m.requests) == 0 { - return provider.Request{} - } - return m.requests[len(m.requests)-1] -} - func newMockModel(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *llm.Model { mp := &mockProvider{completeFunc: fn} return llm.NewClient(mp).Model("mock-model") @@ -53,7 +44,7 @@ func TestAgent_Run(t *testing.T) { model := newSimpleMockModel("Hello from agent!") a := New(model, "You are a helpful assistant.") - result, err := a.Run(context.Background(), "Say hello") + result, _, err := a.Run(context.Background(), "Say hello") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -83,7 +74,7 @@ func TestAgent_Run_WithTools(t *testing.T) { }) a := New(model, "You are helpful.", WithTools(llm.NewToolBox(tool))) - result, err := a.Run(context.Background(), "Use the greet tool") + result, _, err := a.Run(context.Background(), "Use the greet tool") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -147,7 +138,7 @@ func TestAgent_AsTool_ParentChild(t *testing.T) { )), ) - result, err := parent.Run(context.Background(), "Tell me about Go generics") + result, _, err := parent.Run(context.Background(), "Tell me about Go generics") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -169,7 +160,7 @@ func TestAgent_RunMessages(t *testing.T) { llm.UserMessage("Follow up"), } - result, err := a.RunMessages(context.Background(), messages) + result, _, err := a.RunMessages(context.Background(), messages) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -187,7 +178,7 @@ func TestAgent_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately - _, err := a.Run(ctx, "This should fail") + _, _, err := a.Run(ctx, "This should fail") if err == nil { t.Fatal("expected error from cancelled context") } @@ -204,7 +195,7 @@ func TestAgent_WithRequestOptions(t *testing.T) { WithRequestOptions(llm.WithTemperature(0.3), llm.WithMaxTokens(100)), ) - _, err := a.Run(context.Background(), "test") + _, _, err := a.Run(context.Background(), "test") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -224,7 +215,7 @@ func TestAgent_Run_Error(t *testing.T) { }) a := New(model, "You are helpful.") - _, err := a.Run(context.Background(), "test") + _, _, err := a.Run(context.Background(), "test") if err == nil { t.Fatal("expected error, got nil") } @@ -234,7 +225,7 @@ func TestAgent_EmptySystem(t *testing.T) { model := newSimpleMockModel("no system prompt") a := New(model, "") // Empty system prompt - result, err := a.Run(context.Background(), "test") + result, _, err := a.Run(context.Background(), "test") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -242,3 +233,34 @@ func TestAgent_EmptySystem(t *testing.T) { t.Errorf("unexpected result: %q", result) } } + +func TestAgent_Run_ReturnsUsage(t *testing.T) { + model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) { + return provider.Response{ + Text: "result", + Usage: &provider.Usage{ + InputTokens: 100, + OutputTokens: 50, + TotalTokens: 150, + }, + }, nil + }) + + a := New(model, "You are helpful.") + result, usage, err := a.Run(context.Background(), "test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "result" { + t.Errorf("expected 'result', got %q", result) + } + if usage == nil { + t.Fatal("expected usage, got nil") + } + if usage.InputTokens != 100 { + t.Errorf("expected input 100, got %d", usage.InputTokens) + } + if usage.OutputTokens != 50 { + t.Errorf("expected output 50, got %d", usage.OutputTokens) + } +} diff --git a/v2/agent/example_test.go b/v2/agent/example_test.go index a08a56f..04cd3f9 100644 --- a/v2/agent/example_test.go +++ b/v2/agent/example_test.go @@ -25,7 +25,7 @@ func Example_researcher() { agent.WithRequestOptions(llm.WithTemperature(0.3)), ) - result, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?") + result, _, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?") if err != nil { fmt.Println("Error:", err) return @@ -50,7 +50,7 @@ func Example_coder() { )), ) - result, err := coder.Run(context.Background(), + result, _, err := coder.Run(context.Background(), "Create a Go program that prints the first 10 Fibonacci numbers. Save it and run it.") if err != nil { fmt.Println("Error:", err) @@ -97,7 +97,7 @@ func Example_orchestrator() { )), ) - result, err := orchestrator.Run(context.Background(), + result, _, err := orchestrator.Run(context.Background(), "Research how to implement a binary search tree in Go, then create one with insert and search operations.") if err != nil { fmt.Println("Error:", err) diff --git a/v2/anthropic/anthropic.go b/v2/anthropic/anthropic.go index 201092b..0a477e3 100644 --- a/v2/anthropic/anthropic.go +++ b/v2/anthropic/anthropic.go @@ -270,6 +270,16 @@ func (p *Provider) convertResponse(resp anth.MessagesResponse) provider.Response OutputTokens: resp.Usage.OutputTokens, TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, } + details := map[string]int{} + if resp.Usage.CacheCreationInputTokens > 0 { + details[provider.UsageDetailCacheCreationTokens] = resp.Usage.CacheCreationInputTokens + } + if resp.Usage.CacheReadInputTokens > 0 { + details[provider.UsageDetailCachedInputTokens] = resp.Usage.CacheReadInputTokens + } + if len(details) > 0 { + res.Usage.Details = details + } return res } diff --git a/v2/chat.go b/v2/chat.go index 47d2f64..4b3a92e 100644 --- a/v2/chat.go +++ b/v2/chat.go @@ -38,44 +38,50 @@ func (c *Chat) SetTools(tb *ToolBox) { c.tools = tb } -// Send sends a user message and returns the assistant's text response. +// Send sends a user message and returns the assistant's text response along with +// accumulated token usage from all iterations of the tool-call loop. // If the model calls tools, they are executed automatically and the loop // continues until the model produces a text response (the "agent loop"). -func (c *Chat) Send(ctx context.Context, text string) (string, error) { +func (c *Chat) Send(ctx context.Context, text string) (string, *Usage, error) { return c.SendMessage(ctx, UserMessage(text)) } // SendWithImages sends a user message with images attached. -func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, error) { +func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, *Usage, error) { return c.SendMessage(ctx, UserMessageWithImages(text, images...)) } -// SendMessage sends an arbitrary message and returns the final text response. +// SendMessage sends an arbitrary message and returns the final text response along with +// accumulated token usage from all iterations of the tool-call loop. // Handles the full tool-call loop automatically. -func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, error) { +func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, *Usage, error) { c.messages = append(c.messages, msg) opts := c.buildOpts() + var totalUsage *Usage + for { resp, err := c.model.Complete(ctx, c.messages, opts...) if err != nil { - return "", fmt.Errorf("completion failed: %w", err) + return "", totalUsage, fmt.Errorf("completion failed: %w", err) } + totalUsage = addUsage(totalUsage, resp.Usage) + c.messages = append(c.messages, resp.Message()) if !resp.HasToolCalls() { - return resp.Text, nil + return resp.Text, totalUsage, nil } if c.tools == nil { - return "", ErrNoToolsConfigured + return "", totalUsage, ErrNoToolsConfigured } toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls) if err != nil { - return "", fmt.Errorf("tool execution failed: %w", err) + return "", totalUsage, fmt.Errorf("tool execution failed: %w", err) } c.messages = append(c.messages, toolResults...) diff --git a/v2/chat_test.go b/v2/chat_test.go index be1864f..a201890 100644 --- a/v2/chat_test.go +++ b/v2/chat_test.go @@ -14,7 +14,7 @@ func TestChat_Send(t *testing.T) { model := newMockModel(mp) chat := NewChat(model) - text, err := chat.Send(context.Background(), "Hi") + text, _, err := chat.Send(context.Background(), "Hi") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -28,7 +28,7 @@ func TestChat_SendMessage(t *testing.T) { model := newMockModel(mp) chat := NewChat(model) - _, err := chat.SendMessage(context.Background(), UserMessage("msg1")) + _, _, err := chat.SendMessage(context.Background(), UserMessage("msg1")) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -79,7 +79,7 @@ func TestChat_SetSystem(t *testing.T) { } // System message stays first even after adding other messages - _, _ = chat.Send(context.Background(), "Hi") + _, _, _ = chat.Send(context.Background(), "Hi") chat.SetSystem("New system") msgs = chat.Messages() if msgs[0].Role != RoleSystem { @@ -113,7 +113,7 @@ func TestChat_ToolCallLoop(t *testing.T) { }) chat.SetTools(NewToolBox(tool)) - text, err := chat.Send(context.Background(), "test") + text, _, err := chat.Send(context.Background(), "test") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -158,7 +158,7 @@ func TestChat_ToolCallLoop_NoTools(t *testing.T) { model := newMockModel(mp) chat := NewChat(model) - _, err := chat.Send(context.Background(), "test") + _, _, err := chat.Send(context.Background(), "test") if !errors.Is(err, ErrNoToolsConfigured) { t.Errorf("expected ErrNoToolsConfigured, got %v", err) } @@ -248,7 +248,7 @@ func TestChat_Messages(t *testing.T) { model := newMockModel(mp) chat := NewChat(model) - _, _ = chat.Send(context.Background(), "test") + _, _, _ = chat.Send(context.Background(), "test") msgs := chat.Messages() // Verify it's a copy — modifying returned slice shouldn't affect chat @@ -265,7 +265,7 @@ func TestChat_Reset(t *testing.T) { model := newMockModel(mp) chat := NewChat(model) - _, _ = chat.Send(context.Background(), "test") + _, _, _ = chat.Send(context.Background(), "test") if len(chat.Messages()) == 0 { t.Fatal("expected messages before reset") } @@ -281,7 +281,7 @@ func TestChat_Fork(t *testing.T) { model := newMockModel(mp) chat := NewChat(model) - _, _ = chat.Send(context.Background(), "msg1") + _, _, _ = chat.Send(context.Background(), "msg1") fork := chat.Fork() @@ -291,14 +291,14 @@ func TestChat_Fork(t *testing.T) { } // Adding to fork should not affect original - _, _ = fork.Send(context.Background(), "msg2") + _, _, _ = fork.Send(context.Background(), "msg2") if len(fork.Messages()) == len(chat.Messages()) { t.Error("fork messages should be independent of original") } // Adding to original should not affect fork originalLen := len(chat.Messages()) - _, _ = chat.Send(context.Background(), "msg3") + _, _, _ = chat.Send(context.Background(), "msg3") if len(chat.Messages()) == originalLen { t.Error("original should have more messages after send") } @@ -310,7 +310,7 @@ func TestChat_SendWithImages(t *testing.T) { chat := NewChat(model) img := Image{URL: "https://example.com/image.png"} - text, err := chat.SendWithImages(context.Background(), "What's in this image?", img) + text, _, err := chat.SendWithImages(context.Background(), "What's in this image?", img) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -355,7 +355,7 @@ func TestChat_MultipleToolCallRounds(t *testing.T) { }) chat.SetTools(NewToolBox(tool)) - text, err := chat.Send(context.Background(), "count three times") + text, _, err := chat.Send(context.Background(), "count three times") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -378,7 +378,7 @@ func TestChat_SendError(t *testing.T) { model := newMockModel(mp) chat := NewChat(model) - _, err := chat.Send(context.Background(), "test") + _, _, err := chat.Send(context.Background(), "test") if err == nil { t.Fatal("expected error, got nil") } @@ -392,7 +392,7 @@ func TestChat_WithRequestOptions(t *testing.T) { model := newMockModel(mp) chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200)) - _, err := chat.Send(context.Background(), "test") + _, _, err := chat.Send(context.Background(), "test") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -405,3 +405,107 @@ func TestChat_WithRequestOptions(t *testing.T) { t.Errorf("expected maxTokens 200, got %v", req.MaxTokens) } } + +func TestChat_Send_UsageAccumulation(t *testing.T) { + var callCount int32 + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + n := atomic.AddInt32(&callCount, 1) + if n == 1 { + return provider.Response{ + ToolCalls: []provider.ToolCall{ + {ID: "tc1", Name: "greet", Arguments: "{}"}, + }, + Usage: &provider.Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + }, nil + } + return provider.Response{ + Text: "done", + Usage: &provider.Usage{InputTokens: 20, OutputTokens: 8, TotalTokens: 28}, + }, nil + }) + model := newMockModel(mp) + chat := NewChat(model) + + tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) { + return "hello!", nil + }) + chat.SetTools(NewToolBox(tool)) + + text, usage, err := chat.Send(context.Background(), "test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if text != "done" { + t.Errorf("expected 'done', got %q", text) + } + if usage == nil { + t.Fatal("expected usage, got nil") + } + if usage.InputTokens != 30 { + t.Errorf("expected accumulated input 30, got %d", usage.InputTokens) + } + if usage.OutputTokens != 13 { + t.Errorf("expected accumulated output 13, got %d", usage.OutputTokens) + } + if usage.TotalTokens != 43 { + t.Errorf("expected accumulated total 43, got %d", usage.TotalTokens) + } +} + +func TestChat_Send_UsageWithDetails(t *testing.T) { + var callCount int32 + mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) { + n := atomic.AddInt32(&callCount, 1) + if n == 1 { + return provider.Response{ + ToolCalls: []provider.ToolCall{ + {ID: "tc1", Name: "greet", Arguments: "{}"}, + }, + Usage: &provider.Usage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + Details: map[string]int{ + "cached_input_tokens": 3, + }, + }, + }, nil + } + return provider.Response{ + Text: "done", + Usage: &provider.Usage{ + InputTokens: 20, + OutputTokens: 8, + TotalTokens: 28, + Details: map[string]int{ + "cached_input_tokens": 7, + "reasoning_tokens": 2, + }, + }, + }, nil + }) + model := newMockModel(mp) + chat := NewChat(model) + + tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) { + return "hello!", nil + }) + chat.SetTools(NewToolBox(tool)) + + _, usage, err := chat.Send(context.Background(), "test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if usage == nil { + t.Fatal("expected usage, got nil") + } + if usage.Details == nil { + t.Fatal("expected usage details, got nil") + } + if usage.Details["cached_input_tokens"] != 10 { + t.Errorf("expected cached_input_tokens=10, got %d", usage.Details["cached_input_tokens"]) + } + if usage.Details["reasoning_tokens"] != 2 { + t.Errorf("expected reasoning_tokens=2, got %d", usage.Details["reasoning_tokens"]) + } +} diff --git a/v2/generate.go b/v2/generate.go index c6af254..60dc243 100644 --- a/v2/generate.go +++ b/v2/generate.go @@ -13,14 +13,16 @@ const structuredOutputToolName = "structured_output" // Generate sends a single user prompt to the model and parses the response into T. // T must be a struct. The model is forced to return structured output matching T's schema // by using a hidden tool call internally. -func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, error) { +// Returns the parsed value, token usage, and any error. +func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, *Usage, error) { return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...) } // GenerateWith sends the given messages to the model and parses the response into T. // T must be a struct. The model is forced to return structured output matching T's schema // by using a hidden tool call internally. -func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, error) { +// Returns the parsed value, token usage, and any error. +func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, *Usage, error) { var zero T s := schema.FromStruct(zero) @@ -36,7 +38,7 @@ func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, resp, err := model.Complete(ctx, messages, opts...) if err != nil { - return zero, err + return zero, nil, err } // Find the structured_output tool call in the response. @@ -44,11 +46,11 @@ func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, if tc.Name == structuredOutputToolName { var result T if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil { - return zero, fmt.Errorf("failed to parse structured output: %w", err) + return zero, resp.Usage, fmt.Errorf("failed to parse structured output: %w", err) } - return result, nil + return result, resp.Usage, nil } } - return zero, ErrNoStructuredOutput + return zero, resp.Usage, ErrNoStructuredOutput } diff --git a/v2/generate_test.go b/v2/generate_test.go index 7c4de1d..3eda5a9 100644 --- a/v2/generate_test.go +++ b/v2/generate_test.go @@ -25,7 +25,7 @@ func TestGenerate(t *testing.T) { }) model := newMockModel(mp) - result, err := Generate[testPerson](context.Background(), model, "Tell me about Alice") + result, _, err := Generate[testPerson](context.Background(), model, "Tell me about Alice") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -63,7 +63,7 @@ func TestGenerateWith(t *testing.T) { UserMessage("Tell me about Bob"), } - result, err := GenerateWith[testPerson](context.Background(), model, messages) + result, _, err := GenerateWith[testPerson](context.Background(), model, messages) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -90,7 +90,7 @@ func TestGenerate_NoToolCall(t *testing.T) { }) model := newMockModel(mp) - _, err := Generate[testPerson](context.Background(), model, "Tell me about someone") + _, _, err := Generate[testPerson](context.Background(), model, "Tell me about someone") if err == nil { t.Fatal("expected error, got nil") } @@ -111,7 +111,7 @@ func TestGenerate_InvalidJSON(t *testing.T) { }) model := newMockModel(mp) - _, err := Generate[testPerson](context.Background(), model, "Tell me about someone") + _, _, err := Generate[testPerson](context.Background(), model, "Tell me about someone") if err == nil { t.Fatal("expected error, got nil") } @@ -143,7 +143,7 @@ func TestGenerate_NestedStruct(t *testing.T) { }) model := newMockModel(mp) - result, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol") + result, _, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -170,7 +170,7 @@ func TestGenerate_WithOptions(t *testing.T) { }) model := newMockModel(mp) - _, err := Generate[testPerson](context.Background(), model, "Tell me about Dave", + _, _, err := Generate[testPerson](context.Background(), model, "Tell me about Dave", WithTemperature(0.5), WithMaxTokens(200), ) @@ -207,7 +207,7 @@ func TestGenerate_WithMiddleware(t *testing.T) { }) model := newMockModel(mp).WithMiddleware(mw) - result, err := Generate[testPerson](context.Background(), model, "Tell me about Eve") + result, _, err := Generate[testPerson](context.Background(), model, "Tell me about Eve") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -231,7 +231,7 @@ func TestGenerate_WrongToolName(t *testing.T) { }) model := newMockModel(mp) - _, err := Generate[testPerson](context.Background(), model, "Tell me about Frank") + _, _, err := Generate[testPerson](context.Background(), model, "Tell me about Frank") if err == nil { t.Fatal("expected error, got nil") } @@ -239,3 +239,44 @@ func TestGenerate_WrongToolName(t *testing.T) { t.Errorf("expected ErrNoStructuredOutput, got %v", err) } } + +func TestGenerate_ReturnsUsage(t *testing.T) { + mp := newMockProvider(provider.Response{ + ToolCalls: []provider.ToolCall{ + { + ID: "call_1", + Name: "structured_output", + Arguments: `{"name":"Grace","age":22}`, + }, + }, + Usage: &provider.Usage{ + InputTokens: 50, + OutputTokens: 20, + TotalTokens: 70, + Details: map[string]int{ + "reasoning_tokens": 5, + }, + }, + }) + model := newMockModel(mp) + + result, usage, err := Generate[testPerson](context.Background(), model, "Tell me about Grace") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Name != "Grace" { + t.Errorf("expected name 'Grace', got %q", result.Name) + } + if usage == nil { + t.Fatal("expected usage, got nil") + } + if usage.InputTokens != 50 { + t.Errorf("expected input 50, got %d", usage.InputTokens) + } + if usage.OutputTokens != 20 { + t.Errorf("expected output 20, got %d", usage.OutputTokens) + } + if usage.Details["reasoning_tokens"] != 5 { + t.Errorf("expected reasoning_tokens=5, got %d", usage.Details["reasoning_tokens"]) + } +} diff --git a/v2/google/google.go b/v2/google/google.go index 2f99060..5003bf2 100644 --- a/v2/google/google.go +++ b/v2/google/google.go @@ -59,12 +59,32 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan var fullText strings.Builder var toolCalls []provider.ToolCall + var usage *provider.Usage for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) { if err != nil { return fmt.Errorf("google stream error: %w", err) } + // Track usage from the last chunk (final chunk has cumulative counts) + if resp.UsageMetadata != nil { + usage = &provider.Usage{ + InputTokens: int(resp.UsageMetadata.PromptTokenCount), + OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount), + TotalTokens: int(resp.UsageMetadata.TotalTokenCount), + } + details := map[string]int{} + if resp.UsageMetadata.CachedContentTokenCount > 0 { + details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount) + } + if resp.UsageMetadata.ThoughtsTokenCount > 0 { + details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount) + } + if len(details) > 0 { + usage.Details = details + } + } + for _, c := range resp.Candidates { if c.Content == nil { continue @@ -105,6 +125,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan Response: &provider.Response{ Text: fullText.String(), ToolCalls: toolCalls, + Usage: usage, }, } @@ -284,6 +305,16 @@ func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provide OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount), TotalTokens: int(resp.UsageMetadata.TotalTokenCount), } + details := map[string]int{} + if resp.UsageMetadata.CachedContentTokenCount > 0 { + details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount) + } + if resp.UsageMetadata.ThoughtsTokenCount > 0 { + details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount) + } + if len(details) > 0 { + res.Usage.Details = details + } } return res, nil diff --git a/v2/llm.go b/v2/llm.go index 388565d..b34fba1 100644 --- a/v2/llm.go +++ b/v2/llm.go @@ -177,6 +177,7 @@ func convertProviderResponse(resp provider.Response) Response { InputTokens: resp.Usage.InputTokens, OutputTokens: resp.Usage.OutputTokens, TotalTokens: resp.Usage.TotalTokens, + Details: resp.Usage.Details, } } diff --git a/v2/middleware.go b/v2/middleware.go index 73e1620..8f3743f 100644 --- a/v2/middleware.go +++ b/v2/middleware.go @@ -82,6 +82,7 @@ type UsageTracker struct { TotalInput int64 TotalOutput int64 TotalRequests int64 + TotalDetails map[string]int64 } // Add records usage from a single request. @@ -94,6 +95,14 @@ func (ut *UsageTracker) Add(u *Usage) { ut.TotalInput += int64(u.InputTokens) ut.TotalOutput += int64(u.OutputTokens) ut.TotalRequests++ + if len(u.Details) > 0 { + if ut.TotalDetails == nil { + ut.TotalDetails = make(map[string]int64) + } + for k, v := range u.Details { + ut.TotalDetails[k] += int64(v) + } + } } // Summary returns the accumulated totals. @@ -103,6 +112,20 @@ func (ut *UsageTracker) Summary() (input, output, requests int64) { return ut.TotalInput, ut.TotalOutput, ut.TotalRequests } +// Details returns a copy of the accumulated detail totals. +func (ut *UsageTracker) Details() map[string]int64 { + ut.mu.Lock() + defer ut.mu.Unlock() + if ut.TotalDetails == nil { + return nil + } + cp := make(map[string]int64, len(ut.TotalDetails)) + for k, v := range ut.TotalDetails { + cp[k] = v + } + return cp +} + // WithUsageTracking returns middleware that accumulates token usage across calls. func WithUsageTracking(tracker *UsageTracker) Middleware { return func(next CompletionFunc) CompletionFunc { diff --git a/v2/middleware_test.go b/v2/middleware_test.go index 652606c..50fb74f 100644 --- a/v2/middleware_test.go +++ b/v2/middleware_test.go @@ -280,3 +280,80 @@ func TestWithLogging_Error(t *testing.T) { t.Errorf("expected provider error, got %v", err) } } + +func TestUsageTracker_Details(t *testing.T) { + tracker := &UsageTracker{} + + tracker.Add(&Usage{ + InputTokens: 100, + OutputTokens: 50, + TotalTokens: 150, + Details: map[string]int{ + "cached_input_tokens": 20, + "reasoning_tokens": 10, + }, + }) + + tracker.Add(&Usage{ + InputTokens: 80, + OutputTokens: 40, + TotalTokens: 120, + Details: map[string]int{ + "cached_input_tokens": 15, + }, + }) + + details := tracker.Details() + if details == nil { + t.Fatal("expected details, got nil") + } + if details["cached_input_tokens"] != 35 { + t.Errorf("expected cached_input_tokens=35, got %d", details["cached_input_tokens"]) + } + if details["reasoning_tokens"] != 10 { + t.Errorf("expected reasoning_tokens=10, got %d", details["reasoning_tokens"]) + } + + // Verify returned map is a copy + details["cached_input_tokens"] = 999 + fresh := tracker.Details() + if fresh["cached_input_tokens"] != 35 { + t.Error("Details() did not return a copy") + } +} + +func TestUsageTracker_Details_Nil(t *testing.T) { + tracker := &UsageTracker{} + tracker.Add(&Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}) + + details := tracker.Details() + if details != nil { + t.Errorf("expected nil details for usage without details, got %v", details) + } +} + +func TestWithUsageTracking_WithDetails(t *testing.T) { + mp := newMockProvider(provider.Response{ + Text: "ok", + Usage: &provider.Usage{ + InputTokens: 100, + OutputTokens: 50, + TotalTokens: 150, + Details: map[string]int{ + "cached_input_tokens": 30, + }, + }, + }) + tracker := &UsageTracker{} + model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker)) + + _, err := model.Complete(context.Background(), []Message{UserMessage("test")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + details := tracker.Details() + if details["cached_input_tokens"] != 30 { + t.Errorf("expected cached_input_tokens=30, got %d", details["cached_input_tokens"]) + } +} diff --git a/v2/openai/openai.go b/v2/openai/openai.go index 2676de6..4193542 100644 --- a/v2/openai/openai.go +++ b/v2/openai/openai.go @@ -58,15 +58,30 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan cl := openai.NewClient(opts...) oaiReq := p.buildRequest(req) + oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), + } stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq) var fullText strings.Builder var toolCalls []provider.ToolCall toolCallArgs := map[int]*strings.Builder{} + var usage *provider.Usage for stream.Next() { chunk := stream.Current() + + // Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true) + if chunk.Usage.TotalTokens > 0 { + usage = &provider.Usage{ + InputTokens: int(chunk.Usage.PromptTokens), + OutputTokens: int(chunk.Usage.CompletionTokens), + TotalTokens: int(chunk.Usage.TotalTokens), + Details: extractUsageDetails(chunk.Usage), + } + } + for _, choice := range chunk.Choices { // Text delta if choice.Delta.Content != "" { @@ -138,6 +153,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan Response: &provider.Response{ Text: fullText.String(), ToolCalls: toolCalls, + Usage: usage, }, } @@ -363,6 +379,7 @@ func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Respons OutputTokens: int(resp.Usage.CompletionTokens), TotalTokens: int(resp.Usage.TotalTokens), } + res.Usage.Details = extractUsageDetails(resp.Usage) } return res @@ -381,6 +398,27 @@ func audioFormat(contentType string) string { } } +// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage. +func extractUsageDetails(usage openai.CompletionUsage) map[string]int { + details := map[string]int{} + if usage.CompletionTokensDetails.ReasoningTokens > 0 { + details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens) + } + if usage.CompletionTokensDetails.AudioTokens > 0 { + details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens) + } + if usage.PromptTokensDetails.CachedTokens > 0 { + details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens) + } + if usage.PromptTokensDetails.AudioTokens > 0 { + details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens) + } + if len(details) == 0 { + return nil + } + return details +} + // audioFormatFromURL guesses the audio format from a URL's file extension. func audioFormatFromURL(u string) string { ext := strings.ToLower(path.Ext(u)) diff --git a/v2/pricing.go b/v2/pricing.go new file mode 100644 index 0000000..aefd02f --- /dev/null +++ b/v2/pricing.go @@ -0,0 +1,83 @@ +package llm + +import "sync" + +// ModelPricing defines per-token pricing for a model. +type ModelPricing struct { + InputPricePerToken float64 // USD per input token + OutputPricePerToken float64 // USD per output token + CachedInputPricePerToken float64 // USD per cached input token (0 = same as input) +} + +// Cost computes the total USD cost from a Usage. +// When CachedInputPricePerToken is set and the usage includes cached_input_tokens, +// those tokens are charged at the cached rate instead of the regular input rate. +func (mp ModelPricing) Cost(u *Usage) float64 { + if u == nil { + return 0 + } + + inputTokens := u.InputTokens + cachedTokens := 0 + if u.Details != nil { + cachedTokens = u.Details[UsageDetailCachedInputTokens] + } + + var cost float64 + + if mp.CachedInputPricePerToken > 0 && cachedTokens > 0 { + regularInput := inputTokens - cachedTokens + if regularInput < 0 { + regularInput = 0 + } + cost += float64(regularInput) * mp.InputPricePerToken + cost += float64(cachedTokens) * mp.CachedInputPricePerToken + } else { + cost += float64(inputTokens) * mp.InputPricePerToken + } + + cost += float64(u.OutputTokens) * mp.OutputPricePerToken + + return cost +} + +// PricingRegistry maps model names to their pricing. +// Callers populate it with the models and prices relevant to their use case. +type PricingRegistry struct { + mu sync.RWMutex + models map[string]ModelPricing +} + +// NewPricingRegistry creates an empty pricing registry. +func NewPricingRegistry() *PricingRegistry { + return &PricingRegistry{ + models: make(map[string]ModelPricing), + } +} + +// Set registers pricing for a model. +func (pr *PricingRegistry) Set(model string, pricing ModelPricing) { + pr.mu.Lock() + defer pr.mu.Unlock() + pr.models[model] = pricing +} + +// Has returns true if pricing is registered for the given model. +func (pr *PricingRegistry) Has(model string) bool { + pr.mu.RLock() + defer pr.mu.RUnlock() + _, ok := pr.models[model] + return ok +} + +// Cost computes the USD cost for the given model and usage. +// Returns 0 if the model is not registered. +func (pr *PricingRegistry) Cost(model string, u *Usage) float64 { + pr.mu.RLock() + pricing, ok := pr.models[model] + pr.mu.RUnlock() + if !ok { + return 0 + } + return pricing.Cost(u) +} diff --git a/v2/pricing_test.go b/v2/pricing_test.go new file mode 100644 index 0000000..471db2d --- /dev/null +++ b/v2/pricing_test.go @@ -0,0 +1,128 @@ +package llm + +import ( + "math" + "testing" +) + +func TestModelPricing_Cost(t *testing.T) { + pricing := ModelPricing{ + InputPricePerToken: 0.000003, // $3/MTok + OutputPricePerToken: 0.000015, // $15/MTok + } + + usage := &Usage{ + InputTokens: 1000, + OutputTokens: 500, + TotalTokens: 1500, + } + + cost := pricing.Cost(usage) + expected := 1000*0.000003 + 500*0.000015 + if math.Abs(cost-expected) > 1e-10 { + t.Errorf("expected cost %f, got %f", expected, cost) + } +} + +func TestModelPricing_Cost_WithCachedTokens(t *testing.T) { + pricing := ModelPricing{ + InputPricePerToken: 0.000003, // $3/MTok + OutputPricePerToken: 0.000015, // $15/MTok + CachedInputPricePerToken: 0.0000015, // $1.50/MTok (50% discount) + } + + usage := &Usage{ + InputTokens: 1000, + OutputTokens: 500, + TotalTokens: 1500, + Details: map[string]int{ + UsageDetailCachedInputTokens: 400, + }, + } + + cost := pricing.Cost(usage) + // 600 regular input tokens + 400 cached tokens + 500 output tokens + expected := 600*0.000003 + 400*0.0000015 + 500*0.000015 + if math.Abs(cost-expected) > 1e-10 { + t.Errorf("expected cost %f, got %f", expected, cost) + } +} + +func TestModelPricing_Cost_NilUsage(t *testing.T) { + pricing := ModelPricing{ + InputPricePerToken: 0.000003, + OutputPricePerToken: 0.000015, + } + + cost := pricing.Cost(nil) + if cost != 0 { + t.Errorf("expected 0 for nil usage, got %f", cost) + } +} + +func TestModelPricing_Cost_NoCachedPrice(t *testing.T) { + // When CachedInputPricePerToken is 0, all input tokens use InputPricePerToken + pricing := ModelPricing{ + InputPricePerToken: 0.000003, + OutputPricePerToken: 0.000015, + } + + usage := &Usage{ + InputTokens: 1000, + OutputTokens: 500, + TotalTokens: 1500, + Details: map[string]int{ + UsageDetailCachedInputTokens: 400, + }, + } + + cost := pricing.Cost(usage) + expected := 1000*0.000003 + 500*0.000015 + if math.Abs(cost-expected) > 1e-10 { + t.Errorf("expected cost %f, got %f", expected, cost) + } +} + +func TestPricingRegistry(t *testing.T) { + registry := NewPricingRegistry() + + registry.Set("gpt-4o", ModelPricing{ + InputPricePerToken: 0.0000025, + OutputPricePerToken: 0.00001, + }) + + if !registry.Has("gpt-4o") { + t.Error("expected Has('gpt-4o') to be true") + } + if registry.Has("gpt-3.5-turbo") { + t.Error("expected Has('gpt-3.5-turbo') to be false") + } + + usage := &Usage{InputTokens: 1000, OutputTokens: 200, TotalTokens: 1200} + + cost := registry.Cost("gpt-4o", usage) + expected := 1000*0.0000025 + 200*0.00001 + if math.Abs(cost-expected) > 1e-10 { + t.Errorf("expected cost %f, got %f", expected, cost) + } + + // Unknown model returns 0 + cost = registry.Cost("unknown-model", usage) + if cost != 0 { + t.Errorf("expected 0 for unknown model, got %f", cost) + } +} + +func TestPricingRegistry_Override(t *testing.T) { + registry := NewPricingRegistry() + + registry.Set("model-a", ModelPricing{InputPricePerToken: 0.001, OutputPricePerToken: 0.002}) + registry.Set("model-a", ModelPricing{InputPricePerToken: 0.003, OutputPricePerToken: 0.004}) + + usage := &Usage{InputTokens: 100, OutputTokens: 50, TotalTokens: 150} + cost := registry.Cost("model-a", usage) + expected := 100*0.003 + 50*0.004 + if math.Abs(cost-expected) > 1e-10 { + t.Errorf("expected overridden cost %f, got %f", expected, cost) + } +} diff --git a/v2/provider/provider.go b/v2/provider/provider.go index 084ee83..5a7774a 100644 --- a/v2/provider/provider.go +++ b/v2/provider/provider.go @@ -64,8 +64,19 @@ type Usage struct { InputTokens int OutputTokens int TotalTokens int + Details map[string]int // provider-specific breakdown (e.g., cached, reasoning tokens) } +// Standardized detail keys for provider-specific token breakdowns. +const ( + UsageDetailReasoningTokens = "reasoning_tokens" + UsageDetailCachedInputTokens = "cached_input_tokens" + UsageDetailCacheCreationTokens = "cache_creation_tokens" + UsageDetailAudioInputTokens = "audio_input_tokens" + UsageDetailAudioOutputTokens = "audio_output_tokens" + UsageDetailThoughtsTokens = "thoughts_tokens" +) + // StreamEventType identifies the kind of stream event. type StreamEventType int diff --git a/v2/response.go b/v2/response.go index a19c397..e7ae0ad 100644 --- a/v2/response.go +++ b/v2/response.go @@ -1,5 +1,7 @@ package llm +import "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" + // Response represents the result of a completion request. type Response struct { // Text is the assistant's text content. Empty if only tool calls. @@ -31,4 +33,45 @@ type Usage struct { InputTokens int OutputTokens int TotalTokens int + Details map[string]int // provider-specific breakdown (e.g., cached, reasoning tokens) } + +// addUsage merges usage u into the receiver, accumulating token counts and details. +// If the receiver is nil, it returns a copy of u. If u is nil, it returns the receiver unchanged. +func addUsage(total *Usage, u *Usage) *Usage { + if u == nil { + return total + } + if total == nil { + cp := *u + if u.Details != nil { + cp.Details = make(map[string]int, len(u.Details)) + for k, v := range u.Details { + cp.Details[k] = v + } + } + return &cp + } + total.InputTokens += u.InputTokens + total.OutputTokens += u.OutputTokens + total.TotalTokens += u.TotalTokens + if u.Details != nil { + if total.Details == nil { + total.Details = make(map[string]int, len(u.Details)) + } + for k, v := range u.Details { + total.Details[k] += v + } + } + return total +} + +// Re-export detail key constants from provider package for convenience. +const ( + UsageDetailReasoningTokens = provider.UsageDetailReasoningTokens + UsageDetailCachedInputTokens = provider.UsageDetailCachedInputTokens + UsageDetailCacheCreationTokens = provider.UsageDetailCacheCreationTokens + UsageDetailAudioInputTokens = provider.UsageDetailAudioInputTokens + UsageDetailAudioOutputTokens = provider.UsageDetailAudioOutputTokens + UsageDetailThoughtsTokens = provider.UsageDetailThoughtsTokens +) -- 2.49.1