[FDConfig]Remove max_model_len in FDConfig (#4350)

* modify max_model_len

* fix unittest

* fix unittest

---------

Co-authored-by: root <root@yqlcc01-sys-rpm12rzmwjd.yqlcc01.baidu.com>
This commit is contained in:
YuanRisheng
2025-10-11 14:04:17 +08:00
committed by GitHub
parent 365601ea5a
commit a2ec2c4152
36 changed files with 127 additions and 121 deletions
+8 -8
View File
@@ -680,17 +680,17 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs = {}
self.share_inputs["pre_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len],
[max_num_seqs, self.model_config.max_model_len],
-1,
dtype="int64",
)
self.share_inputs["input_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len],
[max_num_seqs, self.model_config.max_model_len],
self.model_config.pad_token_id,
dtype="int64",
)
self.share_inputs["prompt_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len],
[max_num_seqs, self.model_config.max_model_len],
self.model_config.pad_token_id,
dtype="int64",
)
@@ -755,7 +755,7 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32")
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
# TODO(gongshaotian): move to models
if not self.enable_mm:
@@ -768,7 +768,7 @@ class XPUModelRunner(ModelRunnerBase):
# Set block tables
pre_max_block_num = (
self.parallel_config.max_model_len + self.cache_config.block_size - 1
self.model_config.max_model_len + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32")
@@ -805,7 +805,7 @@ class XPUModelRunner(ModelRunnerBase):
max_num_seqs,
2,
1,
self.parallel_config.max_model_len,
self.model_config.max_model_len,
1,
head_dim // 2,
],
@@ -960,7 +960,7 @@ class XPUModelRunner(ModelRunnerBase):
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
"""Set dummy prefill inputs to share_inputs"""
full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - 10)
full_length = min(num_tokens // batch_size, self.model_config.max_model_len - 10)
input_length = int(full_length - 512)
block_num = (
input_length + self.cache_config.block_size - 1
@@ -1344,7 +1344,7 @@ class XPUModelRunner(ModelRunnerBase):
rotary_dim=self.model_config.head_dim,
partial_rotary_factor=1.0,
base=self.model_config.rope_theta,
max_position=self.parallel_config.max_model_len,
max_position=self.model_config.max_model_len,
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
model_type=self.model_config.model_type,
)