[Bug Fix] Fix MM mtp incorrect rope emb (#6581)

* [Bug Fix] Fix MM mtp incorrect rope emb
This commit is contained in:
ming1753
2026-03-03 19:28:59 +08:00
committed by GitHub
parent 4ff3f4212f
commit c5eb6b65e7
3 changed files with 56 additions and 19 deletions
+30 -17
View File
@@ -731,15 +731,7 @@ class ProposerInputBatch(InputBatch):
dtype="bfloat16",
)
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)
self.rope_emb = paddle.clone(self.target_model_input_batch["rope_emb"])
# self.caches = self.cache_kvs
# Inherit generation hyperparameters from the main model for consistency
@@ -959,14 +951,35 @@ class ProposerInputBatch(InputBatch):
fill_paddle_tensor(self, "target_hidden_states", 0)
# Reset rope embedding by recreating with default position_ids
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)
if self.enable_mm:
head_dim = self.model_config.head_dim
if "qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type:
rope_head_dim = head_dim
else:
rope_head_dim = head_dim // 2
self.rope_emb = paddle.full(
shape=[
self.scheduler_config.max_num_seqs,
2,
1,
self.model_config.max_model_len,
1,
rope_head_dim,
],
fill_value=0,
dtype="float32",
)
self.image_features = None
self.image_features_list = None
else:
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
base=self.model_config.rope_theta,
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)
# Reset generation hyperparameters from the main model
self.top_p = self.target_model_input_batch["top_p"]