mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding]Reformat input preprocess for spec decode (#6501)
* add speculate_pre_process kernel * reduce one slice * make d2h async && fix mtp bug for new pre_process * fix * add unitest * fix: code stype formatting * fix * fix: thread race in speculate_preprocess && rename d2h event
This commit is contained in:
@@ -148,6 +148,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.cache_kvs_map: dict = {}
|
||||
self.exist_prefill_flag = False
|
||||
|
||||
if self.speculative_decoding:
|
||||
self._real_output_token_num_host = paddle.empty([1], dtype="int32").pin_memory()
|
||||
self.output_token_num_event = paddle.device.cuda.Event()
|
||||
|
||||
# VL model config:
|
||||
if self.enable_mm:
|
||||
if "ernie" in self.fd_config.model_config.model_type:
|
||||
@@ -1129,6 +1133,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q_output,
|
||||
batch_id_per_token_output,
|
||||
real_output_token_num,
|
||||
) = pre_process(
|
||||
token_num,
|
||||
self.share_inputs["input_ids"],
|
||||
@@ -1150,6 +1155,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False)
|
||||
self.share_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False)
|
||||
|
||||
self._real_output_token_num_host.copy_(real_output_token_num, False)
|
||||
self.output_token_num_event.record()
|
||||
|
||||
# Initialize forward meta data
|
||||
self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run)
|
||||
|
||||
@@ -1750,13 +1758,19 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self._dummy_pooler_run(model_output, model_output)
|
||||
break
|
||||
else:
|
||||
if self.speculative_decoding:
|
||||
self.output_token_num_event.synchronize()
|
||||
real_num = int(self._real_output_token_num_host)
|
||||
real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_num]
|
||||
else:
|
||||
real_batch_id_per_token_output = None
|
||||
hidden_states = rebuild_padding(
|
||||
model_output,
|
||||
self.share_inputs["cu_seqlens_q"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
self.share_inputs["seq_lens_encoder"],
|
||||
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
|
||||
real_batch_id_per_token_output,
|
||||
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
|
||||
)
|
||||
self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts)
|
||||
@@ -2059,6 +2073,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
+ self.share_inputs["is_block_step_cpu"].numpy().sum().item()
|
||||
)
|
||||
|
||||
if self.speculative_decoding:
|
||||
self.output_token_num_event.synchronize()
|
||||
real_num = int(self._real_output_token_num_host)
|
||||
real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_num]
|
||||
|
||||
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)
|
||||
if self.is_pooling_model:
|
||||
pooler_output = self._pool(model_output, num_running_requests)
|
||||
@@ -2117,7 +2136,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
self.share_inputs["seq_lens_encoder"],
|
||||
(self.share_inputs["batch_id_per_token_output"] if self.speculative_decoding else None),
|
||||
(real_batch_id_per_token_output if self.speculative_decoding else None),
|
||||
(self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user