[Feature] Support mtp overlap schedule (#7001)

This commit is contained in:
sunxin
2026-04-01 14:24:26 +08:00
committed by GitHub
parent c6f0c5c3a6
commit c29e86fc9d
23 changed files with 215 additions and 138 deletions
@@ -146,10 +146,10 @@ std::vector<paddle::Tensor> SpeculatePreProcess(
const int bsz = seq_len.shape()[0];
const int max_seq_len = input_ids_shape[1];
const int token_num_data = cpu_token_num;
auto ids_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto ids_remove_padding = paddle::full(
{token_num_data}, 2, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::full(
{token_num_data}, -1, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
@@ -170,9 +170,10 @@ std::vector<paddle::Tensor> SpeculatePreProcess(
auto cu_seq_lens_q_output =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto batch_id_per_token_output =
paddle::empty({bsz * max_draft_tokens_per_batch},
paddle::DataType::INT32,
input_ids.place());
paddle::full({bsz * max_draft_tokens_per_batch},
-1,
paddle::DataType::INT32,
input_ids.place());
auto real_output_token_num =
paddle::empty({1}, paddle::DataType::INT32, input_ids.place());