mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix mtp logprob bugs in chunk prefill (#5244)
* fix mtp logprob bugs in chunk prefill * fix * fix
This commit is contained in:
@@ -338,6 +338,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["prompt_lens"] = paddle.clone(self.target_model_inputs["prompt_lens"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"])
|
||||
self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"])
|
||||
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
|
||||
@@ -766,7 +767,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
|
||||
def _propose(self, step_use_cudagraph: bool = False):
|
||||
def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
@@ -891,7 +892,12 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs,
|
||||
)
|
||||
|
||||
if substep == 0 and sampler_output.logprobs_tensors is not None:
|
||||
if (
|
||||
not is_dummy_run
|
||||
and self.parallel_config.tensor_parallel_rank == 0
|
||||
and substep == 0
|
||||
and sampler_output.logprobs_tensors is not None
|
||||
):
|
||||
real_bsz = self.model_inputs["seq_lens_this_time"].shape[0]
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
@@ -901,8 +907,11 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["batch_token_num"][:real_bsz],
|
||||
self.model_inputs["cu_batch_token_offset"][:real_bsz],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["prompt_lens"],
|
||||
4, # mtype
|
||||
self.local_rank,
|
||||
self.parallel_config.use_ep,
|
||||
)
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
@@ -1009,10 +1018,12 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
|
||||
def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False):
|
||||
def _run_impl(
|
||||
self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False
|
||||
):
|
||||
"""Execute Draft Model"""
|
||||
self._prepare_inputs(full_hidden_states)
|
||||
self._propose(step_use_cudagraph=step_use_cudagraph)
|
||||
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run)
|
||||
self._update_status()
|
||||
if self.hybrid_mode:
|
||||
self._extend_draft_token_with_ngram_match()
|
||||
|
||||
Reference in New Issue
Block a user