[Speculative Decoding][MTP] Support static CacheKV C8 quantization and optimize memory usage (#5155)

* support static cachekv c8 quantization in mtp mode

* optimize memory allocation
This commit is contained in:
freeliuzc
2025-11-21 15:10:13 +08:00
committed by GitHub
parent 3c36283d7d
commit 2d1dade5e2
6 changed files with 350 additions and 295 deletions
+2 -1
View File
@@ -104,6 +104,7 @@ class MTPProposer(Proposer):
self.model_config.num_hidden_layers = 1
self.model_config.model = self.speculative_config.model
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
self.model_config.prefix_layer_name = "mtp_block"
if self.speculative_config.quantization != "":
self.model_config.quantization = self.speculative_config.quantization
self.model_config.start_layer_index = self.num_main_model_layers
@@ -354,7 +355,7 @@ class MTPProposer(Proposer):
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["target_hidden_states"] = paddle.full(
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
)
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))