mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
support multi-step draft-model with cudagraph (#5886)
This commit is contained in:
@@ -708,7 +708,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["not_need_stop"][0] = True
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||
|
||||
def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
|
||||
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
|
||||
"""
|
||||
Initialize forward meta and attention meta data
|
||||
"""
|
||||
@@ -744,7 +744,12 @@ class MTPProposer(Proposer):
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph
|
||||
# Notes(liuzichang):
|
||||
# 1. CUDA Graph capture sizes must be recorded in descending order (large → small).
|
||||
# 2. In multi-step execution, only the first step should be captured.
|
||||
self.forward_meta.step_use_cudagraph = (
|
||||
step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run)
|
||||
)
|
||||
|
||||
def _initialize_forward_meta_xpu(self):
|
||||
|
||||
@@ -922,7 +927,9 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||
|
||||
# Initialize forward meta data
|
||||
self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph)
|
||||
self._initialize_forward_meta(
|
||||
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
|
||||
)
|
||||
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
|
||||
|
||||
# Padding inputs for cuda graph
|
||||
@@ -947,9 +954,10 @@ class MTPProposer(Proposer):
|
||||
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
|
||||
share_inputs=self.model_inputs,
|
||||
)
|
||||
|
||||
# Note(liuzichang):
|
||||
# paddle.clone would raise error 700 in cudaGraph mode
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
|
||||
self.last_seq_lens_this_time.copy_(self.model_inputs["seq_lens_this_time"], False)
|
||||
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||
|
||||
Reference in New Issue
Block a user