fix mtp prefix_cache dy-c8 bug (#5390)

This commit is contained in:
kevin
2025-12-05 19:03:19 +08:00
committed by GitHub
parent c9d7f9e7c3
commit db936ab3e4
+12
View File
@@ -218,6 +218,18 @@ class MTPProposer(Proposer):
value_cache = share_external_data(value_cache, val_cache_name, value_cache_shape)
cache_kvs_list.append(value_cache)
if kv_cache_quant_type == "block_wise_fp8":
scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
key_scale_cache = share_external_data(key_scale_cache, scale_key_cache_name, kv_cache_scale_shape)
cache_kvs_list.append(key_scale_cache)
value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
value_scale_cache = share_external_data(
value_scale_cache, scale_val_cache_name, kv_cache_scale_shape
)
cache_kvs_list.append(value_scale_cache)
self.model_inputs["caches"] = cache_kvs_list
else:
for i in range(self.model_config.num_hidden_layers):