[MTP] refactor MTP pre_process (#6358)

This commit is contained in:
周周周
2026-02-09 10:47:15 +08:00
committed by GitHub
parent 18e79dd660
commit 2b4748de4f
24 changed files with 411 additions and 533 deletions
@@ -63,7 +63,6 @@ elif current_platform.is_maca():
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_get_output_padding_offset,
speculate_get_seq_lens_output,
speculate_limit_thinking_content_length_v1,
speculate_limit_thinking_content_length_v2,
@@ -89,7 +88,6 @@ else:
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_get_output_padding_offset,
speculate_get_seq_lens_output,
speculate_save_output,
speculate_save_output_topk,
@@ -240,10 +238,6 @@ def pre_process(
None,
)
# Remove padding
max_len = input_ids.shape[1]
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
output_padding_offset = None
output_cum_offsets = None
if speculative_decoding:
(
ids_remove_padding,
@@ -251,6 +245,8 @@ def pre_process(
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu)
# compute each batch's output token num
seq_lens_output = speculate_get_seq_lens_output(
seq_lens_this_time,
seq_lens_encoder,
@@ -259,28 +255,23 @@ def pre_process(
if isinstance(seq_lens_output, list):
seq_lens_output = seq_lens_output[0]
output_token_num = paddle.sum(seq_lens_output)
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32")
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
output_cum_offsets_tmp,
output_token_num,
useless_input_ids = input_ids
_, batch_id_per_token_output, cu_seqlens_q_output, _ = get_padding_offset(
useless_input_ids,
seq_lens_output,
max_len,
None,
None,
output_token_num.item(),
)
else:
token_num = paddle.sum(seq_lens_this_time)
(
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
return (
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
output_cum_offsets,
output_padding_offset,
cu_seqlens_q_output,
batch_id_per_token_output,
)
@@ -841,8 +832,8 @@ def rebuild_padding(
seq_len_this_time: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
output_padding_offset: Optional[paddle.Tensor] = None,
max_input_length: Optional[int] = None,
batch_id_per_token_output: Optional[paddle.Tensor] = None,
cu_seqlens_q_output: Optional[paddle.Tensor] = None,
first_token_out: Optional[paddle.Tensor] = None,
enable_logprob: Optional[bool] = False,
):
@@ -859,9 +850,9 @@ def rebuild_padding(
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
max_input_length,
enable_logprob,
)
elif current_platform.is_dcu():
@@ -873,8 +864,7 @@ def rebuild_padding(
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length,
batch_id_per_token_output,
)
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import rebuild_padding
@@ -885,9 +875,8 @@ def rebuild_padding(
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
batch_id_per_token_output,
first_token_out,
max_input_length,
enable_logprob,
)
elif current_platform.is_gcu():
@@ -899,8 +888,7 @@ def rebuild_padding(
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length,
batch_id_per_token_output,
)
elif current_platform.is_cpu():
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
@@ -911,8 +899,7 @@ def rebuild_padding(
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length,
batch_id_per_token_output,
)
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import rebuild_padding
@@ -923,9 +910,8 @@ def rebuild_padding(
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
batch_id_per_token_output,
first_token_out,
max_input_length,
enable_logprob,
)
else: