mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
fix mtp prefix_cache dy-c8 bug (#5390)
This commit is contained in:
@@ -218,6 +218,18 @@ class MTPProposer(Proposer):
|
||||
value_cache = share_external_data(value_cache, val_cache_name, value_cache_shape)
|
||||
cache_kvs_list.append(value_cache)
|
||||
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
||||
scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
key_scale_cache = share_external_data(key_scale_cache, scale_key_cache_name, kv_cache_scale_shape)
|
||||
cache_kvs_list.append(key_scale_cache)
|
||||
value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
value_scale_cache = share_external_data(
|
||||
value_scale_cache, scale_val_cache_name, kv_cache_scale_shape
|
||||
)
|
||||
cache_kvs_list.append(value_scale_cache)
|
||||
|
||||
self.model_inputs["caches"] = cache_kvs_list
|
||||
else:
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
|
||||
Reference in New Issue
Block a user