[PD Disaggregation] pd + cache_storage support vl model (#6906)

* pd + cache_storage support vl model

* support vl model

* fix test
This commit is contained in:
jc
2026-03-23 15:35:20 +08:00
committed by GitHub
parent 5416da8c6e
commit bb881c2c0a
3 changed files with 51 additions and 6 deletions
@@ -860,8 +860,17 @@ class PrefixCacheManager:
prefix_block_key = [] if match_block_node.hash_value is None else [match_block_node.hash_value]
cur_token_idx = match_token_num
no_match_block_keys = []
mm_idx = 0
while cur_token_idx <= input_token_num - block_size:
cur_block_token_ids = input_token_ids[cur_token_idx : cur_token_idx + block_size]
# Get extra hash keys for multimodal content (images, videos, etc.)
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=task,
start_idx=cur_token_idx,
end_idx=cur_token_idx + block_size,
mm_idx=mm_idx,
)
prefix_block_key.extend(extra_keys)
cur_block_key = get_hash_str(cur_block_token_ids, prefix_block_key)
no_match_block_keys.append(cur_block_key)
cur_token_idx += block_size
@@ -1103,8 +1112,30 @@ class PrefixCacheManager:
def write_cache_to_storage(self, request: Request):
"""
For finished request, write cache to storage.
NOTE: this function does not modify the global params
Write finished request's KV cache to storage backend (P instance with Radix Tree).
This method is called after a request finishes generation. It traverses the Radix
Tree from leaf node to root to collect cache keys, then issues a write-back task
to persist KV cache blocks to the storage backend.
Args:
request: The finished request containing:
- prompt_token_ids: Input token sequence
- output_token_ids: Generated output tokens (used if enable_output_caching)
- block_tables: Mapping of logical to physical block IDs
- request_id: Unique request identifier
Process:
1. Get token_ids (prompt tokens + output tokens if output caching enabled)
2. Traverse Radix Tree from leaf (req_leaf_map[req_id]) to root, collecting hash keys
3. Reverse keys to get root-to-leaf order
4. Create WriteStorageTask with keys, token_ids, and gpu_block_ids
5. Issue synchronous write-back task to storage backend
Note:
- This function does not modify global params (block_tables, ref counters)
- Only called on P instance which maintains the Radix Tree
- For D instance, use write_cache_to_storage_decode() instead
"""
if self.kvcache_storage_backend is None:
return
@@ -1170,12 +1201,22 @@ class PrefixCacheManager:
keys = []
prefix_block_key = [] # Initial is empty list
block_size = self.config.cache_config.block_size
mm_idx = 0 # Multimodal index for tracking position in mm_inputs
for i in range(0, len(token_ids), block_size):
block_token_ids = token_ids[i : i + block_size]
if len(block_token_ids) < block_size:
break # Do not cache incomplete block
# Get extra hash keys for multimodal content (images, videos, etc.)
mm_idx, extra_keys = self.get_block_hash_extra_keys(
request=request,
start_idx=i,
end_idx=i + block_size,
mm_idx=mm_idx,
)
prefix_block_key.extend(extra_keys)
# Calculate hash key for current block
key = get_hash_str(block_token_ids, prefix_block_key)
keys.append(key)