mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Speculative Decoding]Reformat input preprocess for spec decode (#6501)
* add speculate_pre_process kernel * reduce one slice * make d2h async && fix mtp bug for new pre_process * fix * add unitest * fix: code stype formatting * fix * fix: thread race in speculate_preprocess && rename d2h event
This commit is contained in:
@@ -751,6 +751,14 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculatePreProcess(
|
||||
const int64_t cpu_token_num,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& seq_len,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void SpecTokenPenaltyMultiScores(
|
||||
const paddle::Tensor& token_ids_all,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
@@ -1604,6 +1612,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
&SpeculateGetSeqLensOutput,
|
||||
"speculate_get_seq_lens_output function");
|
||||
|
||||
m.def("speculate_pre_process",
|
||||
&SpeculatePreProcess,
|
||||
"speculate_pre_process function");
|
||||
|
||||
m.def("speculate_get_token_penalty_multi_scores",
|
||||
&SpecTokenPenaltyMultiScores,
|
||||
"speculate_get_token_penalty_multi_scores function");
|
||||
|
||||
Reference in New Issue
Block a user