Files
go-llm/v2/model_test.go
Steve Dudenhoeffer 87ec56a2be
All checks were successful
CI / Lint (push) Successful in 9m46s
CI / V2 Module (push) Successful in 12m5s
CI / Root Module (push) Successful in 12m6s
Add agent sub-package for composable LLM agents
Introduces v2/agent with a minimal API: Agent, New(), Run(), and AsTool().
Agents wrap a model + system prompt + tools. AsTool() turns an agent into
a llm.Tool, enabling parent agents to delegate to sub-agents through the
normal tool-call loop — no channels, pools, or orchestration needed.

Also exports NewClient(provider.Provider) for custom provider integration.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 23:17:19 -05:00

216 lines
6.0 KiB
Go

package llm
import (
"context"
"errors"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestModel_Complete(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "Hello!",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "Hello!" {
t.Errorf("expected text 'Hello!', got %q", resp.Text)
}
if resp.Usage == nil {
t.Fatal("expected usage, got nil")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens)
}
if resp.Usage.OutputTokens != 5 {
t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens)
}
if resp.Usage.TotalTokens != 15 {
t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens)
}
}
func TestModel_Complete_WithOptions(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
temp := 0.7
maxTok := 100
topP := 0.9
_, err := model.Complete(context.Background(), []Message{UserMessage("test")},
WithTemperature(temp),
WithMaxTokens(maxTok),
WithTopP(topP),
WithStop("STOP", "END"),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if req.Temperature == nil || *req.Temperature != temp {
t.Errorf("expected temperature %v, got %v", temp, req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != maxTok {
t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens)
}
if req.TopP == nil || *req.TopP != topP {
t.Errorf("expected topP %v, got %v", topP, req.TopP)
}
if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" {
t.Errorf("expected stop [STOP END], got %v", req.Stop)
}
}
func TestModel_Complete_Error(t *testing.T) {
wantErr := errors.New("provider error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, wantErr
})
model := newMockModel(mp)
_, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, wantErr) {
t.Errorf("expected error %v, got %v", wantErr, err)
}
}
func TestModel_Complete_WithTools(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "done"})
model := newMockModel(mp)
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello", nil
})
tb := NewToolBox(tool)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if len(req.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
}
if req.Tools[0].Name != "greet" {
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
}
if req.Tools[0].Description != "Says hello" {
t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description)
}
}
func TestClient_Model(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "hi"})
client := NewClient(mp)
model := client.Model("test-model")
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "hi" {
t.Errorf("expected 'hi', got %q", resp.Text)
}
req := mp.lastRequest()
if req.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", req.Model)
}
}
func TestClient_WithMiddleware(t *testing.T) {
var called bool
mw := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
called = true
return next(ctx, model, messages, cfg)
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
client := NewClient(mp).WithMiddleware(mw)
model := client.Model("test-model")
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !called {
t.Error("middleware was not called")
}
}
func TestModel_WithMiddleware(t *testing.T) {
var order []string
mw1 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw1")
return next(ctx, model, messages, cfg)
}
}
mw2 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw2")
return next(ctx, model, messages, cfg)
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" {
t.Errorf("expected middleware order [mw1 mw2], got %v", order)
}
}
func TestModel_Complete_NoUsage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "no usage"})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Usage != nil {
t.Errorf("expected nil usage, got %+v", resp.Usage)
}
}
func TestModel_Complete_ResponseMessage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "response text"})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
msg := resp.Message()
if msg.Role != RoleAssistant {
t.Errorf("expected role assistant, got %v", msg.Role)
}
if msg.Content.Text != "response text" {
t.Errorf("expected text 'response text', got %q", msg.Content.Text)
}
}