feat: comprehensive token usage tracking for V2 #3
@@ -18,7 +18,7 @@
|
|||||||
// coder.AsTool("code", "Write and run code"),
|
// coder.AsTool("code", "Write and run code"),
|
||||||
// )),
|
// )),
|
||||||
// )
|
// )
|
||||||
// result, err := orchestrator.Run(ctx, "Build a fibonacci function in Go")
|
// result, _, err := orchestrator.Run(ctx, "Build a fibonacci function in Go")
|
||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -64,13 +64,15 @@ func New(model *llm.Model, system string, opts ...Option) *Agent {
|
|||||||
|
|
||||||
// Run executes the agent with a user prompt. Each call is a fresh conversation.
|
// Run executes the agent with a user prompt. Each call is a fresh conversation.
|
||||||
// The agent loops tool calls automatically until it produces a text response.
|
// The agent loops tool calls automatically until it produces a text response.
|
||||||
func (a *Agent) Run(ctx context.Context, prompt string) (string, error) {
|
// Returns the text response, accumulated token usage, and any error.
|
||||||
|
func (a *Agent) Run(ctx context.Context, prompt string) (string, *llm.Usage, error) {
|
||||||
return a.RunMessages(ctx, []llm.Message{llm.UserMessage(prompt)})
|
return a.RunMessages(ctx, []llm.Message{llm.UserMessage(prompt)})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunMessages executes the agent with full message control.
|
// RunMessages executes the agent with full message control.
|
||||||
// Each call is a fresh conversation. The agent loops tool calls automatically.
|
// Each call is a fresh conversation. The agent loops tool calls automatically.
|
||||||
func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, error) {
|
// Returns the text response, accumulated token usage, and any error.
|
||||||
|
func (a *Agent) RunMessages(ctx context.Context, messages []llm.Message) (string, *llm.Usage, error) {
|
||||||
chat := llm.NewChat(a.model, a.reqOpts...)
|
chat := llm.NewChat(a.model, a.reqOpts...)
|
||||||
if a.system != "" {
|
if a.system != "" {
|
||||||
chat.SetSystem(a.system)
|
chat.SetSystem(a.system)
|
||||||
@@ -107,7 +109,8 @@ type delegateParams struct {
|
|||||||
func (a *Agent) AsTool(name, description string) llm.Tool {
|
func (a *Agent) AsTool(name, description string) llm.Tool {
|
||||||
return llm.Define[delegateParams](name, description,
|
return llm.Define[delegateParams](name, description,
|
||||||
func(ctx context.Context, p delegateParams) (string, error) {
|
func(ctx context.Context, p delegateParams) (string, error) {
|
||||||
return a.Run(ctx, p.Input)
|
text, _, err := a.Run(ctx, p.Input)
|
||||||
|
return text, err
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,15 +29,6 @@ func (m *mockProvider) Stream(ctx context.Context, req provider.Request, events
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockProvider) lastRequest() provider.Request {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
if len(m.requests) == 0 {
|
|
||||||
return provider.Request{}
|
|
||||||
}
|
|
||||||
return m.requests[len(m.requests)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockModel(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *llm.Model {
|
func newMockModel(fn func(ctx context.Context, req provider.Request) (provider.Response, error)) *llm.Model {
|
||||||
mp := &mockProvider{completeFunc: fn}
|
mp := &mockProvider{completeFunc: fn}
|
||||||
return llm.NewClient(mp).Model("mock-model")
|
return llm.NewClient(mp).Model("mock-model")
|
||||||
@@ -53,7 +44,7 @@ func TestAgent_Run(t *testing.T) {
|
|||||||
model := newSimpleMockModel("Hello from agent!")
|
model := newSimpleMockModel("Hello from agent!")
|
||||||
a := New(model, "You are a helpful assistant.")
|
a := New(model, "You are a helpful assistant.")
|
||||||
|
|
||||||
result, err := a.Run(context.Background(), "Say hello")
|
result, _, err := a.Run(context.Background(), "Say hello")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -83,7 +74,7 @@ func TestAgent_Run_WithTools(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
a := New(model, "You are helpful.", WithTools(llm.NewToolBox(tool)))
|
a := New(model, "You are helpful.", WithTools(llm.NewToolBox(tool)))
|
||||||
result, err := a.Run(context.Background(), "Use the greet tool")
|
result, _, err := a.Run(context.Background(), "Use the greet tool")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -147,7 +138,7 @@ func TestAgent_AsTool_ParentChild(t *testing.T) {
|
|||||||
)),
|
)),
|
||||||
)
|
)
|
||||||
|
|
||||||
result, err := parent.Run(context.Background(), "Tell me about Go generics")
|
result, _, err := parent.Run(context.Background(), "Tell me about Go generics")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -169,7 +160,7 @@ func TestAgent_RunMessages(t *testing.T) {
|
|||||||
llm.UserMessage("Follow up"),
|
llm.UserMessage("Follow up"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := a.RunMessages(context.Background(), messages)
|
result, _, err := a.RunMessages(context.Background(), messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -187,7 +178,7 @@ func TestAgent_ContextCancellation(t *testing.T) {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel() // Cancel immediately
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
_, err := a.Run(ctx, "This should fail")
|
_, _, err := a.Run(ctx, "This should fail")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error from cancelled context")
|
t.Fatal("expected error from cancelled context")
|
||||||
}
|
}
|
||||||
@@ -204,7 +195,7 @@ func TestAgent_WithRequestOptions(t *testing.T) {
|
|||||||
WithRequestOptions(llm.WithTemperature(0.3), llm.WithMaxTokens(100)),
|
WithRequestOptions(llm.WithTemperature(0.3), llm.WithMaxTokens(100)),
|
||||||
)
|
)
|
||||||
|
|
||||||
_, err := a.Run(context.Background(), "test")
|
_, _, err := a.Run(context.Background(), "test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -224,7 +215,7 @@ func TestAgent_Run_Error(t *testing.T) {
|
|||||||
})
|
})
|
||||||
a := New(model, "You are helpful.")
|
a := New(model, "You are helpful.")
|
||||||
|
|
||||||
_, err := a.Run(context.Background(), "test")
|
_, _, err := a.Run(context.Background(), "test")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -234,7 +225,7 @@ func TestAgent_EmptySystem(t *testing.T) {
|
|||||||
model := newSimpleMockModel("no system prompt")
|
model := newSimpleMockModel("no system prompt")
|
||||||
a := New(model, "") // Empty system prompt
|
a := New(model, "") // Empty system prompt
|
||||||
|
|
||||||
result, err := a.Run(context.Background(), "test")
|
result, _, err := a.Run(context.Background(), "test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -242,3 +233,34 @@ func TestAgent_EmptySystem(t *testing.T) {
|
|||||||
t.Errorf("unexpected result: %q", result)
|
t.Errorf("unexpected result: %q", result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAgent_Run_ReturnsUsage(t *testing.T) {
|
||||||
|
model := newMockModel(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
return provider.Response{
|
||||||
|
Text: "result",
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
a := New(model, "You are helpful.")
|
||||||
|
result, usage, err := a.Run(context.Background(), "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result != "result" {
|
||||||
|
t.Errorf("expected 'result', got %q", result)
|
||||||
|
}
|
||||||
|
if usage == nil {
|
||||||
|
t.Fatal("expected usage, got nil")
|
||||||
|
}
|
||||||
|
if usage.InputTokens != 100 {
|
||||||
|
t.Errorf("expected input 100, got %d", usage.InputTokens)
|
||||||
|
}
|
||||||
|
if usage.OutputTokens != 50 {
|
||||||
|
t.Errorf("expected output 50, got %d", usage.OutputTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func Example_researcher() {
|
|||||||
agent.WithRequestOptions(llm.WithTemperature(0.3)),
|
agent.WithRequestOptions(llm.WithTemperature(0.3)),
|
||||||
)
|
)
|
||||||
|
|
||||||
result, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?")
|
result, _, err := researcher.Run(context.Background(), "What are the latest developments in Go generics?")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error:", err)
|
fmt.Println("Error:", err)
|
||||||
return
|
return
|
||||||
@@ -50,7 +50,7 @@ func Example_coder() {
|
|||||||
)),
|
)),
|
||||||
)
|
)
|
||||||
|
|
||||||
result, err := coder.Run(context.Background(),
|
result, _, err := coder.Run(context.Background(),
|
||||||
"Create a Go program that prints the first 10 Fibonacci numbers. Save it and run it.")
|
"Create a Go program that prints the first 10 Fibonacci numbers. Save it and run it.")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error:", err)
|
fmt.Println("Error:", err)
|
||||||
@@ -97,7 +97,7 @@ func Example_orchestrator() {
|
|||||||
)),
|
)),
|
||||||
)
|
)
|
||||||
|
|
||||||
result, err := orchestrator.Run(context.Background(),
|
result, _, err := orchestrator.Run(context.Background(),
|
||||||
"Research how to implement a binary search tree in Go, then create one with insert and search operations.")
|
"Research how to implement a binary search tree in Go, then create one with insert and search operations.")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error:", err)
|
fmt.Println("Error:", err)
|
||||||
|
|||||||
@@ -270,6 +270,16 @@ func (p *Provider) convertResponse(resp anth.MessagesResponse) provider.Response
|
|||||||
OutputTokens: resp.Usage.OutputTokens,
|
OutputTokens: resp.Usage.OutputTokens,
|
||||||
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
|
||||||
}
|
}
|
||||||
|
details := map[string]int{}
|
||||||
|
if resp.Usage.CacheCreationInputTokens > 0 {
|
||||||
|
details[provider.UsageDetailCacheCreationTokens] = resp.Usage.CacheCreationInputTokens
|
||||||
|
}
|
||||||
|
if resp.Usage.CacheReadInputTokens > 0 {
|
||||||
|
details[provider.UsageDetailCachedInputTokens] = resp.Usage.CacheReadInputTokens
|
||||||
|
}
|
||||||
|
if len(details) > 0 {
|
||||||
|
res.Usage.Details = details
|
||||||
|
}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|||||||
24
v2/chat.go
24
v2/chat.go
@@ -38,44 +38,50 @@ func (c *Chat) SetTools(tb *ToolBox) {
|
|||||||
c.tools = tb
|
c.tools = tb
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send sends a user message and returns the assistant's text response.
|
// Send sends a user message and returns the assistant's text response along with
|
||||||
|
// accumulated token usage from all iterations of the tool-call loop.
|
||||||
// If the model calls tools, they are executed automatically and the loop
|
// If the model calls tools, they are executed automatically and the loop
|
||||||
// continues until the model produces a text response (the "agent loop").
|
// continues until the model produces a text response (the "agent loop").
|
||||||
func (c *Chat) Send(ctx context.Context, text string) (string, error) {
|
func (c *Chat) Send(ctx context.Context, text string) (string, *Usage, error) {
|
||||||
return c.SendMessage(ctx, UserMessage(text))
|
return c.SendMessage(ctx, UserMessage(text))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendWithImages sends a user message with images attached.
|
// SendWithImages sends a user message with images attached.
|
||||||
func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, error) {
|
func (c *Chat) SendWithImages(ctx context.Context, text string, images ...Image) (string, *Usage, error) {
|
||||||
return c.SendMessage(ctx, UserMessageWithImages(text, images...))
|
return c.SendMessage(ctx, UserMessageWithImages(text, images...))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessage sends an arbitrary message and returns the final text response.
|
// SendMessage sends an arbitrary message and returns the final text response along with
|
||||||
|
// accumulated token usage from all iterations of the tool-call loop.
|
||||||
// Handles the full tool-call loop automatically.
|
// Handles the full tool-call loop automatically.
|
||||||
func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, error) {
|
func (c *Chat) SendMessage(ctx context.Context, msg Message) (string, *Usage, error) {
|
||||||
c.messages = append(c.messages, msg)
|
c.messages = append(c.messages, msg)
|
||||||
|
|
||||||
opts := c.buildOpts()
|
opts := c.buildOpts()
|
||||||
|
|
||||||
|
var totalUsage *Usage
|
||||||
|
|
||||||
for {
|
for {
|
||||||
resp, err := c.model.Complete(ctx, c.messages, opts...)
|
resp, err := c.model.Complete(ctx, c.messages, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("completion failed: %w", err)
|
return "", totalUsage, fmt.Errorf("completion failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalUsage = addUsage(totalUsage, resp.Usage)
|
||||||
|
|
||||||
c.messages = append(c.messages, resp.Message())
|
c.messages = append(c.messages, resp.Message())
|
||||||
|
|
||||||
if !resp.HasToolCalls() {
|
if !resp.HasToolCalls() {
|
||||||
return resp.Text, nil
|
return resp.Text, totalUsage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.tools == nil {
|
if c.tools == nil {
|
||||||
return "", ErrNoToolsConfigured
|
return "", totalUsage, ErrNoToolsConfigured
|
||||||
}
|
}
|
||||||
|
|
||||||
toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls)
|
toolResults, err := c.tools.ExecuteAll(ctx, resp.ToolCalls)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("tool execution failed: %w", err)
|
return "", totalUsage, fmt.Errorf("tool execution failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.messages = append(c.messages, toolResults...)
|
c.messages = append(c.messages, toolResults...)
|
||||||
|
|||||||
132
v2/chat_test.go
132
v2/chat_test.go
@@ -14,7 +14,7 @@ func TestChat_Send(t *testing.T) {
|
|||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
chat := NewChat(model)
|
chat := NewChat(model)
|
||||||
|
|
||||||
text, err := chat.Send(context.Background(), "Hi")
|
text, _, err := chat.Send(context.Background(), "Hi")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -28,7 +28,7 @@ func TestChat_SendMessage(t *testing.T) {
|
|||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
chat := NewChat(model)
|
chat := NewChat(model)
|
||||||
|
|
||||||
_, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
|
_, _, err := chat.SendMessage(context.Background(), UserMessage("msg1"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -79,7 +79,7 @@ func TestChat_SetSystem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// System message stays first even after adding other messages
|
// System message stays first even after adding other messages
|
||||||
_, _ = chat.Send(context.Background(), "Hi")
|
_, _, _ = chat.Send(context.Background(), "Hi")
|
||||||
chat.SetSystem("New system")
|
chat.SetSystem("New system")
|
||||||
msgs = chat.Messages()
|
msgs = chat.Messages()
|
||||||
if msgs[0].Role != RoleSystem {
|
if msgs[0].Role != RoleSystem {
|
||||||
@@ -113,7 +113,7 @@ func TestChat_ToolCallLoop(t *testing.T) {
|
|||||||
})
|
})
|
||||||
chat.SetTools(NewToolBox(tool))
|
chat.SetTools(NewToolBox(tool))
|
||||||
|
|
||||||
text, err := chat.Send(context.Background(), "test")
|
text, _, err := chat.Send(context.Background(), "test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -158,7 +158,7 @@ func TestChat_ToolCallLoop_NoTools(t *testing.T) {
|
|||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
chat := NewChat(model)
|
chat := NewChat(model)
|
||||||
|
|
||||||
_, err := chat.Send(context.Background(), "test")
|
_, _, err := chat.Send(context.Background(), "test")
|
||||||
if !errors.Is(err, ErrNoToolsConfigured) {
|
if !errors.Is(err, ErrNoToolsConfigured) {
|
||||||
t.Errorf("expected ErrNoToolsConfigured, got %v", err)
|
t.Errorf("expected ErrNoToolsConfigured, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -248,7 +248,7 @@ func TestChat_Messages(t *testing.T) {
|
|||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
chat := NewChat(model)
|
chat := NewChat(model)
|
||||||
|
|
||||||
_, _ = chat.Send(context.Background(), "test")
|
_, _, _ = chat.Send(context.Background(), "test")
|
||||||
|
|
||||||
msgs := chat.Messages()
|
msgs := chat.Messages()
|
||||||
// Verify it's a copy — modifying returned slice shouldn't affect chat
|
// Verify it's a copy — modifying returned slice shouldn't affect chat
|
||||||
@@ -265,7 +265,7 @@ func TestChat_Reset(t *testing.T) {
|
|||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
chat := NewChat(model)
|
chat := NewChat(model)
|
||||||
|
|
||||||
_, _ = chat.Send(context.Background(), "test")
|
_, _, _ = chat.Send(context.Background(), "test")
|
||||||
if len(chat.Messages()) == 0 {
|
if len(chat.Messages()) == 0 {
|
||||||
t.Fatal("expected messages before reset")
|
t.Fatal("expected messages before reset")
|
||||||
}
|
}
|
||||||
@@ -281,7 +281,7 @@ func TestChat_Fork(t *testing.T) {
|
|||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
chat := NewChat(model)
|
chat := NewChat(model)
|
||||||
|
|
||||||
_, _ = chat.Send(context.Background(), "msg1")
|
_, _, _ = chat.Send(context.Background(), "msg1")
|
||||||
|
|
||||||
fork := chat.Fork()
|
fork := chat.Fork()
|
||||||
|
|
||||||
@@ -291,14 +291,14 @@ func TestChat_Fork(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Adding to fork should not affect original
|
// Adding to fork should not affect original
|
||||||
_, _ = fork.Send(context.Background(), "msg2")
|
_, _, _ = fork.Send(context.Background(), "msg2")
|
||||||
if len(fork.Messages()) == len(chat.Messages()) {
|
if len(fork.Messages()) == len(chat.Messages()) {
|
||||||
t.Error("fork messages should be independent of original")
|
t.Error("fork messages should be independent of original")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adding to original should not affect fork
|
// Adding to original should not affect fork
|
||||||
originalLen := len(chat.Messages())
|
originalLen := len(chat.Messages())
|
||||||
_, _ = chat.Send(context.Background(), "msg3")
|
_, _, _ = chat.Send(context.Background(), "msg3")
|
||||||
if len(chat.Messages()) == originalLen {
|
if len(chat.Messages()) == originalLen {
|
||||||
t.Error("original should have more messages after send")
|
t.Error("original should have more messages after send")
|
||||||
}
|
}
|
||||||
@@ -310,7 +310,7 @@ func TestChat_SendWithImages(t *testing.T) {
|
|||||||
chat := NewChat(model)
|
chat := NewChat(model)
|
||||||
|
|
||||||
img := Image{URL: "https://example.com/image.png"}
|
img := Image{URL: "https://example.com/image.png"}
|
||||||
text, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
|
text, _, err := chat.SendWithImages(context.Background(), "What's in this image?", img)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -355,7 +355,7 @@ func TestChat_MultipleToolCallRounds(t *testing.T) {
|
|||||||
})
|
})
|
||||||
chat.SetTools(NewToolBox(tool))
|
chat.SetTools(NewToolBox(tool))
|
||||||
|
|
||||||
text, err := chat.Send(context.Background(), "count three times")
|
text, _, err := chat.Send(context.Background(), "count three times")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -378,7 +378,7 @@ func TestChat_SendError(t *testing.T) {
|
|||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
chat := NewChat(model)
|
chat := NewChat(model)
|
||||||
|
|
||||||
_, err := chat.Send(context.Background(), "test")
|
_, _, err := chat.Send(context.Background(), "test")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -392,7 +392,7 @@ func TestChat_WithRequestOptions(t *testing.T) {
|
|||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
|
chat := NewChat(model, WithTemperature(0.5), WithMaxTokens(200))
|
||||||
|
|
||||||
_, err := chat.Send(context.Background(), "test")
|
_, _, err := chat.Send(context.Background(), "test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -405,3 +405,107 @@ func TestChat_WithRequestOptions(t *testing.T) {
|
|||||||
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
|
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChat_Send_UsageAccumulation(t *testing.T) {
|
||||||
|
var callCount int32
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
n := atomic.AddInt32(&callCount, 1)
|
||||||
|
if n == 1 {
|
||||||
|
return provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "greet", Arguments: "{}"},
|
||||||
|
},
|
||||||
|
Usage: &provider.Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return provider.Response{
|
||||||
|
Text: "done",
|
||||||
|
Usage: &provider.Usage{InputTokens: 20, OutputTokens: 8, TotalTokens: 28},
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||||
|
return "hello!", nil
|
||||||
|
})
|
||||||
|
chat.SetTools(NewToolBox(tool))
|
||||||
|
|
||||||
|
text, usage, err := chat.Send(context.Background(), "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if text != "done" {
|
||||||
|
t.Errorf("expected 'done', got %q", text)
|
||||||
|
}
|
||||||
|
if usage == nil {
|
||||||
|
t.Fatal("expected usage, got nil")
|
||||||
|
}
|
||||||
|
if usage.InputTokens != 30 {
|
||||||
|
t.Errorf("expected accumulated input 30, got %d", usage.InputTokens)
|
||||||
|
}
|
||||||
|
if usage.OutputTokens != 13 {
|
||||||
|
t.Errorf("expected accumulated output 13, got %d", usage.OutputTokens)
|
||||||
|
}
|
||||||
|
if usage.TotalTokens != 43 {
|
||||||
|
t.Errorf("expected accumulated total 43, got %d", usage.TotalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChat_Send_UsageWithDetails(t *testing.T) {
|
||||||
|
var callCount int32
|
||||||
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||||
|
n := atomic.AddInt32(&callCount, 1)
|
||||||
|
if n == 1 {
|
||||||
|
return provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{ID: "tc1", Name: "greet", Arguments: "{}"},
|
||||||
|
},
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalTokens: 15,
|
||||||
|
Details: map[string]int{
|
||||||
|
"cached_input_tokens": 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return provider.Response{
|
||||||
|
Text: "done",
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 20,
|
||||||
|
OutputTokens: 8,
|
||||||
|
TotalTokens: 28,
|
||||||
|
Details: map[string]int{
|
||||||
|
"cached_input_tokens": 7,
|
||||||
|
"reasoning_tokens": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
chat := NewChat(model)
|
||||||
|
|
||||||
|
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||||
|
return "hello!", nil
|
||||||
|
})
|
||||||
|
chat.SetTools(NewToolBox(tool))
|
||||||
|
|
||||||
|
_, usage, err := chat.Send(context.Background(), "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if usage == nil {
|
||||||
|
t.Fatal("expected usage, got nil")
|
||||||
|
}
|
||||||
|
if usage.Details == nil {
|
||||||
|
t.Fatal("expected usage details, got nil")
|
||||||
|
}
|
||||||
|
if usage.Details["cached_input_tokens"] != 10 {
|
||||||
|
t.Errorf("expected cached_input_tokens=10, got %d", usage.Details["cached_input_tokens"])
|
||||||
|
}
|
||||||
|
if usage.Details["reasoning_tokens"] != 2 {
|
||||||
|
t.Errorf("expected reasoning_tokens=2, got %d", usage.Details["reasoning_tokens"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,14 +13,16 @@ const structuredOutputToolName = "structured_output"
|
|||||||
// Generate sends a single user prompt to the model and parses the response into T.
|
// Generate sends a single user prompt to the model and parses the response into T.
|
||||||
// T must be a struct. The model is forced to return structured output matching T's schema
|
// T must be a struct. The model is forced to return structured output matching T's schema
|
||||||
// by using a hidden tool call internally.
|
// by using a hidden tool call internally.
|
||||||
func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, error) {
|
// Returns the parsed value, token usage, and any error.
|
||||||
|
func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, *Usage, error) {
|
||||||
return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...)
|
return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateWith sends the given messages to the model and parses the response into T.
|
// GenerateWith sends the given messages to the model and parses the response into T.
|
||||||
// T must be a struct. The model is forced to return structured output matching T's schema
|
// T must be a struct. The model is forced to return structured output matching T's schema
|
||||||
// by using a hidden tool call internally.
|
// by using a hidden tool call internally.
|
||||||
func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, error) {
|
// Returns the parsed value, token usage, and any error.
|
||||||
|
func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, *Usage, error) {
|
||||||
var zero T
|
var zero T
|
||||||
|
|
||||||
s := schema.FromStruct(zero)
|
s := schema.FromStruct(zero)
|
||||||
@@ -36,7 +38,7 @@ func GenerateWith[T any](ctx context.Context, model *Model, messages []Message,
|
|||||||
|
|
||||||
resp, err := model.Complete(ctx, messages, opts...)
|
resp, err := model.Complete(ctx, messages, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return zero, err
|
return zero, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the structured_output tool call in the response.
|
// Find the structured_output tool call in the response.
|
||||||
@@ -44,11 +46,11 @@ func GenerateWith[T any](ctx context.Context, model *Model, messages []Message,
|
|||||||
if tc.Name == structuredOutputToolName {
|
if tc.Name == structuredOutputToolName {
|
||||||
var result T
|
var result T
|
||||||
if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil {
|
if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil {
|
||||||
return zero, fmt.Errorf("failed to parse structured output: %w", err)
|
return zero, resp.Usage, fmt.Errorf("failed to parse structured output: %w", err)
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, resp.Usage, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return zero, ErrNoStructuredOutput
|
return zero, resp.Usage, ErrNoStructuredOutput
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func TestGenerate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
|
|
||||||
result, err := Generate[testPerson](context.Background(), model, "Tell me about Alice")
|
result, _, err := Generate[testPerson](context.Background(), model, "Tell me about Alice")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -63,7 +63,7 @@ func TestGenerateWith(t *testing.T) {
|
|||||||
UserMessage("Tell me about Bob"),
|
UserMessage("Tell me about Bob"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := GenerateWith[testPerson](context.Background(), model, messages)
|
result, _, err := GenerateWith[testPerson](context.Background(), model, messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -90,7 +90,7 @@ func TestGenerate_NoToolCall(t *testing.T) {
|
|||||||
})
|
})
|
||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
|
|
||||||
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
_, _, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -111,7 +111,7 @@ func TestGenerate_InvalidJSON(t *testing.T) {
|
|||||||
})
|
})
|
||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
|
|
||||||
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
_, _, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -143,7 +143,7 @@ func TestGenerate_NestedStruct(t *testing.T) {
|
|||||||
})
|
})
|
||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
|
|
||||||
result, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol")
|
result, _, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -170,7 +170,7 @@ func TestGenerate_WithOptions(t *testing.T) {
|
|||||||
})
|
})
|
||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
|
|
||||||
_, err := Generate[testPerson](context.Background(), model, "Tell me about Dave",
|
_, _, err := Generate[testPerson](context.Background(), model, "Tell me about Dave",
|
||||||
WithTemperature(0.5),
|
WithTemperature(0.5),
|
||||||
WithMaxTokens(200),
|
WithMaxTokens(200),
|
||||||
)
|
)
|
||||||
@@ -207,7 +207,7 @@ func TestGenerate_WithMiddleware(t *testing.T) {
|
|||||||
})
|
})
|
||||||
model := newMockModel(mp).WithMiddleware(mw)
|
model := newMockModel(mp).WithMiddleware(mw)
|
||||||
|
|
||||||
result, err := Generate[testPerson](context.Background(), model, "Tell me about Eve")
|
result, _, err := Generate[testPerson](context.Background(), model, "Tell me about Eve")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -231,7 +231,7 @@ func TestGenerate_WrongToolName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
model := newMockModel(mp)
|
model := newMockModel(mp)
|
||||||
|
|
||||||
_, err := Generate[testPerson](context.Background(), model, "Tell me about Frank")
|
_, _, err := Generate[testPerson](context.Background(), model, "Tell me about Frank")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -239,3 +239,44 @@ func TestGenerate_WrongToolName(t *testing.T) {
|
|||||||
t.Errorf("expected ErrNoStructuredOutput, got %v", err)
|
t.Errorf("expected ErrNoStructuredOutput, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerate_ReturnsUsage(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Grace","age":22}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 50,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalTokens: 70,
|
||||||
|
Details: map[string]int{
|
||||||
|
"reasoning_tokens": 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
result, usage, err := Generate[testPerson](context.Background(), model, "Tell me about Grace")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != "Grace" {
|
||||||
|
t.Errorf("expected name 'Grace', got %q", result.Name)
|
||||||
|
}
|
||||||
|
if usage == nil {
|
||||||
|
t.Fatal("expected usage, got nil")
|
||||||
|
}
|
||||||
|
if usage.InputTokens != 50 {
|
||||||
|
t.Errorf("expected input 50, got %d", usage.InputTokens)
|
||||||
|
}
|
||||||
|
if usage.OutputTokens != 20 {
|
||||||
|
t.Errorf("expected output 20, got %d", usage.OutputTokens)
|
||||||
|
}
|
||||||
|
if usage.Details["reasoning_tokens"] != 5 {
|
||||||
|
t.Errorf("expected reasoning_tokens=5, got %d", usage.Details["reasoning_tokens"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -59,12 +59,32 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
|
|||||||
|
|
||||||
var fullText strings.Builder
|
var fullText strings.Builder
|
||||||
var toolCalls []provider.ToolCall
|
var toolCalls []provider.ToolCall
|
||||||
|
var usage *provider.Usage
|
||||||
|
|
||||||
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
|
for resp, err := range cl.Models.GenerateContentStream(ctx, req.Model, contents, cfg) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("google stream error: %w", err)
|
return fmt.Errorf("google stream error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Track usage from the last chunk (final chunk has cumulative counts)
|
||||||
|
if resp.UsageMetadata != nil {
|
||||||
|
usage = &provider.Usage{
|
||||||
|
InputTokens: int(resp.UsageMetadata.PromptTokenCount),
|
||||||
|
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
||||||
|
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
||||||
|
}
|
||||||
|
details := map[string]int{}
|
||||||
|
if resp.UsageMetadata.CachedContentTokenCount > 0 {
|
||||||
|
details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount)
|
||||||
|
}
|
||||||
|
if resp.UsageMetadata.ThoughtsTokenCount > 0 {
|
||||||
|
details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount)
|
||||||
|
}
|
||||||
|
if len(details) > 0 {
|
||||||
|
usage.Details = details
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, c := range resp.Candidates {
|
for _, c := range resp.Candidates {
|
||||||
if c.Content == nil {
|
if c.Content == nil {
|
||||||
continue
|
continue
|
||||||
@@ -105,6 +125,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
|
|||||||
Response: &provider.Response{
|
Response: &provider.Response{
|
||||||
Text: fullText.String(),
|
Text: fullText.String(),
|
||||||
ToolCalls: toolCalls,
|
ToolCalls: toolCalls,
|
||||||
|
Usage: usage,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,6 +305,16 @@ func (p *Provider) convertResponse(resp *genai.GenerateContentResponse) (provide
|
|||||||
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
OutputTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
||||||
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
||||||
}
|
}
|
||||||
|
details := map[string]int{}
|
||||||
|
if resp.UsageMetadata.CachedContentTokenCount > 0 {
|
||||||
|
details[provider.UsageDetailCachedInputTokens] = int(resp.UsageMetadata.CachedContentTokenCount)
|
||||||
|
}
|
||||||
|
if resp.UsageMetadata.ThoughtsTokenCount > 0 {
|
||||||
|
details[provider.UsageDetailThoughtsTokens] = int(resp.UsageMetadata.ThoughtsTokenCount)
|
||||||
|
}
|
||||||
|
if len(details) > 0 {
|
||||||
|
res.Usage.Details = details
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
|
|||||||
@@ -177,6 +177,7 @@ func convertProviderResponse(resp provider.Response) Response {
|
|||||||
InputTokens: resp.Usage.InputTokens,
|
InputTokens: resp.Usage.InputTokens,
|
||||||
OutputTokens: resp.Usage.OutputTokens,
|
OutputTokens: resp.Usage.OutputTokens,
|
||||||
TotalTokens: resp.Usage.TotalTokens,
|
TotalTokens: resp.Usage.TotalTokens,
|
||||||
|
Details: resp.Usage.Details,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ type UsageTracker struct {
|
|||||||
TotalInput int64
|
TotalInput int64
|
||||||
TotalOutput int64
|
TotalOutput int64
|
||||||
TotalRequests int64
|
TotalRequests int64
|
||||||
|
TotalDetails map[string]int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add records usage from a single request.
|
// Add records usage from a single request.
|
||||||
@@ -94,6 +95,14 @@ func (ut *UsageTracker) Add(u *Usage) {
|
|||||||
ut.TotalInput += int64(u.InputTokens)
|
ut.TotalInput += int64(u.InputTokens)
|
||||||
ut.TotalOutput += int64(u.OutputTokens)
|
ut.TotalOutput += int64(u.OutputTokens)
|
||||||
ut.TotalRequests++
|
ut.TotalRequests++
|
||||||
|
if len(u.Details) > 0 {
|
||||||
|
if ut.TotalDetails == nil {
|
||||||
|
ut.TotalDetails = make(map[string]int64)
|
||||||
|
}
|
||||||
|
for k, v := range u.Details {
|
||||||
|
ut.TotalDetails[k] += int64(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Summary returns the accumulated totals.
|
// Summary returns the accumulated totals.
|
||||||
@@ -103,6 +112,20 @@ func (ut *UsageTracker) Summary() (input, output, requests int64) {
|
|||||||
return ut.TotalInput, ut.TotalOutput, ut.TotalRequests
|
return ut.TotalInput, ut.TotalOutput, ut.TotalRequests
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Details returns a copy of the accumulated detail totals.
|
||||||
|
func (ut *UsageTracker) Details() map[string]int64 {
|
||||||
|
ut.mu.Lock()
|
||||||
|
defer ut.mu.Unlock()
|
||||||
|
if ut.TotalDetails == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cp := make(map[string]int64, len(ut.TotalDetails))
|
||||||
|
for k, v := range ut.TotalDetails {
|
||||||
|
cp[k] = v
|
||||||
|
}
|
||||||
|
return cp
|
||||||
|
}
|
||||||
|
|
||||||
// WithUsageTracking returns middleware that accumulates token usage across calls.
|
// WithUsageTracking returns middleware that accumulates token usage across calls.
|
||||||
func WithUsageTracking(tracker *UsageTracker) Middleware {
|
func WithUsageTracking(tracker *UsageTracker) Middleware {
|
||||||
return func(next CompletionFunc) CompletionFunc {
|
return func(next CompletionFunc) CompletionFunc {
|
||||||
|
|||||||
@@ -280,3 +280,80 @@ func TestWithLogging_Error(t *testing.T) {
|
|||||||
t.Errorf("expected provider error, got %v", err)
|
t.Errorf("expected provider error, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageTracker_Details(t *testing.T) {
|
||||||
|
tracker := &UsageTracker{}
|
||||||
|
|
||||||
|
tracker.Add(&Usage{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
Details: map[string]int{
|
||||||
|
"cached_input_tokens": 20,
|
||||||
|
"reasoning_tokens": 10,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
tracker.Add(&Usage{
|
||||||
|
InputTokens: 80,
|
||||||
|
OutputTokens: 40,
|
||||||
|
TotalTokens: 120,
|
||||||
|
Details: map[string]int{
|
||||||
|
"cached_input_tokens": 15,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
details := tracker.Details()
|
||||||
|
if details == nil {
|
||||||
|
t.Fatal("expected details, got nil")
|
||||||
|
}
|
||||||
|
if details["cached_input_tokens"] != 35 {
|
||||||
|
t.Errorf("expected cached_input_tokens=35, got %d", details["cached_input_tokens"])
|
||||||
|
}
|
||||||
|
if details["reasoning_tokens"] != 10 {
|
||||||
|
t.Errorf("expected reasoning_tokens=10, got %d", details["reasoning_tokens"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify returned map is a copy
|
||||||
|
details["cached_input_tokens"] = 999
|
||||||
|
fresh := tracker.Details()
|
||||||
|
if fresh["cached_input_tokens"] != 35 {
|
||||||
|
t.Error("Details() did not return a copy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageTracker_Details_Nil(t *testing.T) {
|
||||||
|
tracker := &UsageTracker{}
|
||||||
|
tracker.Add(&Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15})
|
||||||
|
|
||||||
|
details := tracker.Details()
|
||||||
|
if details != nil {
|
||||||
|
t.Errorf("expected nil details for usage without details, got %v", details)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithUsageTracking_WithDetails(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
Text: "ok",
|
||||||
|
Usage: &provider.Usage{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
TotalTokens: 150,
|
||||||
|
Details: map[string]int{
|
||||||
|
"cached_input_tokens": 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
tracker := &UsageTracker{}
|
||||||
|
model := newMockModel(mp).WithMiddleware(WithUsageTracking(tracker))
|
||||||
|
|
||||||
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
details := tracker.Details()
|
||||||
|
if details["cached_input_tokens"] != 30 {
|
||||||
|
t.Errorf("expected cached_input_tokens=30, got %d", details["cached_input_tokens"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -58,15 +58,30 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
|
|||||||
|
|
||||||
cl := openai.NewClient(opts...)
|
cl := openai.NewClient(opts...)
|
||||||
oaiReq := p.buildRequest(req)
|
oaiReq := p.buildRequest(req)
|
||||||
|
oaiReq.StreamOptions = openai.ChatCompletionStreamOptionsParam{
|
||||||
|
IncludeUsage: openai.Bool(true),
|
||||||
|
}
|
||||||
|
|
||||||
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
|
stream := cl.Chat.Completions.NewStreaming(ctx, oaiReq)
|
||||||
|
|
||||||
var fullText strings.Builder
|
var fullText strings.Builder
|
||||||
var toolCalls []provider.ToolCall
|
var toolCalls []provider.ToolCall
|
||||||
toolCallArgs := map[int]*strings.Builder{}
|
toolCallArgs := map[int]*strings.Builder{}
|
||||||
|
var usage *provider.Usage
|
||||||
|
|
||||||
for stream.Next() {
|
for stream.Next() {
|
||||||
chunk := stream.Current()
|
chunk := stream.Current()
|
||||||
|
|
||||||
|
// Capture usage from the final chunk (present when StreamOptions.IncludeUsage is true)
|
||||||
|
if chunk.Usage.TotalTokens > 0 {
|
||||||
|
usage = &provider.Usage{
|
||||||
|
InputTokens: int(chunk.Usage.PromptTokens),
|
||||||
|
OutputTokens: int(chunk.Usage.CompletionTokens),
|
||||||
|
TotalTokens: int(chunk.Usage.TotalTokens),
|
||||||
|
Details: extractUsageDetails(chunk.Usage),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, choice := range chunk.Choices {
|
for _, choice := range chunk.Choices {
|
||||||
// Text delta
|
// Text delta
|
||||||
if choice.Delta.Content != "" {
|
if choice.Delta.Content != "" {
|
||||||
@@ -138,6 +153,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request, events chan
|
|||||||
Response: &provider.Response{
|
Response: &provider.Response{
|
||||||
Text: fullText.String(),
|
Text: fullText.String(),
|
||||||
ToolCalls: toolCalls,
|
ToolCalls: toolCalls,
|
||||||
|
Usage: usage,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,6 +379,7 @@ func (p *Provider) convertResponse(resp *openai.ChatCompletion) provider.Respons
|
|||||||
OutputTokens: int(resp.Usage.CompletionTokens),
|
OutputTokens: int(resp.Usage.CompletionTokens),
|
||||||
TotalTokens: int(resp.Usage.TotalTokens),
|
TotalTokens: int(resp.Usage.TotalTokens),
|
||||||
}
|
}
|
||||||
|
res.Usage.Details = extractUsageDetails(resp.Usage)
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
@@ -381,6 +398,27 @@ func audioFormat(contentType string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractUsageDetails extracts provider-specific detail tokens from an OpenAI CompletionUsage.
|
||||||
|
func extractUsageDetails(usage openai.CompletionUsage) map[string]int {
|
||||||
|
details := map[string]int{}
|
||||||
|
if usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||||
|
details[provider.UsageDetailReasoningTokens] = int(usage.CompletionTokensDetails.ReasoningTokens)
|
||||||
|
}
|
||||||
|
if usage.CompletionTokensDetails.AudioTokens > 0 {
|
||||||
|
details[provider.UsageDetailAudioOutputTokens] = int(usage.CompletionTokensDetails.AudioTokens)
|
||||||
|
}
|
||||||
|
if usage.PromptTokensDetails.CachedTokens > 0 {
|
||||||
|
details[provider.UsageDetailCachedInputTokens] = int(usage.PromptTokensDetails.CachedTokens)
|
||||||
|
}
|
||||||
|
if usage.PromptTokensDetails.AudioTokens > 0 {
|
||||||
|
details[provider.UsageDetailAudioInputTokens] = int(usage.PromptTokensDetails.AudioTokens)
|
||||||
|
}
|
||||||
|
if len(details) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return details
|
||||||
|
}
|
||||||
|
|
||||||
// audioFormatFromURL guesses the audio format from a URL's file extension.
|
// audioFormatFromURL guesses the audio format from a URL's file extension.
|
||||||
func audioFormatFromURL(u string) string {
|
func audioFormatFromURL(u string) string {
|
||||||
ext := strings.ToLower(path.Ext(u))
|
ext := strings.ToLower(path.Ext(u))
|
||||||
|
|||||||
83
v2/pricing.go
Normal file
83
v2/pricing.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// ModelPricing defines per-token pricing for a model.
|
||||||
|
type ModelPricing struct {
|
||||||
|
InputPricePerToken float64 // USD per input token
|
||||||
|
OutputPricePerToken float64 // USD per output token
|
||||||
|
CachedInputPricePerToken float64 // USD per cached input token (0 = same as input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cost computes the total USD cost from a Usage.
|
||||||
|
// When CachedInputPricePerToken is set and the usage includes cached_input_tokens,
|
||||||
|
// those tokens are charged at the cached rate instead of the regular input rate.
|
||||||
|
func (mp ModelPricing) Cost(u *Usage) float64 {
|
||||||
|
if u == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
inputTokens := u.InputTokens
|
||||||
|
cachedTokens := 0
|
||||||
|
if u.Details != nil {
|
||||||
|
cachedTokens = u.Details[UsageDetailCachedInputTokens]
|
||||||
|
}
|
||||||
|
|
||||||
|
var cost float64
|
||||||
|
|
||||||
|
if mp.CachedInputPricePerToken > 0 && cachedTokens > 0 {
|
||||||
|
regularInput := inputTokens - cachedTokens
|
||||||
|
if regularInput < 0 {
|
||||||
|
regularInput = 0
|
||||||
|
}
|
||||||
|
cost += float64(regularInput) * mp.InputPricePerToken
|
||||||
|
cost += float64(cachedTokens) * mp.CachedInputPricePerToken
|
||||||
|
} else {
|
||||||
|
cost += float64(inputTokens) * mp.InputPricePerToken
|
||||||
|
}
|
||||||
|
|
||||||
|
cost += float64(u.OutputTokens) * mp.OutputPricePerToken
|
||||||
|
|
||||||
|
return cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// PricingRegistry maps model names to their pricing.
|
||||||
|
// Callers populate it with the models and prices relevant to their use case.
|
||||||
|
type PricingRegistry struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
models map[string]ModelPricing
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPricingRegistry creates an empty pricing registry.
|
||||||
|
func NewPricingRegistry() *PricingRegistry {
|
||||||
|
return &PricingRegistry{
|
||||||
|
models: make(map[string]ModelPricing),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set registers pricing for a model.
|
||||||
|
func (pr *PricingRegistry) Set(model string, pricing ModelPricing) {
|
||||||
|
pr.mu.Lock()
|
||||||
|
defer pr.mu.Unlock()
|
||||||
|
pr.models[model] = pricing
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has returns true if pricing is registered for the given model.
|
||||||
|
func (pr *PricingRegistry) Has(model string) bool {
|
||||||
|
pr.mu.RLock()
|
||||||
|
defer pr.mu.RUnlock()
|
||||||
|
_, ok := pr.models[model]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cost computes the USD cost for the given model and usage.
|
||||||
|
// Returns 0 if the model is not registered.
|
||||||
|
func (pr *PricingRegistry) Cost(model string, u *Usage) float64 {
|
||||||
|
pr.mu.RLock()
|
||||||
|
pricing, ok := pr.models[model]
|
||||||
|
pr.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return pricing.Cost(u)
|
||||||
|
}
|
||||||
128
v2/pricing_test.go
Normal file
128
v2/pricing_test.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModelPricing_Cost(t *testing.T) {
|
||||||
|
pricing := ModelPricing{
|
||||||
|
InputPricePerToken: 0.000003, // $3/MTok
|
||||||
|
OutputPricePerToken: 0.000015, // $15/MTok
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &Usage{
|
||||||
|
InputTokens: 1000,
|
||||||
|
OutputTokens: 500,
|
||||||
|
TotalTokens: 1500,
|
||||||
|
}
|
||||||
|
|
||||||
|
cost := pricing.Cost(usage)
|
||||||
|
expected := 1000*0.000003 + 500*0.000015
|
||||||
|
if math.Abs(cost-expected) > 1e-10 {
|
||||||
|
t.Errorf("expected cost %f, got %f", expected, cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelPricing_Cost_WithCachedTokens(t *testing.T) {
|
||||||
|
pricing := ModelPricing{
|
||||||
|
InputPricePerToken: 0.000003, // $3/MTok
|
||||||
|
OutputPricePerToken: 0.000015, // $15/MTok
|
||||||
|
CachedInputPricePerToken: 0.0000015, // $1.50/MTok (50% discount)
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &Usage{
|
||||||
|
InputTokens: 1000,
|
||||||
|
OutputTokens: 500,
|
||||||
|
TotalTokens: 1500,
|
||||||
|
Details: map[string]int{
|
||||||
|
UsageDetailCachedInputTokens: 400,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cost := pricing.Cost(usage)
|
||||||
|
// 600 regular input tokens + 400 cached tokens + 500 output tokens
|
||||||
|
expected := 600*0.000003 + 400*0.0000015 + 500*0.000015
|
||||||
|
if math.Abs(cost-expected) > 1e-10 {
|
||||||
|
t.Errorf("expected cost %f, got %f", expected, cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelPricing_Cost_NilUsage(t *testing.T) {
|
||||||
|
pricing := ModelPricing{
|
||||||
|
InputPricePerToken: 0.000003,
|
||||||
|
OutputPricePerToken: 0.000015,
|
||||||
|
}
|
||||||
|
|
||||||
|
cost := pricing.Cost(nil)
|
||||||
|
if cost != 0 {
|
||||||
|
t.Errorf("expected 0 for nil usage, got %f", cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelPricing_Cost_NoCachedPrice(t *testing.T) {
|
||||||
|
// When CachedInputPricePerToken is 0, all input tokens use InputPricePerToken
|
||||||
|
pricing := ModelPricing{
|
||||||
|
InputPricePerToken: 0.000003,
|
||||||
|
OutputPricePerToken: 0.000015,
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &Usage{
|
||||||
|
InputTokens: 1000,
|
||||||
|
OutputTokens: 500,
|
||||||
|
TotalTokens: 1500,
|
||||||
|
Details: map[string]int{
|
||||||
|
UsageDetailCachedInputTokens: 400,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cost := pricing.Cost(usage)
|
||||||
|
expected := 1000*0.000003 + 500*0.000015
|
||||||
|
if math.Abs(cost-expected) > 1e-10 {
|
||||||
|
t.Errorf("expected cost %f, got %f", expected, cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPricingRegistry(t *testing.T) {
|
||||||
|
registry := NewPricingRegistry()
|
||||||
|
|
||||||
|
registry.Set("gpt-4o", ModelPricing{
|
||||||
|
InputPricePerToken: 0.0000025,
|
||||||
|
OutputPricePerToken: 0.00001,
|
||||||
|
})
|
||||||
|
|
||||||
|
if !registry.Has("gpt-4o") {
|
||||||
|
t.Error("expected Has('gpt-4o') to be true")
|
||||||
|
}
|
||||||
|
if registry.Has("gpt-3.5-turbo") {
|
||||||
|
t.Error("expected Has('gpt-3.5-turbo') to be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &Usage{InputTokens: 1000, OutputTokens: 200, TotalTokens: 1200}
|
||||||
|
|
||||||
|
cost := registry.Cost("gpt-4o", usage)
|
||||||
|
expected := 1000*0.0000025 + 200*0.00001
|
||||||
|
if math.Abs(cost-expected) > 1e-10 {
|
||||||
|
t.Errorf("expected cost %f, got %f", expected, cost)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown model returns 0
|
||||||
|
cost = registry.Cost("unknown-model", usage)
|
||||||
|
if cost != 0 {
|
||||||
|
t.Errorf("expected 0 for unknown model, got %f", cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPricingRegistry_Override(t *testing.T) {
|
||||||
|
registry := NewPricingRegistry()
|
||||||
|
|
||||||
|
registry.Set("model-a", ModelPricing{InputPricePerToken: 0.001, OutputPricePerToken: 0.002})
|
||||||
|
registry.Set("model-a", ModelPricing{InputPricePerToken: 0.003, OutputPricePerToken: 0.004})
|
||||||
|
|
||||||
|
usage := &Usage{InputTokens: 100, OutputTokens: 50, TotalTokens: 150}
|
||||||
|
cost := registry.Cost("model-a", usage)
|
||||||
|
expected := 100*0.003 + 50*0.004
|
||||||
|
if math.Abs(cost-expected) > 1e-10 {
|
||||||
|
t.Errorf("expected overridden cost %f, got %f", expected, cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -64,8 +64,19 @@ type Usage struct {
|
|||||||
InputTokens int
|
InputTokens int
|
||||||
OutputTokens int
|
OutputTokens int
|
||||||
TotalTokens int
|
TotalTokens int
|
||||||
|
Details map[string]int // provider-specific breakdown (e.g., cached, reasoning tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Standardized detail keys for provider-specific token breakdowns.
|
||||||
|
const (
|
||||||
|
UsageDetailReasoningTokens = "reasoning_tokens"
|
||||||
|
UsageDetailCachedInputTokens = "cached_input_tokens"
|
||||||
|
UsageDetailCacheCreationTokens = "cache_creation_tokens"
|
||||||
|
UsageDetailAudioInputTokens = "audio_input_tokens"
|
||||||
|
UsageDetailAudioOutputTokens = "audio_output_tokens"
|
||||||
|
UsageDetailThoughtsTokens = "thoughts_tokens"
|
||||||
|
)
|
||||||
|
|
||||||
// StreamEventType identifies the kind of stream event.
|
// StreamEventType identifies the kind of stream event.
|
||||||
type StreamEventType int
|
type StreamEventType int
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package llm
|
package llm
|
||||||
|
|
||||||
|
import "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
|
||||||
// Response represents the result of a completion request.
|
// Response represents the result of a completion request.
|
||||||
type Response struct {
|
type Response struct {
|
||||||
// Text is the assistant's text content. Empty if only tool calls.
|
// Text is the assistant's text content. Empty if only tool calls.
|
||||||
@@ -31,4 +33,45 @@ type Usage struct {
|
|||||||
InputTokens int
|
InputTokens int
|
||||||
OutputTokens int
|
OutputTokens int
|
||||||
TotalTokens int
|
TotalTokens int
|
||||||
|
Details map[string]int // provider-specific breakdown (e.g., cached, reasoning tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addUsage merges usage u into the receiver, accumulating token counts and details.
|
||||||
|
// If the receiver is nil, it returns a copy of u. If u is nil, it returns the receiver unchanged.
|
||||||
|
func addUsage(total *Usage, u *Usage) *Usage {
|
||||||
|
if u == nil {
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
if total == nil {
|
||||||
|
cp := *u
|
||||||
|
if u.Details != nil {
|
||||||
|
cp.Details = make(map[string]int, len(u.Details))
|
||||||
|
for k, v := range u.Details {
|
||||||
|
cp.Details[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &cp
|
||||||
|
}
|
||||||
|
total.InputTokens += u.InputTokens
|
||||||
|
total.OutputTokens += u.OutputTokens
|
||||||
|
total.TotalTokens += u.TotalTokens
|
||||||
|
if u.Details != nil {
|
||||||
|
if total.Details == nil {
|
||||||
|
total.Details = make(map[string]int, len(u.Details))
|
||||||
|
}
|
||||||
|
for k, v := range u.Details {
|
||||||
|
total.Details[k] += v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-export detail key constants from provider package for convenience.
|
||||||
|
const (
|
||||||
|
UsageDetailReasoningTokens = provider.UsageDetailReasoningTokens
|
||||||
|
UsageDetailCachedInputTokens = provider.UsageDetailCachedInputTokens
|
||||||
|
UsageDetailCacheCreationTokens = provider.UsageDetailCacheCreationTokens
|
||||||
|
UsageDetailAudioInputTokens = provider.UsageDetailAudioInputTokens
|
||||||
|
UsageDetailAudioOutputTokens = provider.UsageDetailAudioOutputTokens
|
||||||
|
UsageDetailThoughtsTokens = provider.UsageDetailThoughtsTokens
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user