Introduces v2/agent with a minimal API: Agent, New(), Run(), and AsTool(). Agents wrap a model + system prompt + tools. AsTool() turns an agent into a llm.Tool, enabling parent agents to delegate to sub-agents through the normal tool-call loop — no channels, pools, or orchestration needed. Also exports NewClient(provider.Provider) for custom provider integration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
216 lines
6.0 KiB
Go
216 lines
6.0 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
|
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
|
)
|
|
|
|
func TestModel_Complete(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{
|
|
Text: "Hello!",
|
|
Usage: &provider.Usage{
|
|
InputTokens: 10,
|
|
OutputTokens: 5,
|
|
TotalTokens: 15,
|
|
},
|
|
})
|
|
model := newMockModel(mp)
|
|
|
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.Text != "Hello!" {
|
|
t.Errorf("expected text 'Hello!', got %q", resp.Text)
|
|
}
|
|
if resp.Usage == nil {
|
|
t.Fatal("expected usage, got nil")
|
|
}
|
|
if resp.Usage.InputTokens != 10 {
|
|
t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens)
|
|
}
|
|
if resp.Usage.OutputTokens != 5 {
|
|
t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens)
|
|
}
|
|
if resp.Usage.TotalTokens != 15 {
|
|
t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens)
|
|
}
|
|
}
|
|
|
|
func TestModel_Complete_WithOptions(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
|
model := newMockModel(mp)
|
|
|
|
temp := 0.7
|
|
maxTok := 100
|
|
topP := 0.9
|
|
|
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")},
|
|
WithTemperature(temp),
|
|
WithMaxTokens(maxTok),
|
|
WithTopP(topP),
|
|
WithStop("STOP", "END"),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
req := mp.lastRequest()
|
|
if req.Temperature == nil || *req.Temperature != temp {
|
|
t.Errorf("expected temperature %v, got %v", temp, req.Temperature)
|
|
}
|
|
if req.MaxTokens == nil || *req.MaxTokens != maxTok {
|
|
t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens)
|
|
}
|
|
if req.TopP == nil || *req.TopP != topP {
|
|
t.Errorf("expected topP %v, got %v", topP, req.TopP)
|
|
}
|
|
if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" {
|
|
t.Errorf("expected stop [STOP END], got %v", req.Stop)
|
|
}
|
|
}
|
|
|
|
func TestModel_Complete_Error(t *testing.T) {
|
|
wantErr := errors.New("provider error")
|
|
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
|
return provider.Response{}, wantErr
|
|
})
|
|
model := newMockModel(mp)
|
|
|
|
_, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if !errors.Is(err, wantErr) {
|
|
t.Errorf("expected error %v, got %v", wantErr, err)
|
|
}
|
|
}
|
|
|
|
func TestModel_Complete_WithTools(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "done"})
|
|
model := newMockModel(mp)
|
|
|
|
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
|
return "hello", nil
|
|
})
|
|
tb := NewToolBox(tool)
|
|
|
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
req := mp.lastRequest()
|
|
if len(req.Tools) != 1 {
|
|
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
|
}
|
|
if req.Tools[0].Name != "greet" {
|
|
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
|
|
}
|
|
if req.Tools[0].Description != "Says hello" {
|
|
t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description)
|
|
}
|
|
}
|
|
|
|
func TestClient_Model(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "hi"})
|
|
client := NewClient(mp)
|
|
model := client.Model("test-model")
|
|
|
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.Text != "hi" {
|
|
t.Errorf("expected 'hi', got %q", resp.Text)
|
|
}
|
|
|
|
req := mp.lastRequest()
|
|
if req.Model != "test-model" {
|
|
t.Errorf("expected model 'test-model', got %q", req.Model)
|
|
}
|
|
}
|
|
|
|
func TestClient_WithMiddleware(t *testing.T) {
|
|
var called bool
|
|
mw := func(next CompletionFunc) CompletionFunc {
|
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
|
called = true
|
|
return next(ctx, model, messages, cfg)
|
|
}
|
|
}
|
|
|
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
|
client := NewClient(mp).WithMiddleware(mw)
|
|
model := client.Model("test-model")
|
|
|
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !called {
|
|
t.Error("middleware was not called")
|
|
}
|
|
}
|
|
|
|
func TestModel_WithMiddleware(t *testing.T) {
|
|
var order []string
|
|
mw1 := func(next CompletionFunc) CompletionFunc {
|
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
|
order = append(order, "mw1")
|
|
return next(ctx, model, messages, cfg)
|
|
}
|
|
}
|
|
mw2 := func(next CompletionFunc) CompletionFunc {
|
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
|
order = append(order, "mw2")
|
|
return next(ctx, model, messages, cfg)
|
|
}
|
|
}
|
|
|
|
mp := newMockProvider(provider.Response{Text: "ok"})
|
|
model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2)
|
|
|
|
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" {
|
|
t.Errorf("expected middleware order [mw1 mw2], got %v", order)
|
|
}
|
|
}
|
|
|
|
func TestModel_Complete_NoUsage(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "no usage"})
|
|
model := newMockModel(mp)
|
|
|
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp.Usage != nil {
|
|
t.Errorf("expected nil usage, got %+v", resp.Usage)
|
|
}
|
|
}
|
|
|
|
func TestModel_Complete_ResponseMessage(t *testing.T) {
|
|
mp := newMockProvider(provider.Response{Text: "response text"})
|
|
model := newMockModel(mp)
|
|
|
|
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
msg := resp.Message()
|
|
if msg.Role != RoleAssistant {
|
|
t.Errorf("expected role assistant, got %v", msg.Role)
|
|
}
|
|
if msg.Content.Text != "response text" {
|
|
t.Errorf("expected text 'response text', got %q", msg.Content.Text)
|
|
}
|
|
}
|