diff --git a/v2/llm.go b/v2/llm.go index b34fba1..df94905 100644 --- a/v2/llm.go +++ b/v2/llm.go @@ -122,6 +122,26 @@ func buildProviderRequest(model string, messages []Message, cfg *requestConfig) } } + if cfg.cacheConfig != nil && cfg.cacheConfig.enabled { + hints := &provider.CacheHints{LastCacheableMessageIndex: -1} + if len(req.Tools) > 0 { + hints.CacheTools = true + } + for _, m := range messages { + if m.Role == RoleSystem { + hints.CacheSystem = true + break + } + } + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role != RoleSystem { + hints.LastCacheableMessageIndex = i + break + } + } + req.CacheHints = hints + } + return req } diff --git a/v2/request_test.go b/v2/request_test.go index b2682f1..dfb0d54 100644 --- a/v2/request_test.go +++ b/v2/request_test.go @@ -154,3 +154,89 @@ func TestWithoutPromptCaching(t *testing.T) { t.Error("expected cacheConfig to be nil when option not applied") } } + +func TestBuildProviderRequest_CachingDisabled(t *testing.T) { + cfg := &requestConfig{} + msgs := []Message{SystemMessage("sys"), UserMessage("hi")} + req := buildProviderRequest("m", msgs, cfg) + if req.CacheHints != nil { + t.Errorf("expected nil CacheHints when caching disabled, got %+v", req.CacheHints) + } +} + +func TestBuildProviderRequest_CachingEnabled_AllSections(t *testing.T) { + tool := DefineSimple("greet", "greet", func(ctx context.Context) (string, error) { return "ok", nil }) + tb := NewToolBox(tool) + cfg := &requestConfig{ + tools: tb, + cacheConfig: &cacheConfig{enabled: true}, + } + msgs := []Message{ + SystemMessage("you are helpful"), + UserMessage("hello"), + AssistantMessage("hi"), + UserMessage("thanks"), + } + req := buildProviderRequest("m", msgs, cfg) + if req.CacheHints == nil { + t.Fatal("expected CacheHints to be set") + } + if !req.CacheHints.CacheTools { + t.Error("expected CacheTools=true") + } + if !req.CacheHints.CacheSystem { + t.Error("expected CacheSystem=true") + } + // Last non-system message index = 3 ("thanks") + if req.CacheHints.LastCacheableMessageIndex != 3 { + t.Errorf("expected LastCacheableMessageIndex=3, got %d", req.CacheHints.LastCacheableMessageIndex) + } +} + +func TestBuildProviderRequest_CachingEnabled_NoTools(t *testing.T) { + cfg := &requestConfig{cacheConfig: &cacheConfig{enabled: true}} + msgs := []Message{SystemMessage("sys"), UserMessage("hi")} + req := buildProviderRequest("m", msgs, cfg) + if req.CacheHints == nil { + t.Fatal("expected CacheHints") + } + if req.CacheHints.CacheTools { + t.Error("expected CacheTools=false when there are no tools") + } + if !req.CacheHints.CacheSystem { + t.Error("expected CacheSystem=true") + } + if req.CacheHints.LastCacheableMessageIndex != 1 { + t.Errorf("expected LastCacheableMessageIndex=1, got %d", req.CacheHints.LastCacheableMessageIndex) + } +} + +func TestBuildProviderRequest_CachingEnabled_NoSystem(t *testing.T) { + cfg := &requestConfig{cacheConfig: &cacheConfig{enabled: true}} + msgs := []Message{UserMessage("hi")} + req := buildProviderRequest("m", msgs, cfg) + if req.CacheHints == nil { + t.Fatal("expected CacheHints") + } + if req.CacheHints.CacheSystem { + t.Error("expected CacheSystem=false when there is no system message") + } + if req.CacheHints.LastCacheableMessageIndex != 0 { + t.Errorf("expected LastCacheableMessageIndex=0, got %d", req.CacheHints.LastCacheableMessageIndex) + } +} + +func TestBuildProviderRequest_CachingEnabled_OnlySystem(t *testing.T) { + cfg := &requestConfig{cacheConfig: &cacheConfig{enabled: true}} + msgs := []Message{SystemMessage("sys")} + req := buildProviderRequest("m", msgs, cfg) + if req.CacheHints == nil { + t.Fatal("expected CacheHints") + } + if !req.CacheHints.CacheSystem { + t.Error("expected CacheSystem=true") + } + if req.CacheHints.LastCacheableMessageIndex != -1 { + t.Errorf("expected LastCacheableMessageIndex=-1 when no non-system messages, got %d", req.CacheHints.LastCacheableMessageIndex) + } +}