Revert "[Bug Fix] Fix MM mtp incorrect rope emb (#6581)" (#6631)

This reverts commit c5eb6b65e7.
This commit is contained in:
ming1753
2026-03-04 11:23:28 +08:00
committed by GitHub
parent aee97e3aae
commit 02d32eea3b
3 changed files with 19 additions and 56 deletions
+17 -30
View File
@@ -731,7 +731,15 @@ class ProposerInputBatch(InputBatch):
dtype="bfloat16",
)
self.rope_emb = paddle.clone(self.target_model_input_batch["rope_emb"])
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)
# self.caches = self.cache_kvs
# Inherit generation hyperparameters from the main model for consistency
@@ -951,35 +959,14 @@ class ProposerInputBatch(InputBatch):
fill_paddle_tensor(self, "target_hidden_states", 0)
# Reset rope embedding by recreating with default position_ids
if self.enable_mm:
head_dim = self.model_config.head_dim
if "qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type:
rope_head_dim = head_dim
else:
rope_head_dim = head_dim // 2
self.rope_emb = paddle.full(
shape=[
self.scheduler_config.max_num_seqs,
2,
1,
self.model_config.max_model_len,
1,
rope_head_dim,
],
fill_value=0,
dtype="float32",
)
self.image_features = None
self.image_features_list = None
else:
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
base=self.model_config.rope_theta,
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)
# Reset generation hyperparameters from the main model
self.top_p = self.target_model_input_batch["top_p"]