mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix cache manager not launched in case of mtp or blockwise fp8 (#5840)
* [BugFix] fix cache manager not launched in case of mtp or blockwise fp8 * [fix] fix mtp cache in mtp.py * [fix] fix gpu ops import * [fix] fix mtp layer idx * [fix] fix xpu model runner mtp cache * [fix] fix mtp import
This commit is contained in:
@@ -45,6 +45,7 @@ if current_platform.is_xpu():
|
||||
eagle_get_self_hidden_states,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
)
|
||||
from fastdeploy.model_executor.xpu_pre_and_post_process import (
|
||||
@@ -65,6 +66,7 @@ else:
|
||||
speculate_get_logits,
|
||||
speculate_save_output_topk,
|
||||
update_attn_mask_offsets,
|
||||
set_data_ipc,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
|
||||
@@ -210,6 +212,9 @@ class MTPProposer(Proposer):
|
||||
self.num_main_model_layers,
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
logger.info(
|
||||
f"..attaching kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}"
|
||||
)
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
@@ -233,28 +238,50 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model_inputs["caches"] = cache_kvs_list
|
||||
else:
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
for i in range(
|
||||
self.num_main_model_layers,
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
logger.info(f"..creating kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")
|
||||
self.cache_kvs[f"key_caches_{i}"] = paddle.full(
|
||||
shape=key_cache_shape,
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"key_caches_{i}"], f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
)
|
||||
|
||||
self.cache_kvs[f"value_caches_{i}"] = paddle.full(
|
||||
shape=value_cache_shape,
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"value_caches_{i}"], f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
)
|
||||
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"key_cache_scales_{i}"],
|
||||
f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
|
||||
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"value_cache_scales_{i}"],
|
||||
f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
|
||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||
for value in self.cache_kvs.values():
|
||||
del value
|
||||
|
||||
Reference in New Issue
Block a user