feat: comprehensive token usage tracking for V2
Add provider-specific usage details, fix streaming usage, and return usage from all high-level APIs (Chat.Send, Generate[T], Agent.Run). Breaking changes: - Chat.Send/SendMessage/SendWithImages now return (string, *Usage, error) - Generate[T]/GenerateWith[T] now return (T, *Usage, error) - Agent.Run/RunMessages now return (string, *Usage, error) New features: - Usage.Details map for provider-specific token breakdowns (reasoning, cached, audio, thoughts tokens) - OpenAI streaming now captures usage via StreamOptions.IncludeUsage - Google streaming now captures UsageMetadata from final chunk - UsageTracker.Details() for accumulated detail totals - ModelPricing and PricingRegistry for cost computation Closes #2 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
// coder.AsTool("code", "Write and run code"),
|
||||
// )),
|
||||
// )
|
||||
// result, err := orchestrator.Run(ctx, "Build a fibonacci function in Go")
|
||||
// result, _, err := orchestrator.Run(ctx, "Build a fibonacci function in Go")
|
||||
package agent
|
||||
|
||||
import (
|
||||
@@ -64,13 +64,15 @@ func New(model *llm.Model, system string, opts ...Option) *Agent {
|
||||
|
||||
// Run executes the agent with a user prompt. Each call is a fresh conversation.
|
||||
// The agent loops tool calls automatically until it produces a text response.
|
||||
func (a *Agent) Run(ctx context.Context, prompt string) (string, error) {
|
||||
// Returns the text response, accumulated token usage, and any error.
|
||||
func (a *Agent) Run(ctx context.Context, prompt string) (string, *llm.Usage, error) {
|
||||
return a.RunMessages(ctx, []llm.Message{llm.UserMessage(prompt)})
|
||||
}
|
||||
|
||||
// RunMessages executes the agent with full message control.
|
||||
// Each call is a fresh conversation. The agent loops tool calls automatically.
|
||||
func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, error) {
|
||||
// Returns the text response, accumulated token usage, and any error.
|
||||
func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, *llm.Usage, error) {
|
||||
chat := llm.NewChat(a.model, a.reqOpts...)
|
||||
if a.system != "" {
|
||||
chat.SetSystem(a.system)
|
||||
@@ -107,7 +109,8 @@ type delegateParams struct {
|
||||
func (a *Agent) AsTool(name, description string) llm.Tool {
|
||||
return llm.Define[delegateParams](name, description,
|
||||
func(ctx context.Context, p delegateParams) (string, error) {
|
||||
return a.Run(ctx, p.Input)
|
||||
text, _, err := a.Run(ctx, p.Input)
|
||||
return text, err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -29,15 +29,6 @@ func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) lastRequest() provider.Request {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if len(m.requests) == 0 {
|
||||
return provider.Request{}
|
||||
}
|
||||
return m.requests[len(m.requests)-1]
|
||||
}
|
||||
|
||||
func newMockModel(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *llm.Model {
|
||||
mp := &mockProvider{completeFunc: fn}
|
||||
return llm.NewClient(mp).Model("mock-model")
|
||||
@@ -53,7 +44,7 @@ func TestAgent_Run(t *testing.T) {
|
||||
model := newSimpleMockModel("Hello from agent!")
|
||||
a := New(model, "You are a helpful assistant.")
|
||||
|
||||
result, err := a.Run(context.Background(), "Say hello")
|
||||
result, _, err := a.Run(context.Background(), "Say hello")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -83,7 +74,7 @@ func TestAgent_Run_WithTools(t *testing.T) {
|
||||
})
|
||||
|
||||
a := New(model, "You are helpful.", WithTools(llm.NewToolBox(tool)))
|
||||
result, err := a.Run(context.Background(), "Use the greet tool")
|
||||
result, _, err := a.Run(context.Background(), "Use the greet tool")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -147,7 +138,7 @@ func TestAgent_AsTool_ParentChild(t *testing.T) {
|
||||
)),
|
||||
)
|
||||
|
||||
result, err := parent.Run(context.Background(), "Tell me about Go generics")
|
||||
result, _, err := parent.Run(context.Background(), "Tell me about Go generics")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -169,7 +160,7 @@ func TestAgent_RunMessages(t *testing.T) {
|
||||
llm.UserMessage("Follow up"),
|
||||
}
|
||||
|
||||
result, err := a.RunMessages(context.Background(), messages)
|
||||
result, _, err := a.RunMessages(context.Background(), messages)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -187,7 +178,7 @@ func TestAgent_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := a.Run(ctx, "This should fail")
|
||||
_, _, err := a.Run(ctx, "This should fail")
|
||||
if err == nil {
|
||||
t.Fatal("expected error from cancelled context")
|
||||
}
|
||||
@@ -204,7 +195,7 @@ func TestAgent_WithRequestOptions(t *testing.T) {
|
||||
WithRequestOptions(llm.WithTemperature(0.3), llm.WithMaxTokens(100)),
|
||||
)
|
||||
|
||||
_, err := a.Run(context.Background(), "test")
|
||||
_, _, err := a.Run(context.Background(), "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -224,7 +215,7 @@ func TestAgent_Run_Error(t *testing.T) {
|
||||
})
|
||||
a := New(model, "You are helpful.")
|
||||
|
||||
_, err := a.Run(context.Background(), "test")
|
||||
_, _, err := a.Run(context.Background(), "test")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
@@ -234,7 +225,7 @@ func TestAgent_EmptySystem(t *testing.T) {
|
||||
model := newSimpleMockModel("no system prompt")
|
||||
a := New(model, "") // Empty system prompt
|
||||
|
||||
result, err := a.Run(context.Background(), "test")
|
||||
result, _, err := a.Run(context.Background(), "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -242,3 +233,34 @@ func TestAgent_EmptySystem(t *testing.T) {
|
||||
t.Errorf("unexpected result: %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_Run_ReturnsUsage(t *testing.T) {
|
||||
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{
|
||||
Text: "result",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
TotalTokens: 150,
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
|
||||
a := New(model, "You are helpful.")
|
||||
result, usage, err := a.Run(context.Background(), "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "result" {
|
||||
t.Errorf("expected 'result', got %q", result)
|
||||
}
|
||||
if usage == nil {
|
||||
t.Fatal("expected usage, got nil")
|
||||
}
|
||||
if usage.InputTokens != 100 {
|
||||
t.Errorf("expected input 100, got %d", usage.InputTokens)
|
||||
}
|
||||
if usage.OutputTokens != 50 {
|
||||
t.Errorf("expected output 50, got %d", usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func Example_researcher() {
|
||||
agent.WithRequestOptions(llm.WithTemperature(0.3)),
|
||||
)
|
||||
|
||||
result, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?")
|
||||
result, _, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?")
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
return
|
||||
@@ -50,7 +50,7 @@ func Example_coder() {
|
||||
)),
|
||||
)
|
||||
|
||||
result, err := coder.Run(context.Background(),
|
||||
result, _, err := coder.Run(context.Background(),
|
||||
"Create a Go program that prints the first 10 Fibonacci numbers. Save it and run it.")
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
@@ -97,7 +97,7 @@ func Example_orchestrator() {
|
||||
)),
|
||||
)
|
||||
|
||||
result, err := orchestrator.Run(context.Background(),
|
||||
result, _, err := orchestrator.Run(context.Background(),
|
||||
"Research how to implement a binary search tree in Go, then create one with insert and search operations.")
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
|
||||
@@ -270,6 +270,16 @@ func (p *Provider) convertResponse(resp anth.MessagesResponse) provider.Response
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||
}
|
||||
details := map[string]int{}
|
||||
if resp.Usage.CacheCreationInputTokens > 0 {
|
||||
details[provider.UsageDetailCacheCreationTokens] = resp.Usage.CacheCreationInputTokens
|
||||
}
|
||||
if resp.Usage.CacheReadInputTokens > 0 {
|
||||
details[provider.UsageDetailCachedInputTokens] = resp.Usage.CacheReadInputTokens
|
||||
}
|
||||
if len(details) > 0 {
|
||||
res.Usage.Details = details
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
24
v2/chat.go
24
v2/chat.go
@@ -38,44 +38,50 @@ func (c *Chat) SetTools(tb *ToolBox) {
|
||||
c.tools = tb
|
||||
}
|
||||
|
||||
// Send sends a user message and returns the assistant's text response.
|
||||
// Send sends a user message and returns the assistant's text response along with
|
||||
// accumulated token usage from all iterations of the tool-call loop.
|
||||
// If the model calls tools, they are executed automatically and the loop
|
||||
// continues until the model produces a text response (the "agent loop").
|
||||
func (c *Chat) Send(ctx context.Context, text string) (string, error) {
|
||||
func (c *Chat) Send(ctx context.Context, text string) (string, *Usage, error) {
|
||||
return c.SendMessage(ctx, UserMessage(text))
|
||||
}
|
||||
|
||||
// SendWithImages sends a user message with images attached.
|
||||
func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, error) {
|
||||
func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, *Usage, error) {
|
||||
return c.SendMessage(ctx, UserMessageWithImages(text, images...))
|
||||
}
|
||||
|
||||
// SendMessage sends an arbitrary message and returns the final text response.
|
||||
// SendMessage sends an arbitrary message and returns the final text response along with
|
||||
// accumulated token usage from all iterations of the tool-call loop.
|
||||
// Handles the full tool-call loop automatically.
|
||||
func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, error) {
|
||||
func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, *Usage, error) {
|
||||
c.messages = append(c.messages, msg)
|
||||
|
||||
opts := c.buildOpts()
|
||||
|
||||
var totalUsage *Usage
|
||||
|
||||
for {
|
||||
resp, err := c.model.Complete(ctx, c.messages, opts...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("completion failed: %w", err)
|
||||
return "", totalUsage, fmt.Errorf("completion failed: %w", err)
|
||||
}
|
||||
|
||||
totalUsage = addUsage(totalUsage, resp.Usage)
|
||||
|
||||
c.messages = append(c.messages, resp.Message())
|
||||
|
||||
if !resp.HasToolCalls() {
|
||||
return resp.Text, nil
|
||||
return resp.Text, totalUsage, nil
|
||||
}
|
||||
|
||||
if c.tools == nil {
|
||||
return "", ErrNoToolsConfigured
|
||||
return "", totalUsage, ErrNoToolsConfigured
|
||||
}
|
||||
|
||||
toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("tool execution failed: %w", err)
|
||||
return "", totalUsage, fmt.Errorf("tool execution failed: %w", err)
|
||||
}
|
||||
|
||||
c.messages = append(c.messages, toolResults...)
|
||||
|
||||
132
v2/chat_test.go
132
v2/chat_test.go
@@ -14,7 +14,7 @@ func TestChat_Send(t *testing.T) {
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
text, err := chat.Send(context.Background(), "Hi")
|
||||
text, _, err := chat.Send(context.Background(), "Hi")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -28,7 +28,7 @@ func TestChat_SendMessage(t *testing.T) {
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
|
||||
_, _, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -79,7 +79,7 @@ func TestChat_SetSystem(t *testing.T) {
|
||||
}
|
||||
|
||||
// System message stays first even after adding other messages
|
||||
_, _ = chat.Send(context.Background(), "Hi")
|
||||
_, _, _ = chat.Send(context.Background(), "Hi")
|
||||
chat.SetSystem("New system")
|
||||
msgs = chat.Messages()
|
||||
if msgs[0].Role != RoleSystem {
|
||||
@@ -113,7 +113,7 @@ func TestChat_ToolCallLoop(t *testing.T) {
|
||||
})
|
||||
chat.SetTools(NewToolBox(tool))
|
||||
|
||||
text, err := chat.Send(context.Background(), "test")
|
||||
text, _, err := chat.Send(context.Background(), "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func TestChat_ToolCallLoop_NoTools(t *testing.T) {
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, err := chat.Send(context.Background(), "test")
|
||||
_, _, err := chat.Send(context.Background(), "test")
|
||||
if !errors.Is(err, ErrNoToolsConfigured) {
|
||||
t.Errorf("expected ErrNoToolsConfigured, got %v", err)
|
||||
}
|
||||
@@ -248,7 +248,7 @@ func TestChat_Messages(t *testing.T) {
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, _ = chat.Send(context.Background(), "test")
|
||||
_, _, _ = chat.Send(context.Background(), "test")
|
||||
|
||||
msgs := chat.Messages()
|
||||
// Verify it's a copy — modifying returned slice shouldn't affect chat
|
||||
@@ -265,7 +265,7 @@ func TestChat_Reset(t *testing.T) {
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, _ = chat.Send(context.Background(), "test")
|
||||
_, _, _ = chat.Send(context.Background(), "test")
|
||||
if len(chat.Messages()) == 0 {
|
||||
t.Fatal("expected messages before reset")
|
||||
}
|
||||
@@ -281,7 +281,7 @@ func TestChat_Fork(t *testing.T) {
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, _ = chat.Send(context.Background(), "msg1")
|
||||
_, _, _ = chat.Send(context.Background(), "msg1")
|
||||
|
||||
fork := chat.Fork()
|
||||
|
||||
@@ -291,14 +291,14 @@ func TestChat_Fork(t *testing.T) {
|
||||
}
|
||||
|
||||
// Adding to fork should not affect original
|
||||
_, _ = fork.Send(context.Background(), "msg2")
|
||||
_, _, _ = fork.Send(context.Background(), "msg2")
|
||||
if len(fork.Messages()) == len(chat.Messages()) {
|
||||
t.Error("fork messages should be independent of original")
|
||||
}
|
||||
|
||||
// Adding to original should not affect fork
|
||||
originalLen := len(chat.Messages())
|
||||
_, _ = chat.Send(context.Background(), "msg3")
|
||||
_, _, _ = chat.Send(context.Background(), "msg3")
|
||||
if len(chat.Messages()) == originalLen {
|
||||
t.Error("original should have more messages after send")
|
||||
}
|
||||
@@ -310,7 +310,7 @@ func TestChat_SendWithImages(t *testing.T) {
|
||||
chat := NewChat(model)
|
||||
|
||||
img := Image{URL: "https://example.com/image.png"}
|
||||
text, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
|
||||
text, _, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -355,7 +355,7 @@ func TestChat_MultipleToolCallRounds(t *testing.T) {
|
||||
})
|
||||
chat.SetTools(NewToolBox(tool))
|
||||
|
||||
text, err := chat.Send(context.Background(), "count three times")
|
||||
text, _, err := chat.Send(context.Background(), "count three times")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -378,7 +378,7 @@ func TestChat_SendError(t *testing.T) {
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
_, err := chat.Send(context.Background(), "test")
|
||||
_, _, err := chat.Send(context.Background(), "test")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
@@ -392,7 +392,7 @@ func TestChat_WithRequestOptions(t *testing.T) {
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
|
||||
|
||||
_, err := chat.Send(context.Background(), "test")
|
||||
_, _, err := chat.Send(context.Background(), "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -405,3 +405,107 @@ func TestChat_WithRequestOptions(t *testing.T) {
|
||||
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_Send_UsageAccumulation(t *testing.T) {
|
||||
var callCount int32
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
n := atomic.AddInt32(&callCount, 1)
|
||||
if n == 1 {
|
||||
return provider.Response{
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{ID: "tc1", Name: "greet", Arguments: "{}"},
|
||||
},
|
||||
Usage: &provider.Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
|
||||
}, nil
|
||||
}
|
||||
return provider.Response{
|
||||
Text: "done",
|
||||
Usage: &provider.Usage{InputTokens: 20, OutputTokens: 8, TotalTokens: 28},
|
||||
}, nil
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||
return "hello!", nil
|
||||
})
|
||||
chat.SetTools(NewToolBox(tool))
|
||||
|
||||
text, usage, err := chat.Send(context.Background(), "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if text != "done" {
|
||||
t.Errorf("expected 'done', got %q", text)
|
||||
}
|
||||
if usage == nil {
|
||||
t.Fatal("expected usage, got nil")
|
||||
}
|
||||
if usage.InputTokens != 30 {
|
||||
t.Errorf("expected accumulated input 30, got %d", usage.InputTokens)
|
||||
}
|
||||
if usage.OutputTokens != 13 {
|
||||
t.Errorf("expected accumulated output 13, got %d", usage.OutputTokens)
|
||||
}
|
||||
if usage.TotalTokens != 43 {
|
||||
t.Errorf("expected accumulated total 43, got %d", usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChat_Send_UsageWithDetails(t *testing.T) {
|
||||
var callCount int32
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
n := atomic.AddInt32(&callCount, 1)
|
||||
if n == 1 {
|
||||
return provider.Response{
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{ID: "tc1", Name: "greet", Arguments: "{}"},
|
||||
},
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
Details: map[string]int{
|
||||
"cached_input_tokens": 3,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return provider.Response{
|
||||
Text: "done",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 8,
|
||||
TotalTokens: 28,
|
||||
Details: map[string]int{
|
||||
"cached_input_tokens": 7,
|
||||
"reasoning_tokens": 2,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
chat := NewChat(model)
|
||||
|
||||
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||
return "hello!", nil
|
||||
})
|
||||
chat.SetTools(NewToolBox(tool))
|
||||
|
||||
_, usage, err := chat.Send(context.Background(), "test")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if usage == nil {
|
||||
t.Fatal("expected usage, got nil")
|
||||
}
|
||||
if usage.Details == nil {
|
||||
t.Fatal("expected usage details, got nil")
|
||||
}
|
||||
if usage.Details["cached_input_tokens"] != 10 {
|
||||
t.Errorf("expected cached_input_tokens=10, got %d", usage.Details["cached_input_tokens"])
|
||||
}
|
||||
if usage.Details["reasoning_tokens"] != 2 {
|
||||
t.Errorf("expected reasoning_tokens=2, got %d", usage.Details["reasoning_tokens"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,14 +13,16 @@ const structuredOutputToolName = "structured_output"
|
||||
// Generate sends a single user prompt to the model and parses the response into T.
|
||||
// T must be a struct. The model is forced to return structured output matching T's schema
|
||||
// by using a hidden tool call internally.
|
||||
func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, error) {
|
||||
// Returns the parsed value, token usage, and any error.
|
||||
func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, *Usage, error) {
|
||||
return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...)
|
||||
}
|
||||
|
||||
// GenerateWith sends the given messages to the model and parses the response into T.
|
||||
// T must be a struct. The model is forced to return structured output matching T's schema
|
||||
// by using a hidden tool call internally.
|
||||
func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, error) {
|
||||
// Returns the parsed value, token usage, and any error.
|
||||
func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, *Usage, error) {
|
||||
var zero T
|
||||
|
||||
s := schema.FromStruct(zero)
|
||||
@@ -36,7 +38,7 @@ func GenerateWith[T any](ctx context.Context, model *Model, messages []Message,
|
||||
|
||||
resp, err := model.Complete(ctx, messages, opts...)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
return zero, nil, err
|
||||
}
|
||||
|
||||
// Find the structured_output tool call in the response.
|
||||
@@ -44,11 +46,11 @@ func GenerateWith[T any](ctx context.Context, model *Model, messages []Message,
|
||||
if tc.Name == structuredOutputToolName {
|
||||
var result T
|
||||
if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil {
|
||||
return zero, fmt.Errorf("failed to parse structured output: %w", err)
|
||||
return zero, resp.Usage, fmt.Errorf("failed to parse structured output: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
return result, resp.Usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
return zero, ErrNoStructuredOutput
|
||||
return zero, resp.Usage, ErrNoStructuredOutput
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestGenerate(t *testing.T) {
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
result, err := Generate[testPerson](context.Background(), model, "Tell me about Alice")
|
||||
result, _, err := Generate[testPerson](context.Background(), model, "Tell me about Alice")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func TestGenerateWith(t *testing.T) {
|
||||
UserMessage("Tell me about Bob"),
|
||||
}
|
||||
|
||||
result, err := GenerateWith[testPerson](context.Background(), model, messages)
|
||||
result, _, err := GenerateWith[testPerson](context.Background(), model, messages)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func TestGenerate_NoToolCall(t *testing.T) {
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||
_, _, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
@@ -111,7 +111,7 @@ func TestGenerate_InvalidJSON(t *testing.T) {
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||
_, _, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func TestGenerate_NestedStruct(t *testing.T) {
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
result, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol")
|
||||
result, _, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func TestGenerate_WithOptions(t *testing.T) {
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
_, err := Generate[testPerson](context.Background(), model, "Tell me about Dave",
|
||||
_, _, err := Generate[testPerson](context.Background(), model, "Tell me about Dave",
|
||||
WithTemperature(0.5),
|
||||
WithMaxTokens(200),
|
||||
)
|
||||
@@ -207,7 +207,7 @@ func TestGenerate_WithMiddleware(t *testing.T) {
|
||||
})
|
||||
model := newMockModel(mp).WithMiddleware(mw)
|
||||
|
||||
result, err := Generate[testPerson](context.Background(), model, "Tell me about Eve")
|
||||
result, _, err := Generate[testPerson](context.Background(), model, "Tell me about Eve")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -231,7 +231,7 @@ func TestGenerate_WrongToolName(t *testing.T) {
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
_, err := Generate[testPerson](context.Background(), model, "Tell me about Frank")
|
||||
_, _, err := Generate[testPerson](context.Background(), model, "Tell me about Frank")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
@@ -239,3 +239,44 @@ func TestGenerate_WrongToolName(t *testing.T) {
|
||||
t.Errorf("expected ErrNoStructuredOutput, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerate_ReturnsUsage(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{
|
||||
ToolCalls: []provider.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Name: "structured_output",
|
||||
Arguments: `{"name":"Grace","age":22}`,
|
||||
},
|
||||
},
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 50,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 70,
|
||||
Details: map[string]int{
|
||||
"reasoning_tokens": 5,
|
||||
},
|
||||
},
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
result, usage, err := Generate[testPerson](context.Background(), model, "Tell me about Grace")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Name != "Grace" {
|
||||
t.Errorf("expected name 'Grace', got %q", result.Name)
|
||||
}
|
||||
if usage == nil {
|
||||
t.Fatal("expected usage, got nil")
|
||||
}
|
||||
if usage.InputTokens != 50 {
|
||||
t.Errorf("expected input 50, got %d", usage.InputTokens)
|
||||
}
|
||||
if usage.OutputTokens != 20 {
|
||||
t.Errorf("expected output 20, got %d", usage.OutputTokens)
|
||||
}
|
||||
if usage.Details["reasoning_tokens"] != 5 {
|
||||
t.Errorf("expected reasoning_tokens=5, got %d", usage.Details["reasoning_tokens"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,12 +59,32 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
|
||||
|
||||
var fullText strings.Builder
|
||||
var toolCalls []provider.ToolCall
|
||||
var usage *provider.Usage
|
||||
|
||||
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("google stream error: %w", err)
|
||||
}
|
||||
|
||||
// Track usage from the last chunk (final chunk has cumulative counts)
|
||||
if resp.UsageMetadata != nil {
|
||||
usage = &provider.Usage{
|
||||
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
|
||||
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
||||
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
||||
}
|
||||
details := map[string]int{}
|
||||
if resp.UsageMetadata.CachedContentTokenCount > 0 {
|
||||
details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount)
|
||||
}
|
||||
if resp.UsageMetadata.ThoughtsTokenCount > 0 {
|
||||
details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount)
|
||||
}
|
||||
if len(details) > 0 {
|
||||
usage.Details = details
|
||||
}
|
||||
}
|
||||
|
||||
for _, c := range resp.Candidates {
|
||||
if c.Content == nil {
|
||||
continue
|
||||
@@ -105,6 +125,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
|
||||
Response: &provider.Response{
|
||||
Text: fullText.String(),
|
||||
ToolCalls: toolCalls,
|
||||
Usage: usage,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -284,6 +305,16 @@ func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provide
|
||||
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
||||
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
||||
}
|
||||
details := map[string]int{}
|
||||
if resp.UsageMetadata.CachedContentTokenCount > 0 {
|
||||
details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount)
|
||||
}
|
||||
if resp.UsageMetadata.ThoughtsTokenCount > 0 {
|
||||
details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount)
|
||||
}
|
||||
if len(details) > 0 {
|
||||
res.Usage.Details = details
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
|
||||
@@ -177,6 +177,7 @@ func convertProviderResponse(resp provider.Response) Response {
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
Details: resp.Usage.Details,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ type UsageTracker struct {
|
||||
TotalInput int64
|
||||
TotalOutput int64
|
||||
TotalRequests int64
|
||||
TotalDetails map[string]int64
|
||||
}
|
||||
|
||||
// Add records usage from a single request.
|
||||
@@ -94,6 +95,14 @@ func (ut *UsageTracker) Add(u *Usage) {
|
||||
ut.TotalInput += int64(u.InputTokens)
|
||||
ut.TotalOutput += int64(u.OutputTokens)
|
||||
ut.TotalRequests++
|
||||
if len(u.Details) > 0 {
|
||||
if ut.TotalDetails == nil {
|
||||
ut.TotalDetails = make(map[string]int64)
|
||||
}
|
||||
for k, v := range u.Details {
|
||||
ut.TotalDetails[k] += int64(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Summary returns the accumulated totals.
|
||||
@@ -103,6 +112,20 @@ func (ut *UsageTracker) Summary() (input, output, requests int64) {
|
||||
return ut.TotalInput, ut.TotalOutput, ut.TotalRequests
|
||||
}
|
||||
|
||||
// Details returns a copy of the accumulated detail totals.
|
||||
func (ut *UsageTracker) Details() map[string]int64 {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
if ut.TotalDetails == nil {
|
||||
return nil
|
||||
}
|
||||
cp := make(map[string]int64, len(ut.TotalDetails))
|
||||
for k, v := range ut.TotalDetails {
|
||||
cp[k] = v
|
||||
}
|
||||
return cp
|
||||
}
|
||||
|
||||
// WithUsageTracking returns middleware that accumulates token usage across calls.
|
||||
func WithUsageTracking(tracker *UsageTracker) Middleware {
|
||||
return func(next CompletionFunc) CompletionFunc {
|
||||
|
||||
@@ -280,3 +280,80 @@ func TestWithLogging_Error(t *testing.T) {
|
||||
t.Errorf("expected provider error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTracker_Details(t *testing.T) {
|
||||
tracker := &UsageTracker{}
|
||||
|
||||
tracker.Add(&Usage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
TotalTokens: 150,
|
||||
Details: map[string]int{
|
||||
"cached_input_tokens": 20,
|
||||
"reasoning_tokens": 10,
|
||||
},
|
||||
})
|
||||
|
||||
tracker.Add(&Usage{
|
||||
InputTokens: 80,
|
||||
OutputTokens: 40,
|
||||
TotalTokens: 120,
|
||||
Details: map[string]int{
|
||||
"cached_input_tokens": 15,
|
||||
},
|
||||
})
|
||||
|
||||
details := tracker.Details()
|
||||
if details == nil {
|
||||
t.Fatal("expected details, got nil")
|
||||
}
|
||||
if details["cached_input_tokens"] != 35 {
|
||||
t.Errorf("expected cached_input_tokens=35, got %d", details["cached_input_tokens"])
|
||||
}
|
||||
if details["reasoning_tokens"] != 10 {
|
||||
t.Errorf("expected reasoning_tokens=10, got %d", details["reasoning_tokens"])
|
||||
}
|
||||
|
||||
// Verify returned map is a copy
|
||||
details["cached_input_tokens"] = 999
|
||||
fresh := tracker.Details()
|
||||
if fresh["cached_input_tokens"] != 35 {
|
||||
t.Error("Details() did not return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTracker_Details_Nil(t *testing.T) {
|
||||
tracker := &UsageTracker{}
|
||||
tracker.Add(&Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15})
|
||||
|
||||
details := tracker.Details()
|
||||
if details != nil {
|
||||
t.Errorf("expected nil details for usage without details, got %v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithUsageTracking_WithDetails(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{
|
||||
Text: "ok",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
TotalTokens: 150,
|
||||
Details: map[string]int{
|
||||
"cached_input_tokens": 30,
|
||||
},
|
||||
},
|
||||
})
|
||||
tracker := &UsageTracker{}
|
||||
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
details := tracker.Details()
|
||||
if details["cached_input_tokens"] != 30 {
|
||||
t.Errorf("expected cached_input_tokens=30, got %d", details["cached_input_tokens"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,15 +58,30 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
|
||||
|
||||
cl := openai.NewClient(opts...)
|
||||
oaiReq := p.buildRequest(req)
|
||||
oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
||||
IncludeUsage: openai.Bool(true),
|
||||
}
|
||||
|
||||
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
|
||||
|
||||
var fullText strings.Builder
|
||||
var toolCalls []provider.ToolCall
|
||||
toolCallArgs := map[int]*strings.Builder{}
|
||||
var usage *provider.Usage
|
||||
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
|
||||
// Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true)
|
||||
if chunk.Usage.TotalTokens > 0 {
|
||||
usage = &provider.Usage{
|
||||
InputTokens: int(chunk.Usage.PromptTokens),
|
||||
OutputTokens: int(chunk.Usage.CompletionTokens),
|
||||
TotalTokens: int(chunk.Usage.TotalTokens),
|
||||
Details: extractUsageDetails(chunk.Usage),
|
||||
}
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
// Text delta
|
||||
if choice.Delta.Content != "" {
|
||||
@@ -138,6 +153,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
|
||||
Response: &provider.Response{
|
||||
Text: fullText.String(),
|
||||
ToolCalls: toolCalls,
|
||||
Usage: usage,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -363,6 +379,7 @@ func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Respons
|
||||
OutputTokens: int(resp.Usage.CompletionTokens),
|
||||
TotalTokens: int(resp.Usage.TotalTokens),
|
||||
}
|
||||
res.Usage.Details = extractUsageDetails(resp.Usage)
|
||||
}
|
||||
|
||||
return res
|
||||
@@ -381,6 +398,27 @@ func audioFormat(contentType string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage.
|
||||
func extractUsageDetails(usage openai.CompletionUsage) map[string]int {
|
||||
details := map[string]int{}
|
||||
if usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens)
|
||||
}
|
||||
if usage.CompletionTokensDetails.AudioTokens > 0 {
|
||||
details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens)
|
||||
}
|
||||
if usage.PromptTokensDetails.CachedTokens > 0 {
|
||||
details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
if usage.PromptTokensDetails.AudioTokens > 0 {
|
||||
details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens)
|
||||
}
|
||||
if len(details) == 0 {
|
||||
return nil
|
||||
}
|
||||
return details
|
||||
}
|
||||
|
||||
// audioFormatFromURL guesses the audio format from a URL's file extension.
|
||||
func audioFormatFromURL(u string) string {
|
||||
ext := strings.ToLower(path.Ext(u))
|
||||
|
||||
83
v2/pricing.go
Normal file
83
v2/pricing.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package llm
|
||||
|
||||
import "sync"
|
||||
|
||||
// ModelPricing defines per-token pricing for a model.
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // USD per input token
|
||||
OutputPricePerToken float64 // USD per output token
|
||||
CachedInputPricePerToken float64 // USD per cached input token (0 = same as input)
|
||||
}
|
||||
|
||||
// Cost computes the total USD cost from a Usage.
|
||||
// When CachedInputPricePerToken is set and the usage includes cached_input_tokens,
|
||||
// those tokens are charged at the cached rate instead of the regular input rate.
|
||||
func (mp ModelPricing) Cost(u *Usage) float64 {
|
||||
if u == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
inputTokens := u.InputTokens
|
||||
cachedTokens := 0
|
||||
if u.Details != nil {
|
||||
cachedTokens = u.Details[UsageDetailCachedInputTokens]
|
||||
}
|
||||
|
||||
var cost float64
|
||||
|
||||
if mp.CachedInputPricePerToken > 0 && cachedTokens > 0 {
|
||||
regularInput := inputTokens - cachedTokens
|
||||
if regularInput < 0 {
|
||||
regularInput = 0
|
||||
}
|
||||
cost += float64(regularInput) * mp.InputPricePerToken
|
||||
cost += float64(cachedTokens) * mp.CachedInputPricePerToken
|
||||
} else {
|
||||
cost += float64(inputTokens) * mp.InputPricePerToken
|
||||
}
|
||||
|
||||
cost += float64(u.OutputTokens) * mp.OutputPricePerToken
|
||||
|
||||
return cost
|
||||
}
|
||||
|
||||
// PricingRegistry maps model names to their pricing.
|
||||
// Callers populate it with the models and prices relevant to their use case.
|
||||
type PricingRegistry struct {
|
||||
mu sync.RWMutex
|
||||
models map[string]ModelPricing
|
||||
}
|
||||
|
||||
// NewPricingRegistry creates an empty pricing registry.
|
||||
func NewPricingRegistry() *PricingRegistry {
|
||||
return &PricingRegistry{
|
||||
models: make(map[string]ModelPricing),
|
||||
}
|
||||
}
|
||||
|
||||
// Set registers pricing for a model.
|
||||
func (pr *PricingRegistry) Set(model string, pricing ModelPricing) {
|
||||
pr.mu.Lock()
|
||||
defer pr.mu.Unlock()
|
||||
pr.models[model] = pricing
|
||||
}
|
||||
|
||||
// Has returns true if pricing is registered for the given model.
|
||||
func (pr *PricingRegistry) Has(model string) bool {
|
||||
pr.mu.RLock()
|
||||
defer pr.mu.RUnlock()
|
||||
_, ok := pr.models[model]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Cost computes the USD cost for the given model and usage.
|
||||
// Returns 0 if the model is not registered.
|
||||
func (pr *PricingRegistry) Cost(model string, u *Usage) float64 {
|
||||
pr.mu.RLock()
|
||||
pricing, ok := pr.models[model]
|
||||
pr.mu.RUnlock()
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return pricing.Cost(u)
|
||||
}
|
||||
128
v2/pricing_test.go
Normal file
128
v2/pricing_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelPricing_Cost(t *testing.T) {
|
||||
pricing := ModelPricing{
|
||||
InputPricePerToken: 0.000003, // $3/MTok
|
||||
OutputPricePerToken: 0.000015, // $15/MTok
|
||||
}
|
||||
|
||||
usage := &Usage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 500,
|
||||
TotalTokens: 1500,
|
||||
}
|
||||
|
||||
cost := pricing.Cost(usage)
|
||||
expected := 1000*0.000003 + 500*0.000015
|
||||
if math.Abs(cost-expected) > 1e-10 {
|
||||
t.Errorf("expected cost %f, got %f", expected, cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelPricing_Cost_WithCachedTokens(t *testing.T) {
|
||||
pricing := ModelPricing{
|
||||
InputPricePerToken: 0.000003, // $3/MTok
|
||||
OutputPricePerToken: 0.000015, // $15/MTok
|
||||
CachedInputPricePerToken: 0.0000015, // $1.50/MTok (50% discount)
|
||||
}
|
||||
|
||||
usage := &Usage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 500,
|
||||
TotalTokens: 1500,
|
||||
Details: map[string]int{
|
||||
UsageDetailCachedInputTokens: 400,
|
||||
},
|
||||
}
|
||||
|
||||
cost := pricing.Cost(usage)
|
||||
// 600 regular input tokens + 400 cached tokens + 500 output tokens
|
||||
expected := 600*0.000003 + 400*0.0000015 + 500*0.000015
|
||||
if math.Abs(cost-expected) > 1e-10 {
|
||||
t.Errorf("expected cost %f, got %f", expected, cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelPricing_Cost_NilUsage(t *testing.T) {
|
||||
pricing := ModelPricing{
|
||||
InputPricePerToken: 0.000003,
|
||||
OutputPricePerToken: 0.000015,
|
||||
}
|
||||
|
||||
cost := pricing.Cost(nil)
|
||||
if cost != 0 {
|
||||
t.Errorf("expected 0 for nil usage, got %f", cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelPricing_Cost_NoCachedPrice(t *testing.T) {
|
||||
// When CachedInputPricePerToken is 0, all input tokens use InputPricePerToken
|
||||
pricing := ModelPricing{
|
||||
InputPricePerToken: 0.000003,
|
||||
OutputPricePerToken: 0.000015,
|
||||
}
|
||||
|
||||
usage := &Usage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 500,
|
||||
TotalTokens: 1500,
|
||||
Details: map[string]int{
|
||||
UsageDetailCachedInputTokens: 400,
|
||||
},
|
||||
}
|
||||
|
||||
cost := pricing.Cost(usage)
|
||||
expected := 1000*0.000003 + 500*0.000015
|
||||
if math.Abs(cost-expected) > 1e-10 {
|
||||
t.Errorf("expected cost %f, got %f", expected, cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPricingRegistry(t *testing.T) {
|
||||
registry := NewPricingRegistry()
|
||||
|
||||
registry.Set("gpt-4o", ModelPricing{
|
||||
InputPricePerToken: 0.0000025,
|
||||
OutputPricePerToken: 0.00001,
|
||||
})
|
||||
|
||||
if !registry.Has("gpt-4o") {
|
||||
t.Error("expected Has('gpt-4o') to be true")
|
||||
}
|
||||
if registry.Has("gpt-3.5-turbo") {
|
||||
t.Error("expected Has('gpt-3.5-turbo') to be false")
|
||||
}
|
||||
|
||||
usage := &Usage{InputTokens: 1000, OutputTokens: 200, TotalTokens: 1200}
|
||||
|
||||
cost := registry.Cost("gpt-4o", usage)
|
||||
expected := 1000*0.0000025 + 200*0.00001
|
||||
if math.Abs(cost-expected) > 1e-10 {
|
||||
t.Errorf("expected cost %f, got %f", expected, cost)
|
||||
}
|
||||
|
||||
// Unknown model returns 0
|
||||
cost = registry.Cost("unknown-model", usage)
|
||||
if cost != 0 {
|
||||
t.Errorf("expected 0 for unknown model, got %f", cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPricingRegistry_Override(t *testing.T) {
|
||||
registry := NewPricingRegistry()
|
||||
|
||||
registry.Set("model-a", ModelPricing{InputPricePerToken: 0.001, OutputPricePerToken: 0.002})
|
||||
registry.Set("model-a", ModelPricing{InputPricePerToken: 0.003, OutputPricePerToken: 0.004})
|
||||
|
||||
usage := &Usage{InputTokens: 100, OutputTokens: 50, TotalTokens: 150}
|
||||
cost := registry.Cost("model-a", usage)
|
||||
expected := 100*0.003 + 50*0.004
|
||||
if math.Abs(cost-expected) > 1e-10 {
|
||||
t.Errorf("expected overridden cost %f, got %f", expected, cost)
|
||||
}
|
||||
}
|
||||
@@ -64,8 +64,19 @@ type Usage struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
TotalTokens int
|
||||
Details map[string]int // provider-specific breakdown (e.g., cached, reasoning tokens)
|
||||
}
|
||||
|
||||
// Standardized detail keys for provider-specific token breakdowns.
|
||||
const (
|
||||
UsageDetailReasoningTokens = "reasoning_tokens"
|
||||
UsageDetailCachedInputTokens = "cached_input_tokens"
|
||||
UsageDetailCacheCreationTokens = "cache_creation_tokens"
|
||||
UsageDetailAudioInputTokens = "audio_input_tokens"
|
||||
UsageDetailAudioOutputTokens = "audio_output_tokens"
|
||||
UsageDetailThoughtsTokens = "thoughts_tokens"
|
||||
)
|
||||
|
||||
// StreamEventType identifies the kind of stream event.
|
||||
type StreamEventType int
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package llm
|
||||
|
||||
import "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
|
||||
// Response represents the result of a completion request.
|
||||
type Response struct {
|
||||
// Text is the assistant's text content. Empty if only tool calls.
|
||||
@@ -31,4 +33,45 @@ type Usage struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
TotalTokens int
|
||||
Details map[string]int // provider-specific breakdown (e.g., cached, reasoning tokens)
|
||||
}
|
||||
|
||||
// addUsage merges usage u into the receiver, accumulating token counts and details.
|
||||
// If the receiver is nil, it returns a copy of u. If u is nil, it returns the receiver unchanged.
|
||||
func addUsage(total *Usage, u *Usage) *Usage {
|
||||
if u == nil {
|
||||
return total
|
||||
}
|
||||
if total == nil {
|
||||
cp := *u
|
||||
if u.Details != nil {
|
||||
cp.Details = make(map[string]int, len(u.Details))
|
||||
for k, v := range u.Details {
|
||||
cp.Details[k] = v
|
||||
}
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
total.InputTokens += u.InputTokens
|
||||
total.OutputTokens += u.OutputTokens
|
||||
total.TotalTokens += u.TotalTokens
|
||||
if u.Details != nil {
|
||||
if total.Details == nil {
|
||||
total.Details = make(map[string]int, len(u.Details))
|
||||
}
|
||||
for k, v := range u.Details {
|
||||
total.Details[k] += v
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// Re-export detail key constants from provider package for convenience.
|
||||
const (
|
||||
UsageDetailReasoningTokens = provider.UsageDetailReasoningTokens
|
||||
UsageDetailCachedInputTokens = provider.UsageDetailCachedInputTokens
|
||||
UsageDetailCacheCreationTokens = provider.UsageDetailCacheCreationTokens
|
||||
UsageDetailAudioInputTokens = provider.UsageDetailAudioInputTokens
|
||||
UsageDetailAudioOutputTokens = provider.UsageDetailAudioOutputTokens
|
||||
UsageDetailThoughtsTokens = provider.UsageDetailThoughtsTokens
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user