mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix][KVCache] Fix mm hash boundary comparison in get_block_hash_extra_keys (#6929)
* [BugFix][KVCache] Fix mm hash boundary comparison in get_block_hash_extra_keys Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * [BugFix][KVCache] Fix test_get_block_hash_extra_keys_boundary_cases assertions ## Motivation 测试用例 `test_get_block_hash_extra_keys_boundary_cases` 中,Block [4,8) 的 调用错误地传入了 `mm_idx=1`,跳过了 img0[2,5);但 img0 覆盖 token 4,token 4 属于 block [4,8),应被包含在 hash_keys 中。此外,所有 assertEqual 只校验了 hash_keys,未校验返回的 mm_idx 游标。 ## Modifications - `test_get_block_hash_extra_keys_boundary_cases`: - 改为链式调用,用上一次返回的 mm_idx 作为下一次入参,模拟真实调用循环 - Block [4,8) 入参从 `mm_idx=1` 改为沿用上次返回的 `mm_idx=0`,期望值从 `[]` 改为 `["hash-0"]` - 所有断言改为 `assertEqual((mm_idx, hash_keys), (...))` 同时校验游标 - `test_get_block_hash_extra_keys_no_overlap_at_boundaries`: - Case B 入参从 `mm_idx=1` 改为 `mm_idx=0`(从头遍历,img-a 走 continue) - 所有断言增加 mm_idx 校验 - `test_get_block_hash_extra_keys_image_crosses_block_boundary`: - 所有断言增加 mm_idx 校验 - `test_get_block_hash_extra_keys_no_mm_inputs`: - 断言增加 mm_idx 校验 - `test_get_block_hash_extra_keys_handles_multimodal_segments`: - call2、call3 断言增加 mm_idx 校验 ## Usage or Command ```bash python -m pytest tests/cache_manager/test_prefix_cache_manager.py::TestPrefixCacheManagerCoverage -v -k "get_block_hash_extra_keys" ``` Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: chengyanfu <chengyanfu@baidu.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1624,7 +1624,7 @@ class PrefixCacheManager:
|
||||
mm_inputs["mm_hashes"]
|
||||
), f"mm_idx {mm_idx} out of range {len(mm_inputs['mm_hashes'])}"
|
||||
|
||||
if mm_inputs["mm_positions"][-1].offset + mm_inputs["mm_positions"][-1].length < start_idx:
|
||||
if mm_inputs["mm_positions"][-1].offset + mm_inputs["mm_positions"][-1].length <= start_idx:
|
||||
# non images in current block
|
||||
return mm_idx, hash_keys
|
||||
|
||||
@@ -1632,7 +1632,7 @@ class PrefixCacheManager:
|
||||
image_offset = mm_inputs["mm_positions"][img_idx].offset
|
||||
image_length = mm_inputs["mm_positions"][img_idx].length
|
||||
|
||||
if image_offset + image_length < start_idx:
|
||||
if image_offset + image_length <= start_idx:
|
||||
# image before block
|
||||
continue
|
||||
elif image_offset >= end_idx:
|
||||
|
||||
@@ -1226,6 +1226,111 @@ class TestPrefixCacheManagerCoverage(unittest.TestCase):
|
||||
is_sync=False,
|
||||
)
|
||||
|
||||
def test_get_block_hash_extra_keys_boundary_cases(self):
|
||||
"""
|
||||
覆盖 image 与 block 的各种位置关系,验证 hash_keys 的正确性。
|
||||
|
||||
数据布局 (block_size=4):
|
||||
tokens: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
|
||||
img0: [====] [2,5) hash-0
|
||||
img1: [========] [8,12) hash-1
|
||||
img2: [==] [14,16) hash-2
|
||||
blocks: [====][====][====][====]
|
||||
[0,4) [4,8) [8,12)[12,16)
|
||||
"""
|
||||
manager = _create_manager()
|
||||
request = SimpleNamespace(
|
||||
multimodal_inputs={
|
||||
"mm_positions": [
|
||||
SimpleNamespace(offset=2, length=3), # [2,5)
|
||||
SimpleNamespace(offset=8, length=4), # [8,12)
|
||||
SimpleNamespace(offset=14, length=2), # [14,16)
|
||||
],
|
||||
"mm_hashes": ["hash-0", "hash-1", "hash-2"],
|
||||
},
|
||||
num_total_tokens=16,
|
||||
)
|
||||
|
||||
# 模拟真实调用循环:每次用上一次返回的 mm_idx 作为下一次入参
|
||||
mm_idx = 0
|
||||
|
||||
# ---- Block [0,4): img0[2,5) 跨越右边界,返回 mm_idx=0(img0 未完全消费)----
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=0, end_idx=4, mm_idx=mm_idx)
|
||||
self.assertEqual((mm_idx, hash_keys), (0, ["hash-0"]))
|
||||
|
||||
# ---- Block [4,8): 沿用返回的 mm_idx=0,img0 的 tail token 4 在本 block 内 ----
|
||||
# img0[2,5): 5 > start_idx=4,不走 continue;5<=end_idx=8,走 else → 包含 hash-0
|
||||
# img1[8,12): image_offset=8 >= end_idx=8 → 结束,返回 mm_idx=1
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=mm_idx)
|
||||
self.assertEqual((mm_idx, hash_keys), (1, ["hash-0"]))
|
||||
|
||||
# ---- Block [8,12): 沿用返回的 mm_idx=1,img1 恰好填满整个 block ----
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=8, end_idx=12, mm_idx=mm_idx)
|
||||
self.assertEqual((mm_idx, hash_keys), (2, ["hash-1"]))
|
||||
|
||||
# ---- Block [12,16): 沿用返回的 mm_idx,img2 完全在 block 内部 [14,16) ⊂ [12,16) ----
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=12, end_idx=16, mm_idx=mm_idx)
|
||||
self.assertEqual((mm_idx, hash_keys), (2, ["hash-2"]))
|
||||
|
||||
def test_get_block_hash_extra_keys_no_overlap_at_boundaries(self):
|
||||
"""image 与 block 恰好相接时不应有重叠。"""
|
||||
manager = _create_manager()
|
||||
|
||||
# image 恰好在 block 之前: img[0,4), block[4,8)
|
||||
request = SimpleNamespace(
|
||||
multimodal_inputs={
|
||||
"mm_positions": [SimpleNamespace(offset=0, length=4)],
|
||||
"mm_hashes": ["hash-a"],
|
||||
},
|
||||
num_total_tokens=8,
|
||||
)
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=0)
|
||||
self.assertEqual((mm_idx, hash_keys), (0, []))
|
||||
|
||||
# image 恰好在 block 之后: img[8,10), block[4,8)
|
||||
request = SimpleNamespace(
|
||||
multimodal_inputs={
|
||||
"mm_positions": [SimpleNamespace(offset=0, length=4), SimpleNamespace(offset=8, length=2)],
|
||||
"mm_hashes": ["hash-a", "hash-b"],
|
||||
},
|
||||
num_total_tokens=12,
|
||||
)
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=0)
|
||||
self.assertEqual((mm_idx, hash_keys), (1, []))
|
||||
|
||||
def test_get_block_hash_extra_keys_image_crosses_block_boundary(self):
|
||||
"""image 跨越 block 边界时 hash 应被包含。"""
|
||||
manager = _create_manager()
|
||||
|
||||
# image 跨越 block 右边界: img[6,10), block[4,8)
|
||||
request = SimpleNamespace(
|
||||
multimodal_inputs={
|
||||
"mm_positions": [SimpleNamespace(offset=6, length=4)],
|
||||
"mm_hashes": ["hash-cross"],
|
||||
},
|
||||
num_total_tokens=12,
|
||||
)
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=0)
|
||||
self.assertEqual((mm_idx, hash_keys), (0, ["hash-cross"]))
|
||||
|
||||
# image 跨越整个 block: img[3,9), block[4,8)
|
||||
request = SimpleNamespace(
|
||||
multimodal_inputs={
|
||||
"mm_positions": [SimpleNamespace(offset=3, length=6)],
|
||||
"mm_hashes": ["hash-span"],
|
||||
},
|
||||
num_total_tokens=12,
|
||||
)
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=4, end_idx=8, mm_idx=0)
|
||||
self.assertEqual((mm_idx, hash_keys), (0, ["hash-span"]))
|
||||
|
||||
def test_get_block_hash_extra_keys_no_mm_inputs(self):
|
||||
"""无多模态输入时应返回空。"""
|
||||
manager = _create_manager()
|
||||
request = SimpleNamespace(multimodal_inputs=None, num_total_tokens=12)
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=0, end_idx=4, mm_idx=0)
|
||||
self.assertEqual((mm_idx, hash_keys), (0, []))
|
||||
|
||||
def test_get_block_hash_extra_keys_handles_multimodal_segments(self):
|
||||
manager = _create_manager()
|
||||
request = SimpleNamespace(
|
||||
@@ -1239,10 +1344,10 @@ class TestPrefixCacheManagerCoverage(unittest.TestCase):
|
||||
self.assertEqual((mm_idx, hash_keys), (0, []))
|
||||
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=2, end_idx=6, mm_idx=0)
|
||||
self.assertEqual(hash_keys, ["img-a"])
|
||||
self.assertEqual((mm_idx, hash_keys), (1, ["img-a"]))
|
||||
|
||||
mm_idx, hash_keys = manager.get_block_hash_extra_keys(request, start_idx=7, end_idx=10, mm_idx=1)
|
||||
self.assertEqual(hash_keys, ["img-b"])
|
||||
self.assertEqual((mm_idx, hash_keys), (1, ["img-b"]))
|
||||
|
||||
def test_cache_output_blocks_updates_leaf_and_recycles_redundant_block(self):
|
||||
manager = _create_manager(num_gpu_blocks=6, num_cpu_blocks=1)
|
||||
|
||||
Reference in New Issue
Block a user