[Speculative Decoding]Reformat input preprocess for spec decode (#6501)

* add speculate_pre_process kernel

* reduce one slice

* make d2h async && fix mtp bug for new pre_process

* fix

* add unitest

* fix: code stype formatting

* fix

* fix: thread race in speculate_preprocess && rename d2h event
This commit is contained in:
huicongyao
2026-03-03 10:22:07 +08:00
committed by GitHub
parent 33d6d2403c
commit 0f718baaf2
6 changed files with 619 additions and 25 deletions
@@ -61,7 +61,6 @@ elif current_platform.is_maca():
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_get_seq_lens_output,
speculate_limit_thinking_content_length,
speculate_save_output,
speculate_save_output_topk,
@@ -85,7 +84,7 @@ else:
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_get_seq_lens_output,
speculate_pre_process,
speculate_save_output,
speculate_save_output_topk,
speculate_set_value_by_flags_and_idx,
@@ -152,6 +151,7 @@ def pre_process(
cu_seqlens_k,
None,
None,
None,
)
# Remove padding
if speculative_decoding:
@@ -160,27 +160,12 @@ def pre_process(
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu)
# compute each batch's output token num
seq_lens_output = speculate_get_seq_lens_output(
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_q_output,
batch_id_per_token_output,
real_output_token_num,
) = speculate_pre_process(
token_num_cpu, input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, seq_lens_decoder
)
if isinstance(seq_lens_output, list):
seq_lens_output = seq_lens_output[0]
output_token_num = paddle.sum(seq_lens_output)
useless_input_ids = input_ids
_, batch_id_per_token_output, cu_seqlens_q_output, _ = get_padding_offset(
useless_input_ids,
seq_lens_output,
None,
None,
output_token_num.item(),
)
return (
ids_remove_padding,
batch_id_per_token,
@@ -188,6 +173,7 @@ def pre_process(
cu_seqlens_k,
cu_seqlens_q_output,
batch_id_per_token_output,
real_output_token_num,
)