mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding][MTP] Support static CacheKV C8 quantization and optimize memory usage (#5155)
* support static cachekv c8 quantization in mtp mode * optimize memory allocation
This commit is contained in:
@@ -104,6 +104,7 @@ class MTPProposer(Proposer):
|
||||
self.model_config.num_hidden_layers = 1
|
||||
self.model_config.model = self.speculative_config.model
|
||||
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
|
||||
self.model_config.prefix_layer_name = "mtp_block"
|
||||
if self.speculative_config.quantization != "":
|
||||
self.model_config.quantization = self.speculative_config.quantization
|
||||
self.model_config.start_layer_index = self.num_main_model_layers
|
||||
@@ -354,7 +355,7 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["target_hidden_states"] = paddle.full(
|
||||
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
|
||||
|
||||
Reference in New Issue
Block a user