[Metax][Fix] fix issues based #6259 (#6338)

This commit is contained in:
MingkunZhang
2026-02-04 15:21:35 +08:00
committed by GitHub
parent 90db0bdd0d
commit e109fb9a0e
2 changed files with 161 additions and 323 deletions
+11 -4
View File
@@ -20,6 +20,7 @@ from paddleformers.utils.log import logger
from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.logits_processor import build_logits_processors
from fastdeploy.platforms import current_platform
class InputBatch:
@@ -134,23 +135,29 @@ class InputBatch:
self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
if self.enable_expert_parallel:
self.seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
self.seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.step_seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.prompt_lens = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.step_idx = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.not_need_stop = paddle.full([1], False, dtype="bool").pin_memory()
if current_platform.is_maca():
self.not_need_stop = paddle.full([1], False, dtype="bool").cpu()
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").cpu()
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu()
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").cpu()
else:
self.not_need_stop = paddle.full([1], False, dtype="bool").pin_memory()
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory()
self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory()
self.not_need_stop_device = paddle.full([1], False, dtype="bool")
self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory()
self.stop_flags = paddle.full([max_num_seqs, 1], True, dtype="bool")
self.bad_tokens = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64")
self.bad_tokens_len = paddle.full([max_num_seqs], 1, dtype="int64")
self.next_tokens = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.is_block_step = paddle.full([max_num_seqs], False, dtype="bool")
self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory()
self.is_chunk_step = paddle.full([max_num_seqs], False, dtype="bool").cpu()
self.encoder_block_lens = paddle.full([max_num_seqs], 0, dtype="int32")
self.step_block_list = paddle.full([max_num_seqs], -1, dtype="int32")