[BugFix] fix seq_lens_this_time init (#6670)

This commit is contained in:
sunxin
2026-03-05 17:07:26 +08:00
committed by GitHub
parent fa1906bd6f
commit a79b82ce68
+2 -4
View File
@@ -138,8 +138,7 @@ class InputBatch:
self.min_dec_len = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
self.max_dec_len = paddle.full([max_num_seqs, 1], self.model_config.max_model_len, dtype="int64")
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs], 0, dtype="int32")
if self.enable_expert_parallel:
self.seq_lens_this_time = paddle.full([max_num_seqs], 0, dtype="int32")
self.seq_lens_this_time = paddle.full([max_num_seqs], 0, dtype="int32")
self.seq_lens_encoder = paddle.full([max_num_seqs], 0, dtype="int32")
self.seq_lens_decoder = paddle.full([max_num_seqs], 0, dtype="int32")
self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
@@ -534,8 +533,7 @@ class InputBatch:
# Reset sequence length related buffers
fill_paddle_tensor(self, "seq_lens_this_time_buffer", 0)
if self.enable_expert_parallel:
fill_paddle_tensor(self, "seq_lens_this_time", 0)
fill_paddle_tensor(self, "seq_lens_this_time", 0)
fill_paddle_tensor(self, "seq_lens_encoder", 0)
fill_paddle_tensor(self, "seq_lens_decoder", 0)
fill_paddle_tensor(self, "step_seq_lens_encoder", 0)