feat(v2): populate CacheHints on provider.Request when caching enabled
buildProviderRequest now computes cache-breakpoint positions automatically when the WithPromptCaching() option is set. It places up to 3 hints: tools, system, and the index of the last non-system message. Providers that don't support caching (OpenAI, Google) ignore the field. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
20
v2/llm.go
20
v2/llm.go
@@ -122,6 +122,26 @@ func buildProviderRequest(model string, messages []Message, cfg *requestConfig)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.cacheConfig != nil && cfg.cacheConfig.enabled {
|
||||||
|
hints := &provider.CacheHints{LastCacheableMessageIndex: -1}
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
hints.CacheTools = true
|
||||||
|
}
|
||||||
|
for _, m := range messages {
|
||||||
|
if m.Role == RoleSystem {
|
||||||
|
hints.CacheSystem = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
if messages[i].Role != RoleSystem {
|
||||||
|
hints.LastCacheableMessageIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req.CacheHints = hints
|
||||||
|
}
|
||||||
|
|
||||||
return req
|
return req
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -154,3 +154,89 @@ func TestWithoutPromptCaching(t *testing.T) {
|
|||||||
t.Error("expected cacheConfig to be nil when option not applied")
|
t.Error("expected cacheConfig to be nil when option not applied")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildProviderRequest_CachingDisabled(t *testing.T) {
|
||||||
|
cfg := &requestConfig{}
|
||||||
|
msgs := []Message{SystemMessage("sys"), UserMessage("hi")}
|
||||||
|
req := buildProviderRequest("m", msgs, cfg)
|
||||||
|
if req.CacheHints != nil {
|
||||||
|
t.Errorf("expected nil CacheHints when caching disabled, got %+v", req.CacheHints)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildProviderRequest_CachingEnabled_AllSections(t *testing.T) {
|
||||||
|
tool := DefineSimple("greet", "greet", func(ctx context.Context) (string, error) { return "ok", nil })
|
||||||
|
tb := NewToolBox(tool)
|
||||||
|
cfg := &requestConfig{
|
||||||
|
tools: tb,
|
||||||
|
cacheConfig: &cacheConfig{enabled: true},
|
||||||
|
}
|
||||||
|
msgs := []Message{
|
||||||
|
SystemMessage("you are helpful"),
|
||||||
|
UserMessage("hello"),
|
||||||
|
AssistantMessage("hi"),
|
||||||
|
UserMessage("thanks"),
|
||||||
|
}
|
||||||
|
req := buildProviderRequest("m", msgs, cfg)
|
||||||
|
if req.CacheHints == nil {
|
||||||
|
t.Fatal("expected CacheHints to be set")
|
||||||
|
}
|
||||||
|
if !req.CacheHints.CacheTools {
|
||||||
|
t.Error("expected CacheTools=true")
|
||||||
|
}
|
||||||
|
if !req.CacheHints.CacheSystem {
|
||||||
|
t.Error("expected CacheSystem=true")
|
||||||
|
}
|
||||||
|
// Last non-system message index = 3 ("thanks")
|
||||||
|
if req.CacheHints.LastCacheableMessageIndex != 3 {
|
||||||
|
t.Errorf("expected LastCacheableMessageIndex=3, got %d", req.CacheHints.LastCacheableMessageIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildProviderRequest_CachingEnabled_NoTools(t *testing.T) {
|
||||||
|
cfg := &requestConfig{cacheConfig: &cacheConfig{enabled: true}}
|
||||||
|
msgs := []Message{SystemMessage("sys"), UserMessage("hi")}
|
||||||
|
req := buildProviderRequest("m", msgs, cfg)
|
||||||
|
if req.CacheHints == nil {
|
||||||
|
t.Fatal("expected CacheHints")
|
||||||
|
}
|
||||||
|
if req.CacheHints.CacheTools {
|
||||||
|
t.Error("expected CacheTools=false when there are no tools")
|
||||||
|
}
|
||||||
|
if !req.CacheHints.CacheSystem {
|
||||||
|
t.Error("expected CacheSystem=true")
|
||||||
|
}
|
||||||
|
if req.CacheHints.LastCacheableMessageIndex != 1 {
|
||||||
|
t.Errorf("expected LastCacheableMessageIndex=1, got %d", req.CacheHints.LastCacheableMessageIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildProviderRequest_CachingEnabled_NoSystem(t *testing.T) {
|
||||||
|
cfg := &requestConfig{cacheConfig: &cacheConfig{enabled: true}}
|
||||||
|
msgs := []Message{UserMessage("hi")}
|
||||||
|
req := buildProviderRequest("m", msgs, cfg)
|
||||||
|
if req.CacheHints == nil {
|
||||||
|
t.Fatal("expected CacheHints")
|
||||||
|
}
|
||||||
|
if req.CacheHints.CacheSystem {
|
||||||
|
t.Error("expected CacheSystem=false when there is no system message")
|
||||||
|
}
|
||||||
|
if req.CacheHints.LastCacheableMessageIndex != 0 {
|
||||||
|
t.Errorf("expected LastCacheableMessageIndex=0, got %d", req.CacheHints.LastCacheableMessageIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildProviderRequest_CachingEnabled_OnlySystem(t *testing.T) {
|
||||||
|
cfg := &requestConfig{cacheConfig: &cacheConfig{enabled: true}}
|
||||||
|
msgs := []Message{SystemMessage("sys")}
|
||||||
|
req := buildProviderRequest("m", msgs, cfg)
|
||||||
|
if req.CacheHints == nil {
|
||||||
|
t.Fatal("expected CacheHints")
|
||||||
|
}
|
||||||
|
if !req.CacheHints.CacheSystem {
|
||||||
|
t.Error("expected CacheSystem=true")
|
||||||
|
}
|
||||||
|
if req.CacheHints.LastCacheableMessageIndex != -1 {
|
||||||
|
t.Errorf("expected LastCacheableMessageIndex=-1 when no non-system messages, got %d", req.CacheHints.LastCacheableMessageIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user