mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support mtp overlap schedule (#7001)
This commit is contained in:
@@ -155,12 +155,15 @@ class CudaGraphPiecewiseBackend:
|
||||
|
||||
def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
|
||||
# Get real shape (total num tokens)
|
||||
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
||||
real_shape = ids_remove_padding.shape[0]
|
||||
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
|
||||
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
|
||||
num_running_requests = int((seq_lens_this_time.flatten() > 0).sum().item())
|
||||
real_bsz = kwargs["forward_meta"].real_bsz
|
||||
num_running_requests = real_bsz if real_bsz > 0 else int((seq_lens_this_time.flatten() > 0).sum().item())
|
||||
num_running_requests = max(1, num_running_requests)
|
||||
real_shape = self.real_bsz_to_captured_size[num_running_requests]
|
||||
else:
|
||||
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
||||
real_shape = ids_remove_padding.shape[0]
|
||||
exist_prefill = kwargs["forward_meta"].exist_prefill
|
||||
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
|
||||
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
|
||||
|
||||
Reference in New Issue
Block a user