[Speculative Decoding]Reformat input preprocess for spec decode (#6501)

* add speculate_pre_process kernel

* reduce one slice

* make d2h async && fix mtp bug for new pre_process

* fix

* add unitest

* fix: code stype formatting

* fix

* fix: thread race in speculate_preprocess && rename d2h event
This commit is contained in:
huicongyao
2026-03-03 10:22:07 +08:00
committed by GitHub
parent 33d6d2403c
commit 0f718baaf2
6 changed files with 619 additions and 25 deletions
+12
View File
@@ -751,6 +751,14 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
std::vector<paddle::Tensor> SpeculatePreProcess(
const int64_t cpu_token_num,
const paddle::Tensor& input_ids,
const paddle::Tensor& seq_len,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
void SpecTokenPenaltyMultiScores(
const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
@@ -1604,6 +1612,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&SpeculateGetSeqLensOutput,
"speculate_get_seq_lens_output function");
m.def("speculate_pre_process",
&SpeculatePreProcess,
"speculate_pre_process function");
m.def("speculate_get_token_penalty_multi_scores",
&SpecTokenPenaltyMultiScores,
"speculate_get_token_penalty_multi_scores function");