diff --git a/media/media.go b/media/media.go index 8cce4c4..467affa 100644 --- a/media/media.go +++ b/media/media.go @@ -58,21 +58,21 @@ func Normalize(req llm.Request, caps llm.Capabilities) (llm.Request, error) { if !caps.SupportsImages() { return llm.Request{}, fmt.Errorf("media: %w: target does not accept image input (request carries %d image(s))", llm.ErrUnsupported, total) } - // Overflow: keep the most-recent MaxImagesPerReq images and replace each - // older one with a short text placeholder, rather than refusing the whole - // request. A hard refuse exhausts a failover chain whose targets share the - // same cap — e.g. an agent loop that accumulates a preview image per - // iteration past the cap makes EVERY target reject and the run dies. The - // placeholder preserves each message's turn structure and tells the model an - // earlier image was elided; the most recent images (the relevant ones in an - // iterative run) are retained. The per-model threshold stays configurable via - // Capabilities.MaxImagesPerReq (0 still means "no image support"). + // Over-cap images are elided in the same copy-on-write pass below: the + // OLDEST excess are replaced with a placeholder and the most-recent + // MaxImagesPerReq kept (see the package doc for why we elide rather than + // refuse). toElide is how many of the first images, front-to-back, to drop. + toElide := 0 if total > caps.MaxImagesPerReq { - req = dropOldestImages(req, total-caps.MaxImagesPerReq) + toElide = total - caps.MaxImagesPerReq } + // Single copy-on-write pass: for each image, the first toElide become a text + // placeholder; the rest are size-normalized against caps. The Messages slice + // and an affected message's Parts slice are copied at most once. out := req copiedMessages := false + seen := 0 for mi := range req.Messages { copiedParts := false for pi, part := range req.Messages[mi].Parts { @@ -80,13 +80,22 @@ func Normalize(req llm.Request, caps llm.Capabilities) (llm.Request, error) { if !ok { continue } - norm, changed, err := normalizeImage(ip, caps) - if err != nil { - return llm.Request{}, fmt.Errorf("media: message %d, part %d: %w", mi, pi, err) - } - if !changed { - continue + seen++ + + var replacement llm.Part + if seen <= toElide { + replacement = llm.Text(imageOverflowPlaceholder) + } else { + norm, changed, err := normalizeImage(ip, caps) + if err != nil { + return llm.Request{}, fmt.Errorf("media: message %d, part %d: %w", mi, pi, err) + } + if !changed { + continue + } + replacement = norm } + if !copiedMessages { out.Messages = make([]llm.Message, len(req.Messages)) copy(out.Messages, req.Messages) @@ -98,59 +107,17 @@ func Normalize(req llm.Request, caps llm.Capabilities) (llm.Request, error) { out.Messages[mi].Parts = parts copiedParts = true } - out.Messages[mi].Parts[pi] = norm + out.Messages[mi].Parts[pi] = replacement } } return out, nil } -// imageOverflowPlaceholder replaces an image dropped to fit a target's +// imageOverflowPlaceholder replaces an image elided to fit a target's // per-request image cap. It keeps the message turn intact and tells the model -// an earlier image was elided rather than silently changing the conversation. +// an earlier image was omitted rather than silently changing the conversation. const imageOverflowPlaceholder = "[earlier image omitted to fit this model's per-request image limit]" -// dropOldestImages replaces the n oldest image parts (front-to-back across the -// message history) with imageOverflowPlaceholder text, keeping the most-recent -// images and preserving every message's turn structure. Copy-on-write: the -// input request is never mutated. n <= 0 returns req unchanged. -func dropOldestImages(req llm.Request, n int) llm.Request { - if n <= 0 { - return req - } - out := req - out.Messages = make([]llm.Message, len(req.Messages)) - copy(out.Messages, req.Messages) - dropped := 0 - for mi := range out.Messages { - if dropped >= n { - break - } - if !hasImagePart(out.Messages[mi].Parts) { - continue - } - parts := make([]llm.Part, 0, len(out.Messages[mi].Parts)) - for _, p := range out.Messages[mi].Parts { - if _, ok := p.(llm.ImagePart); ok && dropped < n { - dropped++ - parts = append(parts, llm.Text(imageOverflowPlaceholder)) - continue - } - parts = append(parts, p) - } - out.Messages[mi].Parts = parts - } - return out -} - -func hasImagePart(parts []llm.Part) bool { - for _, p := range parts { - if _, ok := p.(llm.ImagePart); ok { - return true - } - } - return false -} - // Info reports an image part's sniffed format ("jpeg", "png", "gif", or // "webp") and pixel dimensions. It is a cheap metadata read — the pixels are // never decoded. webp is recognized by signature but not decodable with the diff --git a/media/media_test.go b/media/media_test.go index e0129dd..8b37319 100644 --- a/media/media_test.go +++ b/media/media_test.go @@ -149,10 +149,10 @@ func TestNormalizeImagesUnsupported(t *testing.T) { } } -func TestNormalizeTooManyImages_DropsOldest(t *testing.T) { - // 3 distinguishable images across 2 messages; cap = 2. Overflow no longer +func TestNormalizeOverCount(t *testing.T) { + // 3 distinguishable images across 2 messages; cap = 2. Over-count no longer // errors — the OLDEST image is replaced with a placeholder and the most-recent - // two (the relevant ones in an iterative run) are kept. + // two (the relevant ones in an iterative run) are kept, in order. a := llm.Image("image/png", encPNG(t, gradient(2, 2))).(llm.ImagePart) b := llm.Image("image/png", encPNG(t, gradient(4, 4))).(llm.ImagePart) c := llm.Image("image/png", encPNG(t, gradient(8, 8))).(llm.ImagePart) @@ -163,7 +163,7 @@ func TestNormalizeTooManyImages_DropsOldest(t *testing.T) { caps := llm.Capabilities{MaxImagesPerReq: 2, MaxImageDimension: 64, MaxImageBytes: 1 << 20, AllowedImageMIME: []string{"image/png"}} out, err := Normalize(req, caps) if err != nil { - t.Fatalf("drop-oldest overflow should not error: %v", err) + t.Fatalf("over-count should not error: %v", err) } var imgs []llm.ImagePart placeholders := 0 @@ -179,20 +179,18 @@ func TestNormalizeTooManyImages_DropsOldest(t *testing.T) { } } } - if len(imgs) != 2 { - t.Fatalf("kept %d images, want 2 (the cap)", len(imgs)) + // The exact survivors are the most-recent two, in order: b then c (a elided). + if len(imgs) != 2 || !bytes.Equal(imgs[0].Data, b.Data) || !bytes.Equal(imgs[1].Data, c.Data) { + t.Fatalf("kept %d images; want exactly [b, c] (the most-recent two)", len(imgs)) } if placeholders != 1 { - t.Errorf("placeholders = %d, want 1 for the dropped oldest image", placeholders) + t.Errorf("placeholders = %d, want 1 for the elided oldest image", placeholders) } - for _, im := range imgs { - if bytes.Equal(im.Data, a.Data) { - t.Errorf("oldest image was kept; the most-recent two should survive") - } - } - // The input request must be untouched (copy-on-write). - if len(req.Messages[0].Parts) != 2 { - t.Errorf("input request was mutated: %+v", req.Messages[0].Parts) + // Input request untouched (copy-on-write): the first part is still image a, + // not a placeholder — a len check alone wouldn't catch in-place substitution. + first, ok := req.Messages[0].Parts[0].(llm.ImagePart) + if !ok || !bytes.Equal(first.Data, a.Data) { + t.Errorf("input request was mutated; first part = %+v", req.Messages[0].Parts[0]) } }