[BugFix] fix mm rope (#7274)

This commit is contained in:
xiaoxiaohehe001
2026-04-14 11:36:08 +08:00
committed by GitHub
parent 8f21c9caa6
commit abba29b348
2 changed files with 5 additions and 5 deletions
@@ -458,12 +458,12 @@ class ErnieVlRotaryEmbedding3D:
# Build position_ids_3d: [bsz, max_position, 3]
position_ids_3d = paddle.tile(
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
paddle.arange(self.max_position, dtype="float32").unsqueeze(0).unsqueeze(-1),
[bsz, 1, 3],
)
for i in range(bsz):
position_ids_cur = position_ids[cumsum_seqlens[i] : cumsum_seqlens[i + 1]]
prefix_max_position_ids = paddle.max(position_ids_cur) + 1
prefix_max_position_ids = paddle.max(position_ids_cur[..., 0]) + 1
dec_pos_ids = paddle.tile(
paddle.arange(max_len_lst[i], dtype="int64").unsqueeze(-1),
[1, 3],
@@ -530,12 +530,12 @@ class QwenVlRotaryEmbedding3D:
bsz = len(cumsum_seqlens) - 1
# position_ids_3d: [bsz, seq_len, 3]
position_ids_3d = paddle.tile(
paddle.arange(self.max_position, dtype="int64").unsqueeze(0).unsqueeze(-1),
paddle.arange(self.max_position, dtype="float32").unsqueeze(0).unsqueeze(-1),
[bsz, 1, 3],
)
for i in range(bsz):
position_ids_cur = position_ids[cumsum_seqlens[i] : cumsum_seqlens[i + 1]]
prefix_max_position_ids = paddle.max(position_ids_cur) + 1
prefix_max_position_ids = paddle.max(position_ids_cur[..., 0]) + 1
dec_pos_ids = paddle.tile(
paddle.arange(max_len_lst[i], dtype="int64").unsqueeze(-1),
[1, 3],