mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support cute cpp Encoder FA4 (#7016)
* add cute cpp fa4 * 删掉注释 * 修正合并错误 * sm_version放到函数内 * ci错误
This commit is contained in:
@@ -22,7 +22,10 @@ import paddle
|
||||
from fastdeploy.model_executor.layers.attention.flash_attn_backend import (
|
||||
flash_attn_func,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.ops import get_attn_mask_q
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
flash_attn_v4,
|
||||
get_attn_mask_q,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
|
||||
|
||||
|
||||
@@ -109,6 +112,65 @@ class TestFlashMaskAttention(unittest.TestCase):
|
||||
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
|
||||
self.assertLessEqual(max_diff, 0.05)
|
||||
|
||||
def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k):
|
||||
"""Causal attention reference implementation for flash_attn_v4 testing."""
|
||||
bsz = cu_seq_q.shape[0] - 1
|
||||
q_token_sum, num_head, head_dim = q_input.shape
|
||||
k_token_sum, num_kv_head, _ = k_input.shape
|
||||
gqa_group_size = num_head // num_kv_head
|
||||
qk_scale = 1 / np.sqrt(head_dim)
|
||||
out = paddle.zeros([num_head, q_token_sum, head_dim], q_input.dtype)
|
||||
for bi in range(bsz):
|
||||
q = q_input[cu_seq_q[bi] : cu_seq_q[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy()
|
||||
k = k_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 2, 0]).astype("float32").numpy()
|
||||
v = v_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy()
|
||||
qk = np.matmul(q, np.repeat(k, gqa_group_size, 0))
|
||||
qk *= qk_scale
|
||||
condition = np.tril(np.ones(qk.shape), q.shape[1] - k.shape[2])
|
||||
mask = np.ones(condition.shape).astype("float32") * -1000000
|
||||
qk = np.where(condition > 0, qk, mask)
|
||||
qk_max = qk.max(axis=-1, keepdims=True)
|
||||
qk -= qk_max
|
||||
qk = np.exp(qk)
|
||||
exp_sum = qk.sum(axis=-1, keepdims=True)
|
||||
exp_sum_inv = 1.0 / exp_sum
|
||||
temp_out = paddle.to_tensor(np.matmul(qk, np.repeat(v, gqa_group_size, 0)))
|
||||
out[:, cu_seq_q[bi] : cu_seq_q[bi + 1], :] = temp_out * exp_sum_inv
|
||||
return out.transpose([1, 0, 2])
|
||||
|
||||
def test_flash_encoder_attn_fwd(self):
|
||||
if self.sm_version < 100:
|
||||
self.skipTest("Flash Encoder Attention V4 requires SM100+.")
|
||||
|
||||
q_input = paddle.randn([self.q_len, self.num_head, self.head_dim], dtype="bfloat16")
|
||||
k_input = paddle.randn([self.q_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
|
||||
v_input = paddle.randn(k_input.shape, dtype="bfloat16")
|
||||
|
||||
mask = paddle.arange(self.q_len).astype("int32") + 1
|
||||
|
||||
bsz = self.bsz
|
||||
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
|
||||
cu_seq_k = paddle.arange(bsz + 1) * self.q_len
|
||||
cu_seq_q = cu_seq_q.astype("int32")
|
||||
cu_seq_k = cu_seq_k.astype("int32")
|
||||
|
||||
naive_attn_out = self.causal_attention_naive(q_input, k_input, v_input, cu_seq_q, cu_seq_k)
|
||||
|
||||
paddle_attn_out = paddle.empty(q_input.shape, dtype="bfloat16")
|
||||
|
||||
flash_attn_v4(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
paddle_attn_out,
|
||||
mask,
|
||||
)
|
||||
|
||||
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
|
||||
self.assertLessEqual(max_diff, 0.05)
|
||||
|
||||
def test_fa4(
|
||||
self,
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user