mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 17:49:42 +08:00
[MTP] refactor MTP pre_process (#6358)
This commit is contained in:
@@ -758,8 +758,8 @@ class SpeculativeSampler(nn.Layer):
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
share_inputs["batch_id_per_token_output"],
|
||||
share_inputs["cu_seqlens_q_output"],
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
@@ -773,8 +773,8 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["reasoning_allowed_tokens"],
|
||||
share_inputs["reasoning_status"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
share_inputs["batch_id_per_token_output"],
|
||||
share_inputs["cu_seqlens_q_output"],
|
||||
share_inputs["enable_thinking"],
|
||||
self.think_end_id,
|
||||
self.line_break_id,
|
||||
@@ -794,7 +794,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["batch_id_per_token_output"],
|
||||
self.speculative_max_candidate_len,
|
||||
max_model_len,
|
||||
)
|
||||
@@ -816,7 +816,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs["max_dec_len"],
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
share_inputs["cu_seqlens_q_output"],
|
||||
actual_candidate_len,
|
||||
share_inputs["actual_draft_token_num"],
|
||||
sampling_metadata.top_p,
|
||||
@@ -1135,8 +1135,8 @@ class MTPSampler(nn.Layer):
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
share_inputs["batch_id_per_token_output"],
|
||||
share_inputs["cu_seqlens_q_output"],
|
||||
max_model_len,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
Reference in New Issue
Block a user