mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[FDConfig]Remove max_model_len in FDConfig (#4350)
* modify max_model_len * fix unittest * fix unittest --------- Co-authored-by: root <root@yqlcc01-sys-rpm12rzmwjd.yqlcc01.baidu.com>
This commit is contained in:
@@ -680,17 +680,17 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs = {}
|
||||
|
||||
self.share_inputs["pre_ids"] = paddle.full(
|
||||
[max_num_seqs, self.parallel_config.max_model_len],
|
||||
[max_num_seqs, self.model_config.max_model_len],
|
||||
-1,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["input_ids"] = paddle.full(
|
||||
[max_num_seqs, self.parallel_config.max_model_len],
|
||||
[max_num_seqs, self.model_config.max_model_len],
|
||||
self.model_config.pad_token_id,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["prompt_ids"] = paddle.full(
|
||||
[max_num_seqs, self.parallel_config.max_model_len],
|
||||
[max_num_seqs, self.model_config.max_model_len],
|
||||
self.model_config.pad_token_id,
|
||||
dtype="int64",
|
||||
)
|
||||
@@ -755,7 +755,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32")
|
||||
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
|
||||
|
||||
# TODO(gongshaotian): move to models
|
||||
if not self.enable_mm:
|
||||
@@ -768,7 +768,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Set block tables
|
||||
pre_max_block_num = (
|
||||
self.parallel_config.max_model_len + self.cache_config.block_size - 1
|
||||
self.model_config.max_model_len + self.cache_config.block_size - 1
|
||||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
||||
self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32")
|
||||
|
||||
@@ -805,7 +805,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
max_num_seqs,
|
||||
2,
|
||||
1,
|
||||
self.parallel_config.max_model_len,
|
||||
self.model_config.max_model_len,
|
||||
1,
|
||||
head_dim // 2,
|
||||
],
|
||||
@@ -960,7 +960,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
|
||||
"""Set dummy prefill inputs to share_inputs"""
|
||||
full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - 10)
|
||||
full_length = min(num_tokens // batch_size, self.model_config.max_model_len - 10)
|
||||
input_length = int(full_length - 512)
|
||||
block_num = (
|
||||
input_length + self.cache_config.block_size - 1
|
||||
@@ -1344,7 +1344,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
partial_rotary_factor=1.0,
|
||||
base=self.model_config.rope_theta,
|
||||
max_position=self.parallel_config.max_model_len,
|
||||
max_position=self.model_config.max_model_len,
|
||||
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
|
||||
model_type=self.model_config.model_type,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user