[MTP] refactor MTP pre_process (#6358)

This commit is contained in:
周周周
2026-02-09 10:47:15 +08:00
committed by GitHub
parent 18e79dd660
commit 2b4748de4f
24 changed files with 411 additions and 533 deletions
@@ -175,8 +175,8 @@ def apply_speculative_penalty_multi_scores(
min_dec_lens: paddle.Tensor,
eos_token_ids: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
output_padding_offset: paddle.Tensor,
output_cum_offsets: paddle.Tensor,
batch_id_per_token_output: paddle.Tensor,
cu_seqlens_q_output: paddle.Tensor,
max_len: int,
):
"""
@@ -200,8 +200,8 @@ def apply_speculative_penalty_multi_scores(
min_dec_lens,
eos_token_ids,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
batch_id_per_token_output,
cu_seqlens_q_output,
max_len,
)
elif current_platform.is_xpu():
@@ -221,8 +221,8 @@ def apply_speculative_penalty_multi_scores(
min_dec_lens,
eos_token_ids,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
batch_id_per_token_output,
cu_seqlens_q_output,
max_len,
)
@@ -242,8 +242,8 @@ def reasoning_phase_token_constraint(
step_idx: paddle.Tensor,
reasoning_allowed_tokens: paddle.Tensor,
reasoning_status: paddle.Tensor,
output_padding_offset: paddle.Tensor,
output_cum_offsets: paddle.Tensor,
batch_id_per_token_output: paddle.Tensor,
cu_seqlens_q_output: paddle.Tensor,
enable_thinking: paddle.Tensor,
think_end_id: int,
line_break_id: int,
@@ -263,8 +263,8 @@ def reasoning_phase_token_constraint(
step_idx,
reasoning_allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
batch_id_per_token_output,
cu_seqlens_q_output,
enable_thinking,
think_end_id,
line_break_id,