Files
go-llm/v2/model_test.go
Steve Dudenhoeffer 6a7eeef619
All checks were successful
CI / Lint (push) Successful in 9m36s
CI / V2 Module (push) Successful in 11m33s
CI / Root Module (push) Successful in 11m35s
Add comprehensive test suite for v2 module with mock provider
Cover all core library logic (Client, Model, Chat, middleware, streaming,
message conversion, request building) using a configurable mock provider
that avoids real API calls. ~50 tests across 7 files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 22:00:49 -05:00

216 lines
6.0 KiB
Go

package llm
import (
"context"
"errors"
"testing"
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
)
func TestModel_Complete(t *testing.T) {
mp := newMockProvider(provider.Response{
Text: "Hello!",
Usage: &provider.Usage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "Hello!" {
t.Errorf("expected text 'Hello!', got %q", resp.Text)
}
if resp.Usage == nil {
t.Fatal("expected usage, got nil")
}
if resp.Usage.InputTokens != 10 {
t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens)
}
if resp.Usage.OutputTokens != 5 {
t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens)
}
if resp.Usage.TotalTokens != 15 {
t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens)
}
}
func TestModel_Complete_WithOptions(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp)
temp := 0.7
maxTok := 100
topP := 0.9
_, err := model.Complete(context.Background(), []Message{UserMessage("test")},
WithTemperature(temp),
WithMaxTokens(maxTok),
WithTopP(topP),
WithStop("STOP", "END"),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if req.Temperature == nil || *req.Temperature != temp {
t.Errorf("expected temperature %v, got %v", temp, req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != maxTok {
t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens)
}
if req.TopP == nil || *req.TopP != topP {
t.Errorf("expected topP %v, got %v", topP, req.TopP)
}
if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" {
t.Errorf("expected stop [STOP END], got %v", req.Stop)
}
}
func TestModel_Complete_Error(t *testing.T) {
wantErr := errors.New("provider error")
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
return provider.Response{}, wantErr
})
model := newMockModel(mp)
_, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, wantErr) {
t.Errorf("expected error %v, got %v", wantErr, err)
}
}
func TestModel_Complete_WithTools(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "done"})
model := newMockModel(mp)
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
return "hello", nil
})
tb := NewToolBox(tool)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req := mp.lastRequest()
if len(req.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
}
if req.Tools[0].Name != "greet" {
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
}
if req.Tools[0].Description != "Says hello" {
t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description)
}
}
func TestClient_Model(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "hi"})
client := newClient(mp)
model := client.Model("test-model")
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Text != "hi" {
t.Errorf("expected 'hi', got %q", resp.Text)
}
req := mp.lastRequest()
if req.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", req.Model)
}
}
func TestClient_WithMiddleware(t *testing.T) {
var called bool
mw := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
called = true
return next(ctx, model, messages, cfg)
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
client := newClient(mp).WithMiddleware(mw)
model := client.Model("test-model")
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !called {
t.Error("middleware was not called")
}
}
func TestModel_WithMiddleware(t *testing.T) {
var order []string
mw1 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw1")
return next(ctx, model, messages, cfg)
}
}
mw2 := func(next CompletionFunc) CompletionFunc {
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
order = append(order, "mw2")
return next(ctx, model, messages, cfg)
}
}
mp := newMockProvider(provider.Response{Text: "ok"})
model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2)
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" {
t.Errorf("expected middleware order [mw1 mw2], got %v", order)
}
}
func TestModel_Complete_NoUsage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "no usage"})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Usage != nil {
t.Errorf("expected nil usage, got %+v", resp.Usage)
}
}
func TestModel_Complete_ResponseMessage(t *testing.T) {
mp := newMockProvider(provider.Response{Text: "response text"})
model := newMockModel(mp)
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
msg := resp.Message()
if msg.Role != RoleAssistant {
t.Errorf("expected role assistant, got %v", msg.Role)
}
if msg.Content.Text != "response text" {
t.Errorf("expected text 'response text', got %q", msg.Content.Text)
}
}