[PD Disaggregation] pd + cache_storage support vl model (#6906)

* pd + cache_storage support vl model

* support vl model

* fix test
This commit is contained in:
jc
2026-03-23 15:35:20 +08:00
committed by GitHub
parent 5416da8c6e
commit bb881c2c0a
3 changed files with 51 additions and 6 deletions
@@ -860,8 +860,17 @@ class PrefixCacheManager:
prefix_block_key = [] if match_block_node.hash_value is None else [match_block_node.hash_value]
cur_token_idx = match_token_num
no_match_block_keys = []
mm_idx = 0
while cur_token_idx <= input_token_num - block_size:
cur_block_token_ids = input_token_ids[cur_token_idx : cur_token_idx + block_size]
# Get extra hash keys for multimodal content (images, videos, etc.)
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=task,
start_idx=cur_token_idx,
end_idx=cur_token_idx + block_size,
mm_idx=mm_idx,
)
prefix_block_key.extend(extra_keys)
cur_block_key = get_hash_str(cur_block_token_ids, prefix_block_key)
no_match_block_keys.append(cur_block_key)
cur_token_idx += block_size
@@ -1103,8 +1112,30 @@ class PrefixCacheManager:
def write_cache_to_storage(self, request: Request):
"""
For finished request, write cache to storage.
NOTE: this function does not modify the global params
Write finished request's KV cache to storage backend (P instance with Radix Tree).
This method is called after a request finishes generation. It traverses the Radix
Tree from leaf node to root to collect cache keys, then issues a write-back task
to persist KV cache blocks to the storage backend.
Args:
request: The finished request containing:
- prompt_token_ids: Input token sequence
- output_token_ids: Generated output tokens (used if enable_output_caching)
- block_tables: Mapping of logical to physical block IDs
- request_id: Unique request identifier
Process:
1. Get token_ids (prompt tokens + output tokens if output caching enabled)
2. Traverse Radix Tree from leaf (req_leaf_map[req_id]) to root, collecting hash keys
3. Reverse keys to get root-to-leaf order
4. Create WriteStorageTask with keys, token_ids, and gpu_block_ids
5. Issue synchronous write-back task to storage backend
Note:
- This function does not modify global params (block_tables, ref counters)
- Only called on P instance which maintains the Radix Tree
- For D instance, use write_cache_to_storage_decode() instead
"""
if self.kvcache_storage_backend is None:
return
@@ -1170,12 +1201,22 @@ class PrefixCacheManager:
keys = []
prefix_block_key = [] # Initial is empty list
block_size = self.config.cache_config.block_size
mm_idx = 0 # Multimodal index for tracking position in mm_inputs
for i in range(0, len(token_ids), block_size):
block_token_ids = token_ids[i : i + block_size]
if len(block_token_ids) < block_size:
break # Do not cache incomplete block
# Get extra hash keys for multimodal content (images, videos, etc.)
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=request,
start_idx=i,
end_idx=i + block_size,
mm_idx=mm_idx,
)
prefix_block_key.extend(extra_keys)
# Calculate hash key for current block
key = get_hash_str(block_token_ids, prefix_block_key)
keys.append(key)
+2 -3
View File
@@ -460,13 +460,12 @@ class Request:
"ic_req_data": self.ic_req_data,
}
# During multimodal PD separation, position_ids are required
if isinstance(self.multimodal_inputs, dict):
# Optimize multimodal data transfer during PD separation:
# - V1 mode (ENABLE_V1_KVCACHE_SCHEDULER=1): Only position_ids needed for decode nodes
# - V1 mode (ENABLE_V1_KVCACHE_SCHEDULER=1): position_ids, mm_positions and mm_hashes needed for decode nodes
# - V0 mode (ENABLE_V1_KVCACHE_SCHEDULER=0): Full field set required for compatibility
# This filtering significantly reduces serialized data size for large numpy arrays
allowed_keys = {"position_ids"}
allowed_keys = {"position_ids", "mm_positions", "mm_hashes"}
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
allowed_keys.update(["input_ids", "token_type_ids", "images", "image_type_ids", "grid_thw"])
@@ -1294,7 +1294,12 @@ class TestPrefixCacheManagerCoverage(unittest.TestCase):
manager = _create_manager(num_gpu_blocks=6)
manager.kvcache_storage_backend = "memory"
manager.prefix_tree_status_signal = SimpleNamespace(value=np.array([PrefixTreeStatus.NORMAL]))
task = SimpleNamespace(prompt_token_ids=[1, 2, 3, 4, 5, 6], output_token_ids=[], request_id="storage-req")
task = SimpleNamespace(
prompt_token_ids=[1, 2, 3, 4, 5, 6],
output_token_ids=[],
request_id="storage-req",
multimodal_inputs=None,
)
with (
patch.object(manager, "mm_match_block", return_value=([], [], [], manager.radix_tree_root, 0, 0)),