Cover all core library logic (Client, Model, Chat, middleware, streaming, message conversion, request building) using a configurable mock provider that avoids real API calls. ~50 tests across 7 files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
138 lines
3.4 KiB
Go
138 lines
3.4 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
)
|
|
|
|
func TestWithTemperature(t *testing.T) {
|
|
cfg := &requestConfig{}
|
|
WithTemperature(0.7)(cfg)
|
|
if cfg.temperature == nil || *cfg.temperature != 0.7 {
|
|
t.Errorf("expected temperature 0.7, got %v", cfg.temperature)
|
|
}
|
|
}
|
|
|
|
func TestWithMaxTokens(t *testing.T) {
|
|
cfg := &requestConfig{}
|
|
WithMaxTokens(256)(cfg)
|
|
if cfg.maxTokens == nil || *cfg.maxTokens != 256 {
|
|
t.Errorf("expected maxTokens 256, got %v", cfg.maxTokens)
|
|
}
|
|
}
|
|
|
|
func TestWithTopP(t *testing.T) {
|
|
cfg := &requestConfig{}
|
|
WithTopP(0.95)(cfg)
|
|
if cfg.topP == nil || *cfg.topP != 0.95 {
|
|
t.Errorf("expected topP 0.95, got %v", cfg.topP)
|
|
}
|
|
}
|
|
|
|
func TestWithStop(t *testing.T) {
|
|
cfg := &requestConfig{}
|
|
WithStop("END", "STOP", "###")(cfg)
|
|
if len(cfg.stop) != 3 {
|
|
t.Fatalf("expected 3 stop sequences, got %d", len(cfg.stop))
|
|
}
|
|
if cfg.stop[0] != "END" || cfg.stop[1] != "STOP" || cfg.stop[2] != "###" {
|
|
t.Errorf("unexpected stop sequences: %v", cfg.stop)
|
|
}
|
|
}
|
|
|
|
func TestWithTools(t *testing.T) {
|
|
tool := DefineSimple("test", "A test tool", func(ctx context.Context) (string, error) {
|
|
return "ok", nil
|
|
})
|
|
tb := NewToolBox(tool)
|
|
|
|
cfg := &requestConfig{}
|
|
WithTools(tb)(cfg)
|
|
if cfg.tools == nil {
|
|
t.Fatal("expected tools to be set")
|
|
}
|
|
if len(cfg.tools.AllTools()) != 1 {
|
|
t.Errorf("expected 1 tool, got %d", len(cfg.tools.AllTools()))
|
|
}
|
|
}
|
|
|
|
func TestBuildProviderRequest(t *testing.T) {
|
|
tool := DefineSimple("greet", "Greets", func(ctx context.Context) (string, error) {
|
|
return "hi", nil
|
|
})
|
|
tb := NewToolBox(tool)
|
|
|
|
temp := 0.8
|
|
maxTok := 512
|
|
topP := 0.9
|
|
|
|
cfg := &requestConfig{
|
|
tools: tb,
|
|
temperature: &temp,
|
|
maxTokens: &maxTok,
|
|
topP: &topP,
|
|
stop: []string{"END"},
|
|
}
|
|
|
|
msgs := []Message{
|
|
SystemMessage("be nice"),
|
|
UserMessage("hello"),
|
|
}
|
|
|
|
req := buildProviderRequest("test-model", msgs, cfg)
|
|
|
|
if req.Model != "test-model" {
|
|
t.Errorf("expected model 'test-model', got %q", req.Model)
|
|
}
|
|
if len(req.Messages) != 2 {
|
|
t.Fatalf("expected 2 messages, got %d", len(req.Messages))
|
|
}
|
|
if req.Messages[0].Role != "system" {
|
|
t.Errorf("expected first message role='system', got %q", req.Messages[0].Role)
|
|
}
|
|
if req.Messages[1].Role != "user" {
|
|
t.Errorf("expected second message role='user', got %q", req.Messages[1].Role)
|
|
}
|
|
if req.Temperature == nil || *req.Temperature != 0.8 {
|
|
t.Errorf("expected temperature 0.8, got %v", req.Temperature)
|
|
}
|
|
if req.MaxTokens == nil || *req.MaxTokens != 512 {
|
|
t.Errorf("expected maxTokens 512, got %v", req.MaxTokens)
|
|
}
|
|
if req.TopP == nil || *req.TopP != 0.9 {
|
|
t.Errorf("expected topP 0.9, got %v", req.TopP)
|
|
}
|
|
if len(req.Stop) != 1 || req.Stop[0] != "END" {
|
|
t.Errorf("expected stop=[END], got %v", req.Stop)
|
|
}
|
|
if len(req.Tools) != 1 {
|
|
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
|
}
|
|
if req.Tools[0].Name != "greet" {
|
|
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
|
|
}
|
|
}
|
|
|
|
func TestBuildProviderRequest_EmptyConfig(t *testing.T) {
|
|
cfg := &requestConfig{}
|
|
msgs := []Message{UserMessage("hi")}
|
|
|
|
req := buildProviderRequest("model", msgs, cfg)
|
|
|
|
if req.Temperature != nil {
|
|
t.Errorf("expected nil temperature, got %v", req.Temperature)
|
|
}
|
|
if req.MaxTokens != nil {
|
|
t.Errorf("expected nil maxTokens, got %v", req.MaxTokens)
|
|
}
|
|
if req.TopP != nil {
|
|
t.Errorf("expected nil topP, got %v", req.TopP)
|
|
}
|
|
if len(req.Stop) != 0 {
|
|
t.Errorf("expected no stop sequences, got %v", req.Stop)
|
|
}
|
|
if len(req.Tools) != 0 {
|
|
t.Errorf("expected no tools, got %d", len(req.Tools))
|
|
}
|
|
}
|