mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[MTP] refactor MTP pre_process (#6358)
This commit is contained in:
@@ -63,7 +63,6 @@ elif current_platform.is_maca():
|
||||
save_output,
|
||||
save_output_topk,
|
||||
set_stop_value_multi_ends,
|
||||
speculate_get_output_padding_offset,
|
||||
speculate_get_seq_lens_output,
|
||||
speculate_limit_thinking_content_length_v1,
|
||||
speculate_limit_thinking_content_length_v2,
|
||||
@@ -89,7 +88,6 @@ else:
|
||||
save_output,
|
||||
save_output_topk,
|
||||
set_stop_value_multi_ends,
|
||||
speculate_get_output_padding_offset,
|
||||
speculate_get_seq_lens_output,
|
||||
speculate_save_output,
|
||||
speculate_save_output_topk,
|
||||
@@ -240,10 +238,6 @@ def pre_process(
|
||||
None,
|
||||
)
|
||||
# Remove padding
|
||||
max_len = input_ids.shape[1]
|
||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
|
||||
output_padding_offset = None
|
||||
output_cum_offsets = None
|
||||
if speculative_decoding:
|
||||
(
|
||||
ids_remove_padding,
|
||||
@@ -251,6 +245,8 @@ def pre_process(
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu)
|
||||
|
||||
# compute each batch's output token num
|
||||
seq_lens_output = speculate_get_seq_lens_output(
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
@@ -259,28 +255,23 @@ def pre_process(
|
||||
if isinstance(seq_lens_output, list):
|
||||
seq_lens_output = seq_lens_output[0]
|
||||
output_token_num = paddle.sum(seq_lens_output)
|
||||
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32")
|
||||
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
|
||||
output_cum_offsets_tmp,
|
||||
output_token_num,
|
||||
|
||||
useless_input_ids = input_ids
|
||||
_, batch_id_per_token_output, cu_seqlens_q_output, _ = get_padding_offset(
|
||||
useless_input_ids,
|
||||
seq_lens_output,
|
||||
max_len,
|
||||
None,
|
||||
None,
|
||||
output_token_num.item(),
|
||||
)
|
||||
else:
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
(
|
||||
ids_remove_padding,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
|
||||
|
||||
return (
|
||||
ids_remove_padding,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
output_cum_offsets,
|
||||
output_padding_offset,
|
||||
cu_seqlens_q_output,
|
||||
batch_id_per_token_output,
|
||||
)
|
||||
|
||||
|
||||
@@ -841,8 +832,8 @@ def rebuild_padding(
|
||||
seq_len_this_time: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
output_padding_offset: Optional[paddle.Tensor] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
batch_id_per_token_output: Optional[paddle.Tensor] = None,
|
||||
cu_seqlens_q_output: Optional[paddle.Tensor] = None,
|
||||
first_token_out: Optional[paddle.Tensor] = None,
|
||||
enable_logprob: Optional[bool] = False,
|
||||
):
|
||||
@@ -859,9 +850,9 @@ def rebuild_padding(
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob,
|
||||
)
|
||||
elif current_platform.is_dcu():
|
||||
@@ -873,8 +864,7 @@ def rebuild_padding(
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length,
|
||||
batch_id_per_token_output,
|
||||
)
|
||||
elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import rebuild_padding
|
||||
@@ -885,9 +875,8 @@ def rebuild_padding(
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
batch_id_per_token_output,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob,
|
||||
)
|
||||
elif current_platform.is_gcu():
|
||||
@@ -899,8 +888,7 @@ def rebuild_padding(
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length,
|
||||
batch_id_per_token_output,
|
||||
)
|
||||
elif current_platform.is_cpu():
|
||||
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
|
||||
@@ -911,8 +899,7 @@ def rebuild_padding(
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length,
|
||||
batch_id_per_token_output,
|
||||
)
|
||||
elif current_platform.is_maca():
|
||||
from fastdeploy.model_executor.ops.gpu import rebuild_padding
|
||||
@@ -923,9 +910,8 @@ def rebuild_padding(
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
batch_id_per_token_output,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user