mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
fix seqlen sync (#4442)
This commit is contained in:
@@ -106,6 +106,7 @@ class VisionFlashAttention2(nn.Layer):
|
||||
self,
|
||||
hidden_states: paddle.Tensor,
|
||||
cu_seqlens: paddle.Tensor,
|
||||
max_seqlen: int,
|
||||
rotary_pos_emb: paddle.Tensor = None,
|
||||
) -> paddle.Tensor:
|
||||
"""_summary_
|
||||
@@ -136,8 +137,6 @@ class VisionFlashAttention2(nn.Layer):
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(axis=0), rotary_pos_emb).squeeze(axis=0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
softmax_scale = self.head_dim**-0.5
|
||||
|
||||
attn_output = (
|
||||
@@ -380,7 +379,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> paddle.Tensor:
|
||||
def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_pos_emb) -> paddle.Tensor:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
@@ -395,6 +394,7 @@ class DFNRopeVisionBlock(nn.Layer):
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@@ -574,6 +574,13 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
cu_seqlens_thw = paddle.repeat_interleave(paddle.tensor([h * w], dtype=paddle.int32), t)
|
||||
return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw)
|
||||
|
||||
def compute_attn_mask_seqlen(
|
||||
self,
|
||||
cu_seqlens: paddle.Tensor,
|
||||
) -> int:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
return max_seqlen
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor, num_pad=0) -> paddle.Tensor:
|
||||
"""_summary_
|
||||
|
||||
@@ -604,15 +611,21 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
|
||||
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
if layer_num in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
max_seqlen_now = max_seqlen_full
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
max_seqlen_now = max_seqlen_window
|
||||
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
max_seqlen=max_seqlen_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user