mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[BugFix] fix seq_lens_this_time init (#6670)
This commit is contained in:
@@ -138,8 +138,7 @@ class InputBatch:
|
||||
self.min_dec_len = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
||||
self.max_dec_len = paddle.full([max_num_seqs, 1], self.model_config.max_model_len, dtype="int64")
|
||||
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
if self.enable_expert_parallel:
|
||||
self.seq_lens_this_time = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.seq_lens_this_time = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.seq_lens_encoder = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.seq_lens_decoder = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
@@ -534,8 +533,7 @@ class InputBatch:
|
||||
|
||||
# Reset sequence length related buffers
|
||||
fill_paddle_tensor(self, "seq_lens_this_time_buffer", 0)
|
||||
if self.enable_expert_parallel:
|
||||
fill_paddle_tensor(self, "seq_lens_this_time", 0)
|
||||
fill_paddle_tensor(self, "seq_lens_this_time", 0)
|
||||
fill_paddle_tensor(self, "seq_lens_encoder", 0)
|
||||
fill_paddle_tensor(self, "seq_lens_decoder", 0)
|
||||
fill_paddle_tensor(self, "step_seq_lens_encoder", 0)
|
||||
|
||||
Reference in New Issue
Block a user