mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[Feature] Support mtp overlap schedule (#7001)
This commit is contained in:
@@ -146,10 +146,10 @@ std::vector<paddle::Tensor> SpeculatePreProcess(
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int max_seq_len = input_ids_shape[1];
|
||||
const int token_num_data = cpu_token_num;
|
||||
auto ids_remove_padding = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||
auto ids_remove_padding = paddle::full(
|
||||
{token_num_data}, 2, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::full(
|
||||
{token_num_data}, -1, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
@@ -170,9 +170,10 @@ std::vector<paddle::Tensor> SpeculatePreProcess(
|
||||
auto cu_seq_lens_q_output =
|
||||
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||
auto batch_id_per_token_output =
|
||||
paddle::empty({bsz * max_draft_tokens_per_batch},
|
||||
paddle::DataType::INT32,
|
||||
input_ids.place());
|
||||
paddle::full({bsz * max_draft_tokens_per_batch},
|
||||
-1,
|
||||
paddle::DataType::INT32,
|
||||
input_ids.place());
|
||||
auto real_output_token_num =
|
||||
paddle::empty({1}, paddle::DataType::INT32, input_ids.place());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user