Files
go-llm/v2/request_test.go
Steve Dudenhoeffer 4b401fcc0d
All checks were successful
CI / Lint (push) Successful in 9m36s
CI / Root Module (push) Successful in 10m55s
CI / V2 Module (push) Successful in 11m14s
feat(v2): populate CacheHints on provider.Request when caching enabled
buildProviderRequest now computes cache-breakpoint positions automatically
when the WithPromptCaching() option is set. It places up to 3 hints:
tools, system, and the index of the last non-system message. Providers
that don't support caching (OpenAI, Google) ignore the field.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 19:22:00 +00:00

243 lines
6.8 KiB
Go

package llm
import (
"context"
"testing"
)
func TestWithTemperature(t *testing.T) {
cfg := &requestConfig{}
WithTemperature(0.7)(cfg)
if cfg.temperature == nil || *cfg.temperature != 0.7 {
t.Errorf("expected temperature 0.7, got %v", cfg.temperature)
}
}
func TestWithMaxTokens(t *testing.T) {
cfg := &requestConfig{}
WithMaxTokens(256)(cfg)
if cfg.maxTokens == nil || *cfg.maxTokens != 256 {
t.Errorf("expected maxTokens 256, got %v", cfg.maxTokens)
}
}
func TestWithTopP(t *testing.T) {
cfg := &requestConfig{}
WithTopP(0.95)(cfg)
if cfg.topP == nil || *cfg.topP != 0.95 {
t.Errorf("expected topP 0.95, got %v", cfg.topP)
}
}
func TestWithStop(t *testing.T) {
cfg := &requestConfig{}
WithStop("END", "STOP", "###")(cfg)
if len(cfg.stop) != 3 {
t.Fatalf("expected 3 stop sequences, got %d", len(cfg.stop))
}
if cfg.stop[0] != "END" || cfg.stop[1] != "STOP" || cfg.stop[2] != "###" {
t.Errorf("unexpected stop sequences: %v", cfg.stop)
}
}
func TestWithTools(t *testing.T) {
tool := DefineSimple("test", "A test tool", func(ctx context.Context) (string, error) {
return "ok", nil
})
tb := NewToolBox(tool)
cfg := &requestConfig{}
WithTools(tb)(cfg)
if cfg.tools == nil {
t.Fatal("expected tools to be set")
}
if len(cfg.tools.AllTools()) != 1 {
t.Errorf("expected 1 tool, got %d", len(cfg.tools.AllTools()))
}
}
func TestBuildProviderRequest(t *testing.T) {
tool := DefineSimple("greet", "Greets", func(ctx context.Context) (string, error) {
return "hi", nil
})
tb := NewToolBox(tool)
temp := 0.8
maxTok := 512
topP := 0.9
cfg := &requestConfig{
tools: tb,
temperature: &temp,
maxTokens: &maxTok,
topP: &topP,
stop: []string{"END"},
}
msgs := []Message{
SystemMessage("be nice"),
UserMessage("hello"),
}
req := buildProviderRequest("test-model", msgs, cfg)
if req.Model != "test-model" {
t.Errorf("expected model 'test-model', got %q", req.Model)
}
if len(req.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(req.Messages))
}
if req.Messages[0].Role != "system" {
t.Errorf("expected first message role='system', got %q", req.Messages[0].Role)
}
if req.Messages[1].Role != "user" {
t.Errorf("expected second message role='user', got %q", req.Messages[1].Role)
}
if req.Temperature == nil || *req.Temperature != 0.8 {
t.Errorf("expected temperature 0.8, got %v", req.Temperature)
}
if req.MaxTokens == nil || *req.MaxTokens != 512 {
t.Errorf("expected maxTokens 512, got %v", req.MaxTokens)
}
if req.TopP == nil || *req.TopP != 0.9 {
t.Errorf("expected topP 0.9, got %v", req.TopP)
}
if len(req.Stop) != 1 || req.Stop[0] != "END" {
t.Errorf("expected stop=[END], got %v", req.Stop)
}
if len(req.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
}
if req.Tools[0].Name != "greet" {
t.Errorf("expected tool name 'greet', got %q", req.Tools[0].Name)
}
}
func TestBuildProviderRequest_EmptyConfig(t *testing.T) {
cfg := &requestConfig{}
msgs := []Message{UserMessage("hi")}
req := buildProviderRequest("model", msgs, cfg)
if req.Temperature != nil {
t.Errorf("expected nil temperature, got %v", req.Temperature)
}
if req.MaxTokens != nil {
t.Errorf("expected nil maxTokens, got %v", req.MaxTokens)
}
if req.TopP != nil {
t.Errorf("expected nil topP, got %v", req.TopP)
}
if len(req.Stop) != 0 {
t.Errorf("expected no stop sequences, got %v", req.Stop)
}
if len(req.Tools) != 0 {
t.Errorf("expected no tools, got %d", len(req.Tools))
}
}
func TestWithPromptCaching(t *testing.T) {
cfg := &requestConfig{}
WithPromptCaching()(cfg)
if cfg.cacheConfig == nil {
t.Fatal("expected cacheConfig to be set after WithPromptCaching()")
}
if !cfg.cacheConfig.enabled {
t.Error("expected cacheConfig.enabled to be true")
}
}
func TestWithoutPromptCaching(t *testing.T) {
cfg := &requestConfig{}
// No option applied
if cfg.cacheConfig != nil {
t.Error("expected cacheConfig to be nil when option not applied")
}
}
func TestBuildProviderRequest_CachingDisabled(t *testing.T) {
cfg := &requestConfig{}
msgs := []Message{SystemMessage("sys"), UserMessage("hi")}
req := buildProviderRequest("m", msgs, cfg)
if req.CacheHints != nil {
t.Errorf("expected nil CacheHints when caching disabled, got %+v", req.CacheHints)
}
}
func TestBuildProviderRequest_CachingEnabled_AllSections(t *testing.T) {
tool := DefineSimple("greet", "greet", func(ctx context.Context) (string, error) { return "ok", nil })
tb := NewToolBox(tool)
cfg := &requestConfig{
tools: tb,
cacheConfig: &cacheConfig{enabled: true},
}
msgs := []Message{
SystemMessage("you are helpful"),
UserMessage("hello"),
AssistantMessage("hi"),
UserMessage("thanks"),
}
req := buildProviderRequest("m", msgs, cfg)
if req.CacheHints == nil {
t.Fatal("expected CacheHints to be set")
}
if !req.CacheHints.CacheTools {
t.Error("expected CacheTools=true")
}
if !req.CacheHints.CacheSystem {
t.Error("expected CacheSystem=true")
}
// Last non-system message index = 3 ("thanks")
if req.CacheHints.LastCacheableMessageIndex != 3 {
t.Errorf("expected LastCacheableMessageIndex=3, got %d", req.CacheHints.LastCacheableMessageIndex)
}
}
func TestBuildProviderRequest_CachingEnabled_NoTools(t *testing.T) {
cfg := &requestConfig{cacheConfig: &cacheConfig{enabled: true}}
msgs := []Message{SystemMessage("sys"), UserMessage("hi")}
req := buildProviderRequest("m", msgs, cfg)
if req.CacheHints == nil {
t.Fatal("expected CacheHints")
}
if req.CacheHints.CacheTools {
t.Error("expected CacheTools=false when there are no tools")
}
if !req.CacheHints.CacheSystem {
t.Error("expected CacheSystem=true")
}
if req.CacheHints.LastCacheableMessageIndex != 1 {
t.Errorf("expected LastCacheableMessageIndex=1, got %d", req.CacheHints.LastCacheableMessageIndex)
}
}
func TestBuildProviderRequest_CachingEnabled_NoSystem(t *testing.T) {
cfg := &requestConfig{cacheConfig: &cacheConfig{enabled: true}}
msgs := []Message{UserMessage("hi")}
req := buildProviderRequest("m", msgs, cfg)
if req.CacheHints == nil {
t.Fatal("expected CacheHints")
}
if req.CacheHints.CacheSystem {
t.Error("expected CacheSystem=false when there is no system message")
}
if req.CacheHints.LastCacheableMessageIndex != 0 {
t.Errorf("expected LastCacheableMessageIndex=0, got %d", req.CacheHints.LastCacheableMessageIndex)
}
}
func TestBuildProviderRequest_CachingEnabled_OnlySystem(t *testing.T) {
cfg := &requestConfig{cacheConfig: &cacheConfig{enabled: true}}
msgs := []Message{SystemMessage("sys")}
req := buildProviderRequest("m", msgs, cfg)
if req.CacheHints == nil {
t.Fatal("expected CacheHints")
}
if !req.CacheHints.CacheSystem {
t.Error("expected CacheSystem=true")
}
if req.CacheHints.LastCacheableMessageIndex != -1 {
t.Errorf("expected LastCacheableMessageIndex=-1 when no non-system messages, got %d", req.CacheHints.LastCacheableMessageIndex)
}
}