support multi-step draft-model with cudagraph (#5886)

This commit is contained in:
freeliuzc
2026-01-06 11:16:21 +08:00
committed by GitHub
parent 7a0744f05a
commit ca574119e5
2 changed files with 15 additions and 46 deletions
+13 -5
View File
@@ -708,7 +708,7 @@ class MTPProposer(Proposer):
self.model_inputs["not_need_stop"][0] = True
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
"""
Initialize forward meta and attention meta data
"""
@@ -744,7 +744,12 @@ class MTPProposer(Proposer):
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph
# Notes(liuzichang):
# 1. CUDA Graph capture sizes must be recorded in descending order (large → small).
# 2. In multi-step execution, only the first step should be captured.
self.forward_meta.step_use_cudagraph = (
step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run)
)
def _initialize_forward_meta_xpu(self):
@@ -922,7 +927,9 @@ class MTPProposer(Proposer):
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
# Initialize forward meta data
self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph)
self._initialize_forward_meta(
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
)
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
# Padding inputs for cuda graph
@@ -947,9 +954,10 @@ class MTPProposer(Proposer):
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
share_inputs=self.model_inputs,
)
# Note(liuzichang):
# paddle.clone would raise error 700 in cudaGraph mode
if self.num_model_steps > 1:
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
self.last_seq_lens_this_time.copy_(self.model_inputs["seq_lens_this_time"], False)
model_output = self.model(
ids_remove_padding=self.model_inputs["ids_remove_padding"],