[BugFix] fix cache manager not launched in case of mtp or blockwise fp8 (#5840)

* [BugFix] fix cache manager not launched in case of mtp or blockwise fp8

* [fix] fix mtp cache in mtp.py

* [fix] fix gpu ops import

* [fix] fix mtp layer idx

* [fix] fix xpu model runner mtp cache

* [fix] fix mtp import
This commit is contained in:
Yonghua Li
2026-01-04 20:35:37 +08:00
committed by GitHub
parent 55f77e9ab1
commit 5e4e6692a4
3 changed files with 32 additions and 5 deletions
+28 -1
View File
@@ -45,6 +45,7 @@ if current_platform.is_xpu():
eagle_get_self_hidden_states,
mtp_save_first_token,
mtp_step_paddle,
set_data_ipc,
share_external_data,
)
from fastdeploy.model_executor.xpu_pre_and_post_process import (
@@ -65,6 +66,7 @@ else:
speculate_get_logits,
speculate_save_output_topk,
update_attn_mask_offsets,
set_data_ipc,
)
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
@@ -210,6 +212,9 @@ class MTPProposer(Proposer):
self.num_main_model_layers,
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
logger.info(
f"..attaching kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}"
)
key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
@@ -233,28 +238,50 @@ class MTPProposer(Proposer):
self.model_inputs["caches"] = cache_kvs_list
else:
for i in range(self.model_config.num_hidden_layers):
for i in range(
self.num_main_model_layers,
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
logger.info(f"..creating kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")
self.cache_kvs[f"key_caches_{i}"] = paddle.full(
shape=key_cache_shape,
fill_value=0,
dtype=cache_type,
)
set_data_ipc(
self.cache_kvs[f"key_caches_{i}"], f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
)
self.cache_kvs[f"value_caches_{i}"] = paddle.full(
shape=value_cache_shape,
fill_value=0,
dtype=cache_type,
)
set_data_ipc(
self.cache_kvs[f"value_caches_{i}"], f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
)
if kv_cache_quant_type == "block_wise_fp8":
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
set_data_ipc(
self.cache_kvs[f"key_cache_scales_{i}"],
f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}",
)
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
set_data_ipc(
self.cache_kvs[f"value_cache_scales_{i}"],
f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}",
)
self.model_inputs["caches"] = list(self.cache_kvs.values())
for value in self.cache_kvs.values():
del value