mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Support suffix decoding (#6403)
* support suffix decoding
This commit is contained in:
@@ -119,6 +119,10 @@ class CudaGraphPiecewiseBackend:
|
||||
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
||||
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
||||
|
||||
self.speculative_decoding = fd_config.speculative_config.method is not None
|
||||
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||
self.real_bsz_to_captured_size = fd_config.graph_opt_config.real_bsz_to_captured_size
|
||||
|
||||
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
||||
|
||||
if not entry.captured:
|
||||
@@ -153,7 +157,10 @@ class CudaGraphPiecewiseBackend:
|
||||
# Get real shape (total num tokens)
|
||||
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
||||
real_shape = ids_remove_padding.shape[0]
|
||||
|
||||
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
|
||||
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
|
||||
num_running_requests = seq_lens_this_time.squeeze(axis=-1).nonzero(as_tuple=False)[-1].item() + 1
|
||||
real_shape = self.real_bsz_to_captured_size[num_running_requests]
|
||||
exist_prefill = kwargs["forward_meta"].exist_prefill
|
||||
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
|
||||
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
|
||||
@@ -188,6 +195,9 @@ class CudaGraphPiecewiseBackend:
|
||||
|
||||
# Capture a new cuda graph
|
||||
if entry.cuda_graph is None:
|
||||
assert (
|
||||
real_shape == padding_real_shape
|
||||
), f"real_shape:{real_shape} is not equal to padding_real_shape:{padding_real_shape} when capture new graph."
|
||||
# Warmup the model
|
||||
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||||
entry.num_finished_warmup += 1
|
||||
|
||||
Reference in New Issue
Block a user