[Feature] Support mtp overlap schedule (#7001)

This commit is contained in:
sunxin
2026-04-01 14:24:26 +08:00
committed by GitHub
parent c6f0c5c3a6
commit c29e86fc9d
23 changed files with 215 additions and 138 deletions
@@ -155,12 +155,15 @@ class CudaGraphPiecewiseBackend:
def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
# Get real shape (total num tokens)
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
num_running_requests = int((seq_lens_this_time.flatten() > 0).sum().item())
real_bsz = kwargs["forward_meta"].real_bsz
num_running_requests = real_bsz if real_bsz > 0 else int((seq_lens_this_time.flatten() > 0).sum().item())
num_running_requests = max(1, num_running_requests)
real_shape = self.real_bsz_to_captured_size[num_running_requests]
else:
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
exist_prefill = kwargs["forward_meta"].exist_prefill
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st