mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support mtp overlap schedule (#7001)
This commit is contained in:
@@ -316,6 +316,15 @@ class InputBatch:
|
||||
dtype="float32",
|
||||
)
|
||||
self.cu_batch_token_offset = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32")
|
||||
# For mtp overlap
|
||||
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
|
||||
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
|
||||
self.accept_tokens_cpu = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
).pin_memory()
|
||||
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
|
||||
if self.enable_mm:
|
||||
head_dim = self.model_config.head_dim
|
||||
if (
|
||||
@@ -435,6 +444,10 @@ class InputBatch:
|
||||
swap_data(self.step_seq_lens_this_time, i1, i2)
|
||||
swap_data(self.draft_logits, i1, i2)
|
||||
swap_data(self.cu_batch_token_offset, i1, i2)
|
||||
swap_data(self.seq_lens_decoder_cpu, i1, i2)
|
||||
swap_data(self.prompt_lens_cpu, i1, i2)
|
||||
swap_data(self.accept_tokens_cpu, i1, i2)
|
||||
swap_data(self.accept_num_cpu, i1, i2)
|
||||
|
||||
if self.enable_mm:
|
||||
if self.image_features_list is not None:
|
||||
@@ -623,6 +636,15 @@ class InputBatch:
|
||||
fill_paddle_tensor(self, "step_seq_lens_this_time", 0)
|
||||
fill_paddle_tensor(self, "draft_logits", -1)
|
||||
fill_paddle_tensor(self, "cu_batch_token_offset", 0)
|
||||
# for mtp overlap
|
||||
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
|
||||
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
|
||||
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
|
||||
self.accept_tokens_cpu = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
).pin_memory()
|
||||
|
||||
# Reset multimodal related tensors
|
||||
if self.enable_mm:
|
||||
@@ -697,6 +719,7 @@ class ProposerInputBatch(InputBatch):
|
||||
self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"])
|
||||
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
|
||||
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
|
||||
self.not_need_stop_device = paddle.to_tensor([False], dtype="bool")
|
||||
if current_platform.is_cuda():
|
||||
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
|
||||
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
|
||||
@@ -1085,6 +1108,7 @@ def _recover_tensor(recover_tensor, index_to_batch_id_list):
|
||||
"""
|
||||
sort_len = len(index_to_batch_id_list)
|
||||
if isinstance(recover_tensor.place, paddle.CUDAPinnedPlace):
|
||||
recover_tensor = recover_tensor.cpu()
|
||||
recover_res_tensor = paddle.empty_like(recover_tensor, device="cpu")
|
||||
else:
|
||||
recover_res_tensor = paddle.empty_like(recover_tensor)
|
||||
|
||||
Reference in New Issue
Block a user