[Speculative Decoding][MTP]Update extract_mtp_weight script and optimize config (#5183)

* update extract_mtp_model

* modify config usage
This commit is contained in:
freeliuzc
2025-11-25 14:09:03 +08:00
committed by GitHub
parent edf0d09257
commit 5c8c2d47eb
4 changed files with 43 additions and 6 deletions
+7 -1
View File
@@ -355,7 +355,13 @@ class MTPProposer(Proposer):
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["target_hidden_states"] = paddle.full(
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
[
self.fd_config.scheduler_config.max_num_batched_tokens
+ self.fd_config.scheduler_config.max_extra_num_batched_tokens,
self.model_config.hidden_size,
],
0,
dtype="bfloat16",
)
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))