[Feature] Support cute cpp Encoder FA4 (#7016)

* add cute cpp fa4

* 删掉注释

* 修正合并错误

* sm_version放到函数内

* ci错误
This commit is contained in:
mpgemm
2026-03-30 10:54:56 +08:00
committed by GitHub
parent 9765fa7313
commit 7a20eaebe8
4 changed files with 132 additions and 14 deletions
+63 -1
View File
@@ -22,7 +22,10 @@ import paddle
from fastdeploy.model_executor.layers.attention.flash_attn_backend import (
flash_attn_func,
)
from fastdeploy.model_executor.layers.attention.ops import get_attn_mask_q
from fastdeploy.model_executor.layers.attention.ops import (
flash_attn_v4,
get_attn_mask_q,
)
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
@@ -109,6 +112,65 @@ class TestFlashMaskAttention(unittest.TestCase):
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
self.assertLessEqual(max_diff, 0.05)
def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k):
"""Causal attention reference implementation for flash_attn_v4 testing."""
bsz = cu_seq_q.shape[0] - 1
q_token_sum, num_head, head_dim = q_input.shape
k_token_sum, num_kv_head, _ = k_input.shape
gqa_group_size = num_head // num_kv_head
qk_scale = 1 / np.sqrt(head_dim)
out = paddle.zeros([num_head, q_token_sum, head_dim], q_input.dtype)
for bi in range(bsz):
q = q_input[cu_seq_q[bi] : cu_seq_q[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy()
k = k_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 2, 0]).astype("float32").numpy()
v = v_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy()
qk = np.matmul(q, np.repeat(k, gqa_group_size, 0))
qk *= qk_scale
condition = np.tril(np.ones(qk.shape), q.shape[1] - k.shape[2])
mask = np.ones(condition.shape).astype("float32") * -1000000
qk = np.where(condition > 0, qk, mask)
qk_max = qk.max(axis=-1, keepdims=True)
qk -= qk_max
qk = np.exp(qk)
exp_sum = qk.sum(axis=-1, keepdims=True)
exp_sum_inv = 1.0 / exp_sum
temp_out = paddle.to_tensor(np.matmul(qk, np.repeat(v, gqa_group_size, 0)))
out[:, cu_seq_q[bi] : cu_seq_q[bi + 1], :] = temp_out * exp_sum_inv
return out.transpose([1, 0, 2])
def test_flash_encoder_attn_fwd(self):
if self.sm_version < 100:
self.skipTest("Flash Encoder Attention V4 requires SM100+.")
q_input = paddle.randn([self.q_len, self.num_head, self.head_dim], dtype="bfloat16")
k_input = paddle.randn([self.q_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
v_input = paddle.randn(k_input.shape, dtype="bfloat16")
mask = paddle.arange(self.q_len).astype("int32") + 1
bsz = self.bsz
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
cu_seq_k = paddle.arange(bsz + 1) * self.q_len
cu_seq_q = cu_seq_q.astype("int32")
cu_seq_k = cu_seq_k.astype("int32")
naive_attn_out = self.causal_attention_naive(q_input, k_input, v_input, cu_seq_q, cu_seq_k)
paddle_attn_out = paddle.empty(q_input.shape, dtype="bfloat16")
flash_attn_v4(
q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
paddle_attn_out,
mask,
)
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
self.assertLessEqual(max_diff, 0.05)
def test_fa4(
self,
):