mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding]Reformat input preprocess for spec decode (#6501)
* add speculate_pre_process kernel * reduce one slice * make d2h async && fix mtp bug for new pre_process * fix * add unitest * fix: code stype formatting * fix * fix: thread race in speculate_preprocess && rename d2h event
This commit is contained in:
@@ -61,7 +61,6 @@ elif current_platform.is_maca():
|
||||
save_output,
|
||||
save_output_topk,
|
||||
set_stop_value_multi_ends,
|
||||
speculate_get_seq_lens_output,
|
||||
speculate_limit_thinking_content_length,
|
||||
speculate_save_output,
|
||||
speculate_save_output_topk,
|
||||
@@ -85,7 +84,7 @@ else:
|
||||
save_output,
|
||||
save_output_topk,
|
||||
set_stop_value_multi_ends,
|
||||
speculate_get_seq_lens_output,
|
||||
speculate_pre_process,
|
||||
speculate_save_output,
|
||||
speculate_save_output_topk,
|
||||
speculate_set_value_by_flags_and_idx,
|
||||
@@ -152,6 +151,7 @@ def pre_process(
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
# Remove padding
|
||||
if speculative_decoding:
|
||||
@@ -160,27 +160,12 @@ def pre_process(
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu)
|
||||
|
||||
# compute each batch's output token num
|
||||
seq_lens_output = speculate_get_seq_lens_output(
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_q_output,
|
||||
batch_id_per_token_output,
|
||||
real_output_token_num,
|
||||
) = speculate_pre_process(
|
||||
token_num_cpu, input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, seq_lens_decoder
|
||||
)
|
||||
if isinstance(seq_lens_output, list):
|
||||
seq_lens_output = seq_lens_output[0]
|
||||
output_token_num = paddle.sum(seq_lens_output)
|
||||
|
||||
useless_input_ids = input_ids
|
||||
_, batch_id_per_token_output, cu_seqlens_q_output, _ = get_padding_offset(
|
||||
useless_input_ids,
|
||||
seq_lens_output,
|
||||
None,
|
||||
None,
|
||||
output_token_num.item(),
|
||||
)
|
||||
|
||||
return (
|
||||
ids_remove_padding,
|
||||
batch_id_per_token,
|
||||
@@ -188,6 +173,7 @@ def pre_process(
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q_output,
|
||||
batch_id_per_token_output,
|
||||
real_output_token_num,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user