seq_lens related tensor shape -> [max_num_seqs] (#6535)

This commit is contained in:
周周周
2026-03-02 11:18:30 +08:00
committed by GitHub
parent 16a2a323eb
commit d957ccd46d
5 changed files with 14 additions and 12 deletions
+4 -4
View File
@@ -137,11 +137,11 @@ class InputBatch:
self.min_dec_len = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
self.max_dec_len = paddle.full([max_num_seqs, 1], self.model_config.max_model_len, dtype="int64")
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs], 0, dtype="int32")
if self.enable_expert_parallel:
self.seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.seq_lens_this_time = paddle.full([max_num_seqs], 0, dtype="int32")
self.seq_lens_encoder = paddle.full([max_num_seqs], 0, dtype="int32")
self.seq_lens_decoder = paddle.full([max_num_seqs], 0, dtype="int32")
self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.step_seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.prompt_lens = paddle.full([max_num_seqs, 1], 0, dtype="int64")