mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix clear_parameters in draft cudagraph (#7035)
This commit is contained in:
@@ -902,6 +902,7 @@ class ProposerInputBatch(InputBatch):
|
||||
self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"])
|
||||
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
|
||||
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
|
||||
self.index_to_batch_id = {}
|
||||
if current_platform.is_cuda():
|
||||
if "token_ids_all" in self.target_model_input_batch:
|
||||
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
|
||||
@@ -922,8 +923,6 @@ class ProposerInputBatch(InputBatch):
|
||||
self.token_ids_all = None
|
||||
else:
|
||||
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
|
||||
self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"])
|
||||
self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"])
|
||||
self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
|
||||
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])
|
||||
self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"])
|
||||
|
||||
Reference in New Issue
Block a user