[Speculative Decoding] Support suffix decoding (#6403)

* support suffix decoding
This commit is contained in:
GoldPancake
2026-02-26 11:42:05 +08:00
committed by GitHub
parent 6d3fede240
commit 2178f2829b
18 changed files with 587 additions and 30 deletions
@@ -119,6 +119,10 @@ class CudaGraphPiecewiseBackend:
if self.fd_config.graph_opt_config.graph_opt_level > 0:
self.cuda_graph_manager = Dy2StCudaGraphManager()
self.speculative_decoding = fd_config.speculative_config.method is not None
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.real_bsz_to_captured_size = fd_config.graph_opt_config.real_bsz_to_captured_size
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
if not entry.captured:
@@ -153,7 +157,10 @@ class CudaGraphPiecewiseBackend:
# Get real shape (total num tokens)
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
num_running_requests = seq_lens_this_time.squeeze(axis=-1).nonzero(as_tuple=False)[-1].item() + 1
real_shape = self.real_bsz_to_captured_size[num_running_requests]
exist_prefill = kwargs["forward_meta"].exist_prefill
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
@@ -188,6 +195,9 @@ class CudaGraphPiecewiseBackend:
# Capture a new cuda graph
if entry.cuda_graph is None:
assert (
real_shape == padding_real_shape
), f"real_shape:{real_shape} is not equal to padding_real_shape:{padding_real_shape} when capture new graph."
# Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1