mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
seq_lens related tensor shape -> [max_num_seqs] (#6535)
This commit is contained in:
@@ -137,11 +137,11 @@ class InputBatch:
|
||||
|
||||
self.min_dec_len = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
|
||||
self.max_dec_len = paddle.full([max_num_seqs, 1], self.model_config.max_model_len, dtype="int64")
|
||||
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
if self.enable_expert_parallel:
|
||||
self.seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.seq_lens_this_time = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.seq_lens_encoder = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.seq_lens_decoder = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.step_seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.prompt_lens = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
|
||||
Reference in New Issue
Block a user