mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding][MTP]Update extract_mtp_weight script and optimize config (#5183)
* update extract_mtp_model * modify config usage
This commit is contained in:
@@ -355,7 +355,13 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["target_hidden_states"] = paddle.full(
|
||||
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||
[
|
||||
self.fd_config.scheduler_config.max_num_batched_tokens
|
||||
+ self.fd_config.scheduler_config.max_extra_num_batched_tokens,
|
||||
self.model_config.hidden_size,
|
||||
],
|
||||
0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
|
||||
|
||||
Reference in New Issue
Block a user