Add comprehensive test suite for v2 module with mock provider
Cover all core library logic (Client, Model, Chat, middleware, streaming, message conversion, request building) using a configurable mock provider that avoids real API calls. ~50 tests across 7 files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
215
v2/model_test.go
Normal file
215
v2/model_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||
)
|
||||
|
||||
func TestModel_Complete(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{
|
||||
Text: "Hello!",
|
||||
Usage: &provider.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalTokens: 15,
|
||||
},
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "Hello!" {
|
||||
t.Errorf("expected text 'Hello!', got %q", resp.Text)
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("expected usage, got nil")
|
||||
}
|
||||
if resp.Usage.InputTokens != 10 {
|
||||
t.Errorf("expected input tokens 10, got %d", resp.Usage.InputTokens)
|
||||
}
|
||||
if resp.Usage.OutputTokens != 5 {
|
||||
t.Errorf("expected output tokens 5, got %d", resp.Usage.OutputTokens)
|
||||
}
|
||||
if resp.Usage.TotalTokens != 15 {
|
||||
t.Errorf("expected total tokens 15, got %d", resp.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_WithOptions(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp)
|
||||
|
||||
temp := 0.7
|
||||
maxTok := 100
|
||||
topP := 0.9
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")},
|
||||
WithTemperature(temp),
|
||||
WithMaxTokens(maxTok),
|
||||
WithTopP(topP),
|
||||
WithStop("STOP", "END"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
req := mp.lastRequest()
|
||||
if req.Temperature == nil || *req.Temperature != temp {
|
||||
t.Errorf("expected temperature %v, got %v", temp, req.Temperature)
|
||||
}
|
||||
if req.MaxTokens == nil || *req.MaxTokens != maxTok {
|
||||
t.Errorf("expected maxTokens %v, got %v", maxTok, req.MaxTokens)
|
||||
}
|
||||
if req.TopP == nil || *req.TopP != topP {
|
||||
t.Errorf("expected topP %v, got %v", topP, req.TopP)
|
||||
}
|
||||
if len(req.Stop) != 2 || req.Stop[0] != "STOP" || req.Stop[1] != "END" {
|
||||
t.Errorf("expected stop [STOP END], got %v", req.Stop)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_Error(t *testing.T) {
|
||||
wantErr := errors.New("provider error")
|
||||
mp := newMockProviderFunc(func(ctx context.Context, req provider.Request) (provider.Response, error) {
|
||||
return provider.Response{}, wantErr
|
||||
})
|
||||
model := newMockModel(mp)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("Hi")})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Errorf("expected error %v, got %v", wantErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_WithTools(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "done"})
|
||||
model := newMockModel(mp)
|
||||
|
||||
tool := DefineSimple("greet", "Says hello", func(ctx context.Context) (string, error) {
|
||||
return "hello", nil
|
||||
})
|
||||
tb := NewToolBox(tool)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")}, WithTools(tb))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
req := mp.lastRequest()
|
||||
if len(req.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
||||
}
|
||||
if req.Tools[0].Name != "greet" {
|
||||
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
|
||||
}
|
||||
if req.Tools[0].Description != "Says hello" {
|
||||
t.Errorf("expected tool description 'Says hello', got %q", req.Tools[0].Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Model(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "hi"})
|
||||
client := newClient(mp)
|
||||
model := client.Model("test-model")
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Text != "hi" {
|
||||
t.Errorf("expected 'hi', got %q", resp.Text)
|
||||
}
|
||||
|
||||
req := mp.lastRequest()
|
||||
if req.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_WithMiddleware(t *testing.T) {
|
||||
var called bool
|
||||
mw := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
called = true
|
||||
return next(ctx, model, messages, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
client := newClient(mp).WithMiddleware(mw)
|
||||
model := client.Model("test-model")
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("middleware was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_WithMiddleware(t *testing.T) {
|
||||
var order []string
|
||||
mw1 := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
order = append(order, "mw1")
|
||||
return next(ctx, model, messages, cfg)
|
||||
}
|
||||
}
|
||||
mw2 := func(next CompletionFunc) CompletionFunc {
|
||||
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||
order = append(order, "mw2")
|
||||
return next(ctx, model, messages, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
mp := newMockProvider(provider.Response{Text: "ok"})
|
||||
model := newMockModel(mp).WithMiddleware(mw1).WithMiddleware(mw2)
|
||||
|
||||
_, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(order) != 2 || order[0] != "mw1" || order[1] != "mw2" {
|
||||
t.Errorf("expected middleware order [mw1 mw2], got %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_NoUsage(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "no usage"})
|
||||
model := newMockModel(mp)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Errorf("expected nil usage, got %+v", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_Complete_ResponseMessage(t *testing.T) {
|
||||
mp := newMockProvider(provider.Response{Text: "response text"})
|
||||
model := newMockModel(mp)
|
||||
|
||||
resp, err := model.Complete(context.Background(), []Message{UserMessage("test")})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
msg := resp.Message()
|
||||
if msg.Role != RoleAssistant {
|
||||
t.Errorf("expected role assistant, got %v", msg.Role)
|
||||
}
|
||||
if msg.Content.Text != "response text" {
|
||||
t.Errorf("expected text 'response text', got %q", msg.Content.Text)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user