fix seqlen sync (#4442)

This commit is contained in:
Ayakouji
2025-10-17 14:37:52 +08:00
committed by GitHub
parent 720697e265
commit a3e0a15495
@@ -106,6 +106,7 @@ class VisionFlashAttention2(nn.Layer):
self,
hidden_states: paddle.Tensor,
cu_seqlens: paddle.Tensor,
max_seqlen: int,
rotary_pos_emb: paddle.Tensor = None,
) -> paddle.Tensor:
"""_summary_
@@ -136,8 +137,6 @@ class VisionFlashAttention2(nn.Layer):
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
softmax_scale = self.head_dim**-0.5
attn_output = (
@@ -380,7 +379,7 @@ class DFNRopeVisionBlock(nn.Layer):
tensor_parallel_degree=tensor_parallel_degree,
)
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor:
def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_pos_emb) -> paddle.Tensor:
"""_summary_
Args:
@@ -395,6 +394,7 @@ class DFNRopeVisionBlock(nn.Layer):
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
rotary_pos_emb=rotary_pos_emb,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@@ -574,6 +574,13 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
cu_seqlens_thw = paddle.repeat_interleave(paddle.tensor([h * w], dtype=paddle.int32), t)
return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw)
def compute_attn_mask_seqlen(
self,
cu_seqlens: paddle.Tensor,
) -> int:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen
def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor, num_pad=0) -> paddle.Tensor:
"""_summary_
@@ -604,15 +611,21 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
else:
cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
max_seqlen=max_seqlen_now,
rotary_pos_emb=rotary_pos_emb,
)