[BugFix][KVCache] Fix mm hash boundary comparison in get_block_hash_extra_keys (#6929)

* [BugFix][KVCache] Fix mm hash boundary comparison in get_block_hash_extra_keys

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] Fix test_get_block_hash_extra_keys_boundary_cases assertions

## Motivation

测试用例 `test_get_block_hash_extra_keys_boundary_cases` 中,Block [4,8) 的
调用错误地传入了 `mm_idx=1`,跳过了 img0[2,5);但 img0 覆盖 token 4,token 4
属于 block [4,8),应被包含在 hash_keys 中。此外,所有 assertEqual 只校验了
hash_keys,未校验返回的 mm_idx 游标。

## Modifications

- `test_get_block_hash_extra_keys_boundary_cases`:
  - 改为链式调用,用上一次返回的 mm_idx 作为下一次入参,模拟真实调用循环
  - Block [4,8) 入参从 `mm_idx=1` 改为沿用上次返回的 `mm_idx=0`,期望值从 `[]` 改为 `["hash-0"]`
  - 所有断言改为 `assertEqual((mm_idx, hash_keys), (...))` 同时校验游标
- `test_get_block_hash_extra_keys_no_overlap_at_boundaries`:
  - Case B 入参从 `mm_idx=1` 改为 `mm_idx=0`(从头遍历,img-a 走 continue)
  - 所有断言增加 mm_idx 校验
- `test_get_block_hash_extra_keys_image_crosses_block_boundary`:
  - 所有断言增加 mm_idx 校验
- `test_get_block_hash_extra_keys_no_mm_inputs`:
  - 断言增加 mm_idx 校验
- `test_get_block_hash_extra_keys_handles_multimodal_segments`:
  - call2、call3 断言增加 mm_idx 校验

## Usage or Command

```bash
python -m pytest tests/cache_manager/test_prefix_cache_manager.py::TestPrefixCacheManagerCoverage -v -k "get_block_hash_extra_keys"
```

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: chengyanfu <chengyanfu@baidu.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
kevin
2026-03-30 17:13:31 +08:00
committed by GitHub
parent 76cf5e9496
commit 18062c55bb
2 changed files with 109 additions and 4 deletions
@@ -1624,7 +1624,7 @@ class PrefixCacheManager:
mm_inputs["mm_hashes"]
), f"mm_idx {mm_idx} out of range {len(mm_inputs['mm_hashes'])}"
if mm_inputs["mm_positions"][-1].offset + mm_inputs["mm_positions"][-1].length < start_idx:
if mm_inputs["mm_positions"][-1].offset + mm_inputs["mm_positions"][-1].length <= start_idx:
# non images in current block
return mm_idx, hash_keys
@@ -1632,7 +1632,7 @@ class PrefixCacheManager:
image_offset = mm_inputs["mm_positions"][img_idx].offset
image_length = mm_inputs["mm_positions"][img_idx].length
if image_offset + image_length < start_idx:
if image_offset + image_length <= start_idx:
# image before block
continue
elif image_offset >= end_idx:
@@ -1226,6 +1226,111 @@ class TestPrefixCacheManagerCoverage(unittest.TestCase):
is_sync=False,
)
def test_get_block_hash_extra_keys_boundary_cases(self):
"""
覆盖 image 与 block 的各种位置关系,验证 hash_keys 的正确性。
数据布局 (block_size=4):
tokens: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
img0: [====] [2,5) hash-0
img1: [========] [8,12) hash-1
img2: [==] [14,16) hash-2
blocks: [====][====][====][====]
[0,4) [4,8) [8,12)[12,16)
"""
manager = _create_manager()
request = SimpleNamespace(
multimodal_inputs={
"mm_positions": [
SimpleNamespace(offset=2, length=3), # [2,5)
SimpleNamespace(offset=8, length=4), # [8,12)
SimpleNamespace(offset=14, length=2), # [14,16)
],
"mm_hashes": ["hash-0", "hash-1", "hash-2"],
},
num_total_tokens=16,
)
# 模拟真实调用循环:每次用上一次返回的 mm_idx 作为下一次入参
mm_idx = 0
# ---- Block [0,4): img0[2,5) 跨越右边界,返回 mm_idx=0img0 未完全消费)----
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=0, end_idx=4, mm_idx=mm_idx)
self.assertEqual((mm_idx, hash_keys), (0, ["hash-0"]))
# ---- Block [4,8): 沿用返回的 mm_idx=0img0 的 tail token 4 在本 block 内 ----
# img0[2,5): 5 > start_idx=4,不走 continue5<=end_idx=8,走 else → 包含 hash-0
# img1[8,12): image_offset=8 >= end_idx=8 → 结束,返回 mm_idx=1
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=mm_idx)
self.assertEqual((mm_idx, hash_keys), (1, ["hash-0"]))
# ---- Block [8,12): 沿用返回的 mm_idx=1img1 恰好填满整个 block ----
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=8, end_idx=12, mm_idx=mm_idx)
self.assertEqual((mm_idx, hash_keys), (2, ["hash-1"]))
# ---- Block [12,16): 沿用返回的 mm_idximg2 完全在 block 内部 [14,16) ⊂ [12,16) ----
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=12, end_idx=16, mm_idx=mm_idx)
self.assertEqual((mm_idx, hash_keys), (2, ["hash-2"]))
def test_get_block_hash_extra_keys_no_overlap_at_boundaries(self):
"""image 与 block 恰好相接时不应有重叠。"""
manager = _create_manager()
# image 恰好在 block 之前: img[0,4), block[4,8)
request = SimpleNamespace(
multimodal_inputs={
"mm_positions": [SimpleNamespace(offset=0, length=4)],
"mm_hashes": ["hash-a"],
},
num_total_tokens=8,
)
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=0)
self.assertEqual((mm_idx, hash_keys), (0, []))
# image 恰好在 block 之后: img[8,10), block[4,8)
request = SimpleNamespace(
multimodal_inputs={
"mm_positions": [SimpleNamespace(offset=0, length=4), SimpleNamespace(offset=8, length=2)],
"mm_hashes": ["hash-a", "hash-b"],
},
num_total_tokens=12,
)
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=0)
self.assertEqual((mm_idx, hash_keys), (1, []))
def test_get_block_hash_extra_keys_image_crosses_block_boundary(self):
"""image 跨越 block 边界时 hash 应被包含。"""
manager = _create_manager()
# image 跨越 block 右边界: img[6,10), block[4,8)
request = SimpleNamespace(
multimodal_inputs={
"mm_positions": [SimpleNamespace(offset=6, length=4)],
"mm_hashes": ["hash-cross"],
},
num_total_tokens=12,
)
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=0)
self.assertEqual((mm_idx, hash_keys), (0, ["hash-cross"]))
# image 跨越整个 block: img[3,9), block[4,8)
request = SimpleNamespace(
multimodal_inputs={
"mm_positions": [SimpleNamespace(offset=3, length=6)],
"mm_hashes": ["hash-span"],
},
num_total_tokens=12,
)
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=0)
self.assertEqual((mm_idx, hash_keys), (0, ["hash-span"]))
def test_get_block_hash_extra_keys_no_mm_inputs(self):
"""无多模态输入时应返回空。"""
manager = _create_manager()
request = SimpleNamespace(multimodal_inputs=None, num_total_tokens=12)
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=0, end_idx=4, mm_idx=0)
self.assertEqual((mm_idx, hash_keys), (0, []))
def test_get_block_hash_extra_keys_handles_multimodal_segments(self):
manager = _create_manager()
request = SimpleNamespace(
@@ -1239,10 +1344,10 @@ class TestPrefixCacheManagerCoverage(unittest.TestCase):
self.assertEqual((mm_idx, hash_keys), (0, []))
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=2, end_idx=6, mm_idx=0)
self.assertEqual(hash_keys, ["img-a"])
self.assertEqual((mm_idx, hash_keys), (1, ["img-a"]))
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=7, end_idx=10, mm_idx=1)
self.assertEqual(hash_keys, ["img-b"])
self.assertEqual((mm_idx, hash_keys), (1, ["img-b"]))
def test_cache_output_blocks_updates_leaf_and_recycles_redundant_block(self):
manager = _create_manager(num_gpu_blocks=6, num_cpu_blocks=1)