[Optimization] Support FA2/FA3/FA4 with attn_mask_q (#6354)

* support FA4 sm100

* flash attn backend support mask

* flash attn backend run flashmask correct

* add test for flash_attn_backend and flash_attn_func

* check

* add test for fa4

* requirements.txt add fa4 whl

* check test on sm100

* fix CI conflict

* add enable_torch_proxy for flash_mask

* lazy import fa4

* check

* fix tests import

* check test_load_mpt import
This commit is contained in:
chen
2026-02-05 14:39:00 +08:00
committed by GitHub
parent 72edd394d9
commit 29a313a402
22 changed files with 999 additions and 101 deletions
+8
View File
@@ -1116,6 +1116,12 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
int64_t think_end_id, int64_t think_end_id,
int64_t line_break_id); int64_t line_break_id);
std::vector<paddle::Tensor> get_attn_mask_q(
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cu_seqlens_k,
const paddle::optional<paddle::Tensor>& attn_mask_kv,
const int kv_token_num);
PYBIND11_MODULE(fastdeploy_ops, m) { PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", m.def("get_expert_token_num",
&GetExpertTokenNum, &GetExpertTokenNum,
@@ -1722,6 +1728,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&ReasoningPhaseTokenConstraint, &ReasoningPhaseTokenConstraint,
"reasoning_phase_token_constraint function"); "reasoning_phase_token_constraint function");
m.def("get_attn_mask_q", &get_attn_mask_q, "get_attn_mask_q function");
m.def("get_stop", &GetStop, "get_stop function"); m.def("get_stop", &GetStop, "get_stop function");
m.def("set_stop", &SetStop, "set_stop function"); m.def("set_stop", &SetStop, "set_stop function");
+137
View File
@@ -0,0 +1,137 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
__global__ void get_attn_mask_q_kernel(
int* __restrict__ startend_row_indices_ptr,
const int* attn_mask_kv_ptr,
const int* cu_seqlens_q,
const int* cu_seqlens_k,
const int kv_token_num,
const int max_batch_size) {
constexpr int VecSize = 4;
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
int startend_row_vec[4];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (uint32_t cu_seqlens_k_idx = bid * blockDim.x + tid;
cu_seqlens_k_idx < kv_token_num;
cu_seqlens_k_idx += blockDim.x * gridDim.x) {
uint32_t batch_id = 0;
for (int i = 0; i < max_batch_size; ++i) {
if (cu_seqlens_k_idx >= cu_seqlens_k[i] &&
cu_seqlens_k_idx < cu_seqlens_k[i + 1]) {
batch_id = i;
break;
}
}
const uint32_t this_batch_q_start = cu_seqlens_q[batch_id];
const uint32_t this_batch_q_end = cu_seqlens_q[batch_id + 1];
const uint32_t this_batch_q_len = this_batch_q_end - this_batch_q_start;
const uint32_t kv_start = cu_seqlens_k[batch_id];
const uint32_t kv_end = cu_seqlens_k[batch_id + 1];
const uint32_t kv_len = kv_end - kv_start;
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
startend_row_vec[0] = this_batch_q_end;
startend_row_vec[1] = cu_seqlens_q[max_batch_size];
startend_row_vec[2] = 0;
startend_row_vec[3] = this_batch_q_end;
for (int this_batch_q_idx = this_batch_q_start;
this_batch_q_idx < this_batch_q_end;
++this_batch_q_idx) {
// const int append_mask_k_start = attn_mask_kv_ptr ?
// attn_mask_kv_ptr[this_batch_q_idx * 2 + 0] : 0;
const int append_mask_k_end =
attn_mask_kv_ptr ? attn_mask_kv_ptr[this_batch_q_idx * 2 + 1] - 1
: this_batch_q_idx - this_batch_q_start + kv_len -
(this_batch_q_len);
if (cache_k_idx <= append_mask_k_end) {
startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx);
// 可提前跳出循环
break;
}
}
reinterpret_cast<int4*>(startend_row_indices_ptr +
cu_seqlens_k_idx * 4)[0] =
reinterpret_cast<int4*>(startend_row_vec)[0];
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
std::vector<paddle::Tensor> get_attn_mask_q(
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cu_seqlens_k,
const paddle::optional<paddle::Tensor>& attn_mask_kv,
const int kv_token_num) {
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
{1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place());
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
constexpr int block_size = 512;
int grid_size = div_up(kv_token_num, block_size);
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
get_attn_mask_q_kernel<<<grid_size, block_size, 0, cu_seqlens_k.stream()>>>(
attn_mask_startend_row_indices.data<int>(),
attn_mask_kv ? attn_mask_kv.get().data<int>() : nullptr,
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
kv_token_num,
max_batch_size);
#else
launchWithPdlWhenEnabled(
get_attn_mask_q_kernel,
grid_size,
block_size,
0,
cu_seqlens_k.stream(),
attn_mask_startend_row_indices.data<int>(),
attn_mask_kv ? attn_mask_kv.get().data<int>() : nullptr,
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
kv_token_num,
max_batch_size);
#endif
return {attn_mask_startend_row_indices};
}
std::vector<paddle::DataType> GetAttnMaskQInferDtype(
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& cu_seqlens_k_dtype,
const paddle::optional<paddle::DataType>& attn_mask_kv_dtype) {
return {paddle::DataType::INT32};
}
std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& cu_seqlens_k_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
const int kv_token_num) {
return {{1, 1, kv_token_num, 4}};
}
PD_BUILD_STATIC_OP(get_attn_mask_q)
.Inputs({"cu_seqlens_q",
"cu_seqlens_k",
paddle::Optional("attn_mask_offsets")})
.Outputs({"attn_mask_q"})
.Attrs({"kv_token_num: int"})
.SetKernelFn(PD_KERNEL(get_attn_mask_q))
.SetInferShapeFn(PD_INFER_SHAPE(GetAttnMaskQInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetAttnMaskQInferDtype));
+1
View File
@@ -313,6 +313,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/fused_neox_rope_embedding.cu", "gpu_ops/fused_neox_rope_embedding.cu",
"gpu_ops/gelu_tanh.cu", "gpu_ops/gelu_tanh.cu",
"gpu_ops/reasoning_phase_token_constraint.cu", "gpu_ops/reasoning_phase_token_constraint.cu",
"gpu_ops/get_attn_mask_q.cu",
] ]
# pd_disaggregation # pd_disaggregation
@@ -21,12 +21,18 @@ from typing import TYPE_CHECKING, List, Optional
import paddle import paddle
from paddle.nn.functional.flash_attention import flash_attn_unpadded from paddle.nn.functional.flash_attention import flash_attn_unpadded
from paddleformers.utils.log import logger
try: try:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
except: except:
flash_attention_v3_varlen = None flash_attention_v3_varlen = None
try:
from paddle.nn.functional.flash_attention import flashmask_attention
except:
flashmask_attention = None
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import (
@@ -35,6 +41,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
) )
from fastdeploy.model_executor.layers.attention.ops import ( from fastdeploy.model_executor.layers.attention.ops import (
append_attention, append_attention,
get_attn_mask_q,
get_block_shape_and_split_kv_block, get_block_shape_and_split_kv_block,
gqa_rope_write_cache, gqa_rope_write_cache,
init_kv_signal_per_query, init_kv_signal_per_query,
@@ -43,12 +50,16 @@ from fastdeploy.model_executor.layers.attention.ops import (
pre_cache_len_concat, pre_cache_len_concat,
) )
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.utils import get_sm_version
if TYPE_CHECKING: if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
paddle.compat.enable_torch_proxy(scope={"flash_mask"})
flashmask_attention_v4 = None
if current_platform.is_cuda(): if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
else: else:
@@ -56,6 +67,123 @@ else:
import os import os
FLASH_ATTN_VERSION = None
def init_flash_attn_version(fa_version: int = None):
"""
init_flash_attn_version
"""
if current_platform.is_cuda():
global FLASH_ATTN_VERSION
if fa_version is not None:
FLASH_ATTN_VERSION = fa_version
logger.info(f"Force use Flash Attention V{fa_version}.")
return
sm_version = get_sm_version()
if sm_version >= 100:
try:
from flash_mask.cute.interface import flashmask_attention as fa4
global flashmask_attention_v4
flashmask_attention_v4 = fa4
FLASH_ATTN_VERSION = 4
logger.info("The current platform supports Flash Attention V4.")
except ImportError:
pass
if FLASH_ATTN_VERSION is None:
if sm_version >= 89 and any(num >= 89 for num in paddle.version.cuda_archs()):
FLASH_ATTN_VERSION = 3
logger.info("The current platform supports Flash Attention V3.")
else:
FLASH_ATTN_VERSION = 2
logger.info("The current platform only support Flash Attention V2.")
else:
logger.info("Only support CUDA version flash attention.")
def flash_attn_func(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: Optional[paddle.Tensor] = None,
cu_seqlens_k: Optional[paddle.Tensor] = None,
max_seqlen_q: Optional[paddle.Tensor] = None,
max_seqlen_k: Optional[paddle.Tensor] = None,
attn_mask_q: Optional[paddle.Tensor] = None,
causal: bool = True,
num_heads: int = None,
kv_num_heads: int = None,
head_dim: int = 128,
version: Optional[int] = None,
):
if version is None:
if FLASH_ATTN_VERSION is None:
init_flash_attn_version()
version = FLASH_ATTN_VERSION
if version == 4:
assert (
flashmask_attention_v4 is not None
), "Cannot import flashmask_attention from flash_mask.cute.interface, please install it first"
assert attn_mask_q is not None, "FA4 requires attn_mask_q"
assert num_heads is not None
assert kv_num_heads is not None
original_flash_attn_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])[
"FLAGS_flash_attn_version"
]
with paddle.no_grad():
try:
paddle.set_flags({"FLAGS_flash_attn_version": 4})
out = flashmask_attention_v4(
q.reshape([1, -1, num_heads, head_dim]),
k.reshape([1, -1, kv_num_heads, head_dim]),
v.reshape([1, -1, kv_num_heads, head_dim]),
startend_row_indices=attn_mask_q,
causal=False,
return_softmax_lse=True,
training=True,
)
finally:
paddle.set_flags({"FLAGS_flash_attn_version": original_flash_attn_version})
return out
if attn_mask_q is not None:
assert flashmask_attention is not None
out = flashmask_attention(
q.reshape([1, -1, num_heads, head_dim]),
k.reshape([1, -1, kv_num_heads, head_dim]),
v.reshape([1, -1, kv_num_heads, head_dim]),
startend_row_indices=attn_mask_q,
causal=False,
)
else:
if version == 3:
out = flash_attention_v3_varlen(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=causal,
)
else:
out = flash_attn_unpadded(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=causal,
scale=head_dim**-0.5,
training=False,
)
return out
@dataclass @dataclass
class FlashAttentionMetadata(AttentionMetadata): class FlashAttentionMetadata(AttentionMetadata):
@@ -79,6 +207,8 @@ class FlashAttentionMetadata(AttentionMetadata):
max_len_tensor_cpu_decoder: paddle.Tensor = None max_len_tensor_cpu_decoder: paddle.Tensor = None
attn_mask_q: paddle.Tensor = None
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
""" """
@@ -87,7 +217,6 @@ class FlashAttentionBackend(AttentionBackend):
__infer_dynamic_dims_fields__ = ["attention_metadata"] __infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: FlashAttentionMetadata attention_metadata: FlashAttentionMetadata
flash_attn_func: callable = None
def __init__( def __init__(
self, self,
@@ -127,22 +256,9 @@ class FlashAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config) self.rank, self.device_id = init_rank_and_device_id(fd_config)
if self.flash_attn_func is None: self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
prop = paddle.device.cuda.get_device_properties() fd_config.model_config, "use_3d_rope", False
cc = prop.major * 10 + prop.minor )
is_current_sm_supported = cc >= 90
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
if is_current_sm_supported and is_paddle_supported:
self.flash_attn_func = flash_attention_v3_varlen
print("The current platform supports Flash Attention V3.")
self.flash_attn_kwargs = {}
else:
self.flash_attn_func = flash_attn_unpadded
self.flash_attn_kwargs = {"scale": self.head_dim**-0.5, "training": False}
print(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
# Note(ZKK): here must be consistent with append_attn_backend.py # Note(ZKK): here must be consistent with append_attn_backend.py
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024)) self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024))
self.zero_seq_enc_lens_for_decode = paddle.zeros( self.zero_seq_enc_lens_for_decode = paddle.zeros(
@@ -255,6 +371,13 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.max_len_tensor_cpu[2], forward_meta.max_len_tensor_cpu[2],
self.block_size, self.block_size,
) )
if forward_meta.attn_mask_offsets is not None:
metadata.attn_mask_q = get_attn_mask_q(
cu_seqlens_q=forward_meta.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_k,
attn_mask_kv=forward_meta.attn_mask_offsets,
kv_token_num=metadata.kv_token_num_cpu[0].item(),
)
use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0 use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0
@@ -294,16 +417,19 @@ class FlashAttentionBackend(AttentionBackend):
self.rope_3d, self.rope_3d,
) )
res_encoder = self.flash_attn_func( res_encoder = flash_attn_func(
q, q,
k, k,
v, v,
forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_q[: metadata.cu_seqlens_k.shape[0]],
metadata.cu_seqlens_k, metadata.cu_seqlens_k,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0], max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3], max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
attn_mask_q=metadata.attn_mask_q,
causal=self.causal, causal=self.causal,
**self.flash_attn_kwargs, num_heads=self.num_heads,
kv_num_heads=self.kv_num_heads,
head_dim=self.head_dim,
)[0].reshape([-1, self.attn_outputsize_tp]) )[0].reshape([-1, self.attn_outputsize_tp])
res_decoder = append_attention( res_decoder = append_attention(
@@ -16,6 +16,7 @@
from .append_attention import append_attention, append_attention_with_output from .append_attention import append_attention, append_attention_with_output
from .flash_mask_attention import flash_mask_attention from .flash_mask_attention import flash_mask_attention
from .get_attn_mask_q import get_attn_mask_q
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
from .gqa_rope_write_cache import gqa_rope_write_cache from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_kv_signal_per_query import init_kv_signal_per_query from .init_kv_signal_per_query import init_kv_signal_per_query
@@ -33,4 +34,5 @@ __all__ = [
"pre_cache_len_concat", "pre_cache_len_concat",
"init_kv_signal_per_query", "init_kv_signal_per_query",
"flash_mask_attention", "flash_mask_attention",
"get_attn_mask_q",
] ]
@@ -0,0 +1,49 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
get_attn_mask_q as get_attn_mask_q_cuda,
)
def get_attn_mask_q(
cu_seqlens_q: paddle.Tensor,
cu_seqlens_k: paddle.Tensor,
attn_mask_kv: Optional[paddle.Tensor] = None,
kv_token_num: int = 0,
):
"""
get_attn_mask_q
"""
if current_platform.is_cuda():
out = get_attn_mask_q_cuda(
cu_seqlens_q,
cu_seqlens_k,
attn_mask_kv,
kv_token_num,
)
else:
raise NotImplementedError
return out
+9
View File
@@ -19,6 +19,7 @@ import re
from collections.abc import Mapping from collections.abc import Mapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cache
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
import paddle import paddle
@@ -547,3 +548,11 @@ def rename_offline_ckpt_suffix_to_fd_suffix(
return loaded_weight_name return loaded_weight_name
return fn return fn
@cache
def get_sm_version():
if paddle.cuda.is_available():
prop = paddle.device.cuda.get_device_properties()
return prop.major * 10 + prop.minor
return 0
+1
View File
@@ -47,3 +47,4 @@ aistudio_sdk
p2pstore p2pstore
py-cpuinfo py-cpuinfo
flashinfer-python-paddle flashinfer-python-paddle
flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl
+294 -67
View File
@@ -66,55 +66,14 @@ class TestAttentionPerformance(unittest.TestCase):
print("Setting up test environment...") print("Setting up test environment...")
paddle.set_device("gpu") paddle.set_device("gpu")
paddle.set_default_dtype("bfloat16") paddle.set_default_dtype("bfloat16")
prop = paddle.device.cuda.get_device_properties()
self.sm_version = prop.major * 10 + prop.minor
init_distributed_environment() init_distributed_environment()
self.model_dir = tempfile.mkdtemp(prefix="tmp_model_config_")
self.model_dir = self.create_model_config_json() self.create_model_config_json(self.model_dir)
tp_size = paddle.distributed.get_world_size()
self.fd_config = self.create_fd_config_from_model_path(self.model_dir, tensor_parallel_size=tp_size)
self.fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
# Initialize Attention Layer
attn_cls = get_attention_backend()
self.attn_backend = attn_cls(
self.fd_config,
kv_num_heads=self.fd_config.model_config.num_key_value_heads // tp_size,
num_heads=self.fd_config.model_config.num_attention_heads // tp_size,
head_dim=self.fd_config.model_config.head_dim,
encoder_block_shape_q=64,
decoder_block_shape_q=16,
)
num_layers = self.fd_config.model_config.num_hidden_layers
self.attention_layer = [None] * num_layers
for i in range(num_layers):
self.attention_layer[i] = Ernie4_5_Attention(self.fd_config, layer_id=i, prefix="test_layer")
state_dict = self.create_random_attention_state_dict(self.fd_config, prefix="test_layer")
self.attention_layer[i].load_state_dict(state_dict)
def attn_forward(forward_meta, hidden_states):
for i in range(num_layers):
hidden_states = self.attention_layer[i](forward_meta, hidden_states)
return hidden_states
self.attn_forward = attn_forward
self.cache_quant_type_str = getattr(self.attention_layer[0].attn, "cache_quant_type_str", "none")
print("===== Initialization Complete =====")
def tearDown(self):
"""
Clean up the environment after each test.
"""
print("\nTearing down test environment...")
if os.path.exists(self.model_dir):
shutil.rmtree(self.model_dir)
print(f"Successfully removed temporary directory: {self.model_dir}")
# region Helper Functions # region Helper Functions
def create_model_config_json(self) -> str: def create_model_config_json(self, model_dir) -> str:
""" """
Creates a temporary directory and writes the model configuration to a 'config.json' file. Creates a temporary directory and writes the model configuration to a 'config.json' file.
""" """
@@ -129,14 +88,13 @@ class TestAttentionPerformance(unittest.TestCase):
"num_key_value_heads": 4, "num_key_value_heads": 4,
"num_hidden_layers": 2, "num_hidden_layers": 2,
} }
model_dir = tempfile.mkdtemp(prefix="tmp_model_config_")
config_path = os.path.join(model_dir, "config.json") config_path = os.path.join(model_dir, "config.json")
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4) json.dump(config_dict, f, indent=4)
print(f"Successfully created config.json at: {config_path}") print(f"Successfully created config.json at: {config_path}")
return model_dir return model_dir
def create_fd_config_from_model_path(self, model_path, tensor_parallel_size=1): def create_fd_config_from_model_path(self, model_path, tensor_parallel_size=1, quantization: dict = None):
"""Creates a complete FDConfig from a model path.""" """Creates a complete FDConfig from a model path."""
model_args = {"model": model_path, "dtype": "bfloat16"} model_args = {"model": model_path, "dtype": "bfloat16"}
model_config = ModelConfig(model_args) model_config = ModelConfig(model_args)
@@ -156,10 +114,9 @@ class TestAttentionPerformance(unittest.TestCase):
scheduler_config=SchedulerConfig({}), scheduler_config=SchedulerConfig({}),
load_config=LoadConfig({}), load_config=LoadConfig({}),
quant_config=MixQuantConfig( quant_config=MixQuantConfig(
dense_quant_type="block_wise_fp8", dense_quant_type=quantization.get("dense_quant_type", None),
moe_quant_type="block_wise_fp8", moe_quant_type=quantization.get("moe_quant_type", None),
kv_cache_quant_type="float8_e4m3fn", kv_cache_quant_type=quantization.get("kv_cache_quant_type", None),
# kv_cache_quant_type=None,
), ),
graph_opt_config=GraphOptimizationConfig({}), graph_opt_config=GraphOptimizationConfig({}),
commit_config=CommitConfig(), commit_config=CommitConfig(),
@@ -207,14 +164,21 @@ class TestAttentionPerformance(unittest.TestCase):
fd_config: FDConfig, fd_config: FDConfig,
attn_backend: AttentionBackend, attn_backend: AttentionBackend,
cache_quant_type_str: str = "none", cache_quant_type_str: str = "none",
has_attn_mask: bool = False,
) -> ForwardMeta: ) -> ForwardMeta:
""" """
Creates a high-fidelity ForwardMeta object. Creates a high-fidelity ForwardMeta object.
""" """
attn_mask_offsets = None
if mode == ForwardMode.EXTEND: if mode == ForwardMode.EXTEND:
seq_lens_encoder = paddle.full([batch_size], seq_len, dtype="int32") seq_lens_encoder = paddle.full([batch_size], seq_len, dtype="int32")
seq_lens_decoder = paddle.zeros([batch_size], dtype="int32") seq_lens_decoder = paddle.zeros([batch_size], dtype="int32")
seq_lens_this_time = seq_lens_encoder seq_lens_this_time = seq_lens_encoder
if has_attn_mask:
attn_mask_offsets_numpy = np.zeros([batch_size, seq_len, 2], dtype=np.int32)
for i in range(seq_len):
attn_mask_offsets_numpy[:, i, 1] = i + 1
attn_mask_offsets = paddle.to_tensor(attn_mask_offsets_numpy.reshape([-1, 2]))
elif mode == ForwardMode.DECODE: elif mode == ForwardMode.DECODE:
seq_lens_encoder = paddle.zeros([batch_size], dtype="int32") seq_lens_encoder = paddle.zeros([batch_size], dtype="int32")
seq_lens_decoder = paddle.full([batch_size], seq_len, dtype="int32") seq_lens_decoder = paddle.full([batch_size], seq_len, dtype="int32")
@@ -292,31 +256,81 @@ class TestAttentionPerformance(unittest.TestCase):
attn_backend=attn_backend, attn_backend=attn_backend,
forward_mode=ForwardMode.MIXED, forward_mode=ForwardMode.MIXED,
attn_mask=None, attn_mask=None,
attn_mask_offsets=None, attn_mask_offsets=attn_mask_offsets,
**attn_backend_buffers, **attn_backend_buffers,
) )
hidden_states = paddle.randn([token_num, self.fd_config.model_config.hidden_size], dtype="bfloat16") hidden_states = paddle.randn([token_num, fd_config.model_config.hidden_size], dtype="bfloat16")
return forward_meta, hidden_states return forward_meta, hidden_states
def test_decode_performance_with_prefill(self): def tearDown(self):
"""
Clean up the environment after each test.
"""
print("\nTearing down test environment...")
if os.path.exists(self.model_dir):
shutil.rmtree(self.model_dir)
print(f"Successfully removed temporary directory: {self.model_dir}")
def attn_forward(self, attention_layer, forward_meta, hidden_states):
for i in range(len(attention_layer)):
hidden_states = attention_layer[i](forward_meta, hidden_states)
return hidden_states
def test_append_attn_backend_decode_performance_with_prefill(self):
# Test parameters # Test parameters
test_steps = 100 test_steps = 100
prefill_batch_size = 1 prefill_batch_size = 1
prefill_seq_len = 4096 * 2 prefill_seq_len = 4096 * 2
model_dir = self.model_dir
tp_size = paddle.distributed.get_world_size()
quantization = {
"dense_quant_type": "block_wise_fp8",
"moe_quant_type": "block_wise_fp8",
"kv_cache_quant_type": "float8_e4m3fn",
}
fd_config = self.create_fd_config_from_model_path(
model_dir, tensor_parallel_size=tp_size, quantization=quantization
)
fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
# Initialize Attention Layer
os.environ["FD_ATTENTION_BACKEND"] = "APPEND_ATTN"
attn_cls = get_attention_backend()
attn_backend = attn_cls(
fd_config,
kv_num_heads=fd_config.model_config.num_key_value_heads // tp_size,
num_heads=fd_config.model_config.num_attention_heads // tp_size,
head_dim=fd_config.model_config.head_dim,
encoder_block_shape_q=64,
decoder_block_shape_q=16,
)
num_layers = fd_config.model_config.num_hidden_layers
attention_layer = [None] * num_layers
for i in range(num_layers):
attention_layer[i] = Ernie4_5_Attention(fd_config, layer_id=i, prefix="test_layer")
state_dict = self.create_random_attention_state_dict(fd_config, prefix="test_layer")
attention_layer[i].load_state_dict(state_dict)
cache_quant_type_str = getattr(attention_layer[0].attn, "cache_quant_type_str", "none")
print("===== Initialization Complete =====")
forward_meta, prefill_hidden_states = self.create_forward_meta( forward_meta, prefill_hidden_states = self.create_forward_meta(
batch_size=prefill_batch_size, batch_size=prefill_batch_size,
seq_len=prefill_seq_len, seq_len=prefill_seq_len,
mode=ForwardMode.EXTEND, mode=ForwardMode.EXTEND,
fd_config=self.fd_config, fd_config=fd_config,
attn_backend=self.attn_backend, attn_backend=attn_backend,
cache_quant_type_str=self.cache_quant_type_str, cache_quant_type_str=cache_quant_type_str,
) )
self.attn_backend.init_attention_metadata(forward_meta) attn_backend.init_attention_metadata(forward_meta)
self.attn_forward(forward_meta, prefill_hidden_states) self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
paddle.device.synchronize() paddle.device.synchronize()
@@ -333,7 +347,7 @@ class TestAttentionPerformance(unittest.TestCase):
for i in range(test_steps): for i in range(test_steps):
start_events[i].record() start_events[i].record()
self.attn_forward(forward_meta, prefill_hidden_states) self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
end_events[i].record() end_events[i].record()
paddle.device.synchronize() paddle.device.synchronize()
@@ -357,22 +371,22 @@ class TestAttentionPerformance(unittest.TestCase):
batch_size=decode_batch_size, batch_size=decode_batch_size,
seq_len=10 * 1024, seq_len=10 * 1024,
mode=ForwardMode.DECODE, mode=ForwardMode.DECODE,
fd_config=self.fd_config, fd_config=fd_config,
attn_backend=self.attn_backend, attn_backend=attn_backend,
cache_quant_type_str=self.cache_quant_type_str, cache_quant_type_str=cache_quant_type_str,
) )
self.attn_backend.init_attention_metadata(forward_meta) attn_backend.init_attention_metadata(forward_meta)
paddle.device.synchronize() paddle.device.synchronize()
# 必须要先预热一次!因为预处理被放到了第一层再做了! # 必须要先预热一次!因为预处理被放到了第一层再做了!
self.attn_forward(forward_meta, hidden_states) self.attn_forward(attention_layer, forward_meta, hidden_states)
attn_cuda_graphs = graphs.CUDAGraph() attn_cuda_graphs = graphs.CUDAGraph()
attn_cuda_graphs.capture_begin() attn_cuda_graphs.capture_begin()
self.attn_forward(forward_meta, hidden_states) self.attn_forward(attention_layer, forward_meta, hidden_states)
attn_cuda_graphs.capture_end() attn_cuda_graphs.capture_end()
@@ -393,6 +407,219 @@ class TestAttentionPerformance(unittest.TestCase):
# p.stop() # p.stop()
def test_flash_attn_v3(self):
if self.sm_version < 89 or self.sm_version >= 100:
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
# Test parameters
test_steps = 100
prefill_batch_size = 1
prefill_seq_len = 4096 * 2
model_dir = self.model_dir
tp_size = paddle.distributed.get_world_size()
quantization = {
"dense_quant_type": "block_wise_fp8",
"moe_quant_type": "block_wise_fp8",
}
fd_config = self.create_fd_config_from_model_path(
model_dir, tensor_parallel_size=tp_size, quantization=quantization
)
fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
# Initialize Attention Layer
os.environ["FD_ATTENTION_BACKEND"] = "FLASH_ATTN"
paddle.set_flags({"FLAGS_flash_attn_version": 3})
attn_cls = get_attention_backend()
attn_backend = attn_cls(
fd_config,
kv_num_heads=fd_config.model_config.num_key_value_heads // tp_size,
num_heads=fd_config.model_config.num_attention_heads // tp_size,
head_dim=fd_config.model_config.head_dim,
encoder_block_shape_q=64,
decoder_block_shape_q=16,
)
num_layers = fd_config.model_config.num_hidden_layers
attention_layer = [None] * num_layers
for i in range(num_layers):
attention_layer[i] = Ernie4_5_Attention(fd_config, layer_id=i, prefix="test_layer")
state_dict = self.create_random_attention_state_dict(fd_config, prefix="test_layer")
attention_layer[i].load_state_dict(state_dict)
cache_quant_type_str = getattr(attention_layer[0].attn, "cache_quant_type_str", "none")
print("===== flash_attn_v3 Initialization Complete =====")
forward_meta, prefill_hidden_states = self.create_forward_meta(
batch_size=prefill_batch_size,
seq_len=prefill_seq_len,
mode=ForwardMode.EXTEND,
fd_config=fd_config,
attn_backend=attn_backend,
cache_quant_type_str=cache_quant_type_str,
)
attn_backend.init_attention_metadata(forward_meta)
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
paddle.device.synchronize()
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
for i in range(test_steps):
start_events[i].record()
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
end_events[i].record()
paddle.device.synchronize()
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
print(times[-5:])
def test_flash_attn_v3_with_mask(self):
if self.sm_version < 89 or self.sm_version >= 100:
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
# Test parameters
test_steps = 100
prefill_batch_size = 1
prefill_seq_len = 4096 * 2
model_dir = self.model_dir
tp_size = paddle.distributed.get_world_size()
quantization = {
"dense_quant_type": "block_wise_fp8",
"moe_quant_type": "block_wise_fp8",
}
fd_config = self.create_fd_config_from_model_path(
model_dir, tensor_parallel_size=tp_size, quantization=quantization
)
fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
# Initialize Attention Layer
os.environ["FD_ATTENTION_BACKEND"] = "FLASH_ATTN"
paddle.set_flags({"FLAGS_flash_attn_version": 3})
attn_cls = get_attention_backend()
attn_backend = attn_cls(
fd_config,
kv_num_heads=fd_config.model_config.num_key_value_heads // tp_size,
num_heads=fd_config.model_config.num_attention_heads // tp_size,
head_dim=fd_config.model_config.head_dim,
encoder_block_shape_q=64,
decoder_block_shape_q=16,
)
num_layers = fd_config.model_config.num_hidden_layers
attention_layer = [None] * num_layers
for i in range(num_layers):
attention_layer[i] = Ernie4_5_Attention(fd_config, layer_id=i, prefix="test_layer")
state_dict = self.create_random_attention_state_dict(fd_config, prefix="test_layer")
attention_layer[i].load_state_dict(state_dict)
cache_quant_type_str = getattr(attention_layer[0].attn, "cache_quant_type_str", "none")
print("===== flash_attn_v3_with_mask Initialization Complete =====")
forward_meta, prefill_hidden_states = self.create_forward_meta(
batch_size=prefill_batch_size,
seq_len=prefill_seq_len,
mode=ForwardMode.EXTEND,
fd_config=fd_config,
attn_backend=attn_backend,
cache_quant_type_str=cache_quant_type_str,
has_attn_mask=True,
)
attn_backend.init_attention_metadata(forward_meta)
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
paddle.device.synchronize()
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
for i in range(test_steps):
start_events[i].record()
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
end_events[i].record()
paddle.device.synchronize()
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
print(times[-5:])
def test_flash_attn_v4(self):
if self.sm_version < 100:
self.skipTest("Flash Attention V4 requires SM100+.")
# Test parameters
test_steps = 100
prefill_batch_size = 1
prefill_seq_len = 4096 * 2
model_dir = self.model_dir
tp_size = paddle.distributed.get_world_size()
quantization = {
"dense_quant_type": "block_wise_fp8",
"moe_quant_type": "block_wise_fp8",
}
fd_config = self.create_fd_config_from_model_path(
model_dir, tensor_parallel_size=tp_size, quantization=quantization
)
fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
# Initialize Attention Layer
os.environ["FD_ATTENTION_BACKEND"] = "FLASH_ATTN"
attn_cls = get_attention_backend()
attn_backend = attn_cls(
fd_config,
kv_num_heads=fd_config.model_config.num_key_value_heads // tp_size,
num_heads=fd_config.model_config.num_attention_heads // tp_size,
head_dim=fd_config.model_config.head_dim,
encoder_block_shape_q=64,
decoder_block_shape_q=16,
)
num_layers = fd_config.model_config.num_hidden_layers
attention_layer = [None] * num_layers
for i in range(num_layers):
attention_layer[i] = Ernie4_5_Attention(fd_config, layer_id=i, prefix="test_layer")
state_dict = self.create_random_attention_state_dict(fd_config, prefix="test_layer")
attention_layer[i].load_state_dict(state_dict)
cache_quant_type_str = getattr(attention_layer[0].attn, "cache_quant_type_str", "none")
print("===== flash_attn_v4 Initialization Complete =====")
forward_meta, prefill_hidden_states = self.create_forward_meta(
batch_size=prefill_batch_size,
seq_len=prefill_seq_len,
mode=ForwardMode.EXTEND,
fd_config=fd_config,
attn_backend=attn_backend,
cache_quant_type_str=cache_quant_type_str,
)
attn_backend.init_attention_metadata(forward_meta)
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
paddle.device.synchronize()
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
for i in range(test_steps):
start_events[i].record()
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
end_events[i].record()
paddle.device.synchronize()
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
print(times[-5:])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
+206
View File
@@ -0,0 +1,206 @@
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
import paddle
from fastdeploy.model_executor.layers.attention.flash_attn_backend import (
flash_attn_func,
)
class TestFlashAttnFunc(unittest.TestCase):
def setUp(self):
"""
Set up the testing environment before each test..
"""
paddle.set_device("gpu")
paddle.set_default_dtype("bfloat16")
prop = paddle.device.cuda.get_device_properties()
self.sm_version = prop.major * 10 + prop.minor
def test_fa3(self):
if self.sm_version < 89 or self.sm_version >= 100:
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
head_dim = 128
num_heads = 12
kv_num_heads = 4
seq_len = 1024
batch_size = 4
token_num = batch_size * seq_len
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
cu_seqlens_q = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
cu_seqlens_k = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
max_seqlen_q = seq_len
max_seqlen_k = seq_len
attn_mask_q = None
paddle.set_flags({"FLAGS_flash_attn_version": 3})
flash_attn_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
attn_mask_q=attn_mask_q,
causal=True,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_dim=head_dim,
version=3,
)
def test_fa3_with_mask(self):
if self.sm_version < 89 or self.sm_version >= 100:
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
head_dim = 128
num_heads = 12
kv_num_heads = 4
seq_len = 1024
batch_size = 4
token_num = batch_size * seq_len
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
cu_seqlens_q = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
cu_seqlens_k = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
max_seqlen_q = seq_len
max_seqlen_k = seq_len
attn_mask_q = paddle.zeros([1, 1, token_num, 4], dtype=paddle.int32)
for bid in range(batch_size):
attn_mask_q[:, :, seq_len * bid : seq_len * (bid + 1), :2] = seq_len * (bid + 1)
for kv_token_id in range(token_num):
attn_mask_q[:, :, kv_token_id, 3] = kv_token_id
paddle.set_flags({"FLAGS_flash_attn_version": 3})
flash_attn_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
attn_mask_q=attn_mask_q,
causal=True,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_dim=head_dim,
version=3,
)
def test_fa2(self):
head_dim = 128
num_heads = 12
kv_num_heads = 4
seq_len = 1024
batch_size = 4
token_num = batch_size * seq_len
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
cu_seqlens_q = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
cu_seqlens_k = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
max_seqlen_q = seq_len
max_seqlen_k = seq_len
attn_mask_q = None
paddle.set_flags({"FLAGS_flash_attn_version": 2})
flash_attn_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
attn_mask_q=attn_mask_q,
causal=True,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_dim=head_dim,
version=2,
)
def test_fa2_with_mask(self):
head_dim = 128
num_heads = 12
kv_num_heads = 4
seq_len = 1024
batch_size = 4
token_num = batch_size * seq_len
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
cu_seqlens_q = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
cu_seqlens_k = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
max_seqlen_q = seq_len
max_seqlen_k = seq_len
attn_mask_q = paddle.zeros([1, 1, token_num, 4], dtype=paddle.int32)
for bid in range(batch_size):
attn_mask_q[:, :, seq_len * bid : seq_len * (bid + 1), :2] = seq_len * (bid + 1)
for kv_token_id in range(token_num):
attn_mask_q[:, :, kv_token_id, 3] = kv_token_id
paddle.set_flags({"FLAGS_flash_attn_version": 2})
flash_attn_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
attn_mask_q=attn_mask_q,
causal=True,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_dim=head_dim,
version=2,
)
def test_fa4(self):
if self.sm_version < 100:
self.skipTest("Flash Attention V4 requires SM100+.")
head_dim = 128
num_heads = 12
kv_num_heads = 4
seq_len = 1024
batch_size = 4
token_num = batch_size * seq_len
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
attn_mask_q = paddle.zeros([1, 1, token_num, 4], dtype=paddle.int32)
for bid in range(batch_size):
attn_mask_q[:, :, seq_len * bid : seq_len * (bid + 1), :2] = seq_len * (bid + 1)
for kv_token_id in range(token_num):
attn_mask_q[:, :, kv_token_id, 3] = kv_token_id
flash_attn_func(
q,
k,
v,
attn_mask_q=attn_mask_q,
version=4,
)
if __name__ == "__main__":
unittest.main()
+1 -1
View File
@@ -20,6 +20,7 @@ import unittest
os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17") os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17")
import paddle import paddle
from utils import OpPerformanceTester
from fastdeploy.config import ( from fastdeploy.config import (
CacheConfig, CacheConfig,
@@ -38,7 +39,6 @@ from fastdeploy.model_executor.layers.quantization.weight_only import (
WINT8Config, WINT8Config,
) )
from fastdeploy.scheduler import SchedulerConfig from fastdeploy.scheduler import SchedulerConfig
from tests.utils import OpPerformanceTester
paddle.set_default_dtype("bfloat16") paddle.set_default_dtype("bfloat16")
paddle.seed(1024) paddle.seed(1024)
+1 -1
View File
@@ -21,6 +21,7 @@ import unittest
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from utils import OpPerformanceTester
from fastdeploy.config import ( from fastdeploy.config import (
CacheConfig, CacheConfig,
@@ -35,7 +36,6 @@ from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config
from fastdeploy.scheduler import SchedulerConfig from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.worker.worker_process import init_distributed_environment from fastdeploy.worker.worker_process import init_distributed_environment
from tests.utils import OpPerformanceTester
paddle.set_default_dtype("bfloat16") paddle.set_default_dtype("bfloat16")
+3 -3
View File
@@ -6,6 +6,9 @@ import unittest
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
# from fastdeploy.worker.worker_process import init_distributed_environment
from utils import OpPerformanceTester
from fastdeploy.config import ( from fastdeploy.config import (
CacheConfig, CacheConfig,
FDConfig, FDConfig,
@@ -19,9 +22,6 @@ from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.quantization.w4afp8 import W4AFP8Config from fastdeploy.model_executor.layers.quantization.w4afp8 import W4AFP8Config
from fastdeploy.scheduler import SchedulerConfig from fastdeploy.scheduler import SchedulerConfig
# from fastdeploy.worker.worker_process import init_distributed_environment
from tests.utils import OpPerformanceTester
paddle.set_default_dtype("bfloat16") paddle.set_default_dtype("bfloat16")
+2 -2
View File
@@ -26,9 +26,9 @@ import paddle.distributed.fleet as fleet
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPForCausalLM from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPForCausalLM
ROOT = Path(__file__).resolve().parents[2] ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT)) sys.path.insert(0, str(ROOT))
from tests.utils import get_default_test_fd_config from utils import get_default_test_fd_config
strategy = fleet.DistributedStrategy() strategy = fleet.DistributedStrategy()
fleet.init(strategy=strategy) fleet.init(strategy=strategy)
+1 -1
View File
@@ -24,7 +24,7 @@ if project_root not in sys.path:
os.environ["FD_USE_MACHETE"] = "0" os.environ["FD_USE_MACHETE"] = "0"
from tests.model_loader.utils import ( from model_loader.utils import (
check_tokens_id_and_text_close, check_tokens_id_and_text_close,
form_model_get_output_topp0, form_model_get_output_topp0,
get_paddle_model_path, get_paddle_model_path,
+1 -1
View File
@@ -24,7 +24,7 @@ project_root = os.path.abspath(os.path.join(current_dir, ".."))
if project_root not in sys.path: if project_root not in sys.path:
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from tests.model_loader.utils import ( from model_loader.utils import (
form_model_get_output_topp0, form_model_get_output_topp0,
get_torch_model_path, get_torch_model_path,
run_with_timeout, run_with_timeout,
+1 -1
View File
@@ -24,7 +24,7 @@ if project_root not in sys.path:
os.environ["FD_USE_MACHETE"] = "0" os.environ["FD_USE_MACHETE"] = "0"
from tests.model_loader.utils import ( from model_loader.utils import (
calculate_diff_rate, calculate_diff_rate,
form_model_get_output_topp0, form_model_get_output_topp0,
get_torch_model_path, get_torch_model_path,
+132
View File
@@ -19,6 +19,10 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from fastdeploy.model_executor.layers.attention.flash_attn_backend import (
flash_attn_func,
)
from fastdeploy.model_executor.layers.attention.ops import get_attn_mask_q
from fastdeploy.model_executor.ops.gpu import flash_mask_attention from fastdeploy.model_executor.ops.gpu import flash_mask_attention
@@ -31,6 +35,8 @@ class TestFlashMaskAttention(unittest.TestCase):
self.k_len = 1024 self.k_len = 1024
self.head_dim = 128 self.head_dim = 128
np.random.seed(self.q_len) np.random.seed(self.q_len)
prop = paddle.device.cuda.get_device_properties()
self.sm_version = prop.major * 10 + prop.minor
def naive_attn(self, q_input, k_input, v_input, mask): def naive_attn(self, q_input, k_input, v_input, mask):
@@ -101,6 +107,132 @@ class TestFlashMaskAttention(unittest.TestCase):
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item() max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
self.assertLessEqual(max_diff, 0.05) self.assertLessEqual(max_diff, 0.05)
def test_fa4(
self,
):
if self.sm_version < 100:
self.skipTest("Flash Attention V4 requires SM100+.")
q_input = paddle.randn([self.q_len, self.num_head * self.head_dim], dtype="bfloat16")
k_input = paddle.randn([self.q_len + self.k_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
v_input = paddle.randn(k_input.shape, dtype="bfloat16")
mask_start = paddle.zeros([self.q_len], dtype="int32")
mask_end = paddle.zeros([self.q_len], dtype="int32") + self.q_len + self.k_len
mask = paddle.stack([mask_start, mask_end], axis=-1).reshape([-1])
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
bsz = self.bsz
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
cu_seq_k = paddle.arange(bsz + 1) * (self.q_len + self.k_len)
cu_seq_q = cu_seq_q.astype("int32")
cu_seq_k = cu_seq_k.astype("int32")
attn_mask_q = get_attn_mask_q(
cu_seqlens_q=cu_seq_q,
cu_seqlens_k=cu_seq_k,
attn_mask_kv=mask,
kv_token_num=self.q_len + self.k_len,
)
paddle_attn_out = flash_attn_func(
q_input,
k_input,
v_input,
attn_mask_q=attn_mask_q,
num_heads=self.num_head,
kv_num_heads=self.num_kv_head,
head_dim=self.head_dim,
version=4,
)[0].reshape([self.q_len, self.num_head * self.head_dim])
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
self.assertLessEqual(max_diff, 0.05)
def test_fa3_with_mask(
self,
):
if self.sm_version < 89 or self.sm_version >= 100:
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
q_input = paddle.randn([self.q_len, self.num_head * self.head_dim], dtype="bfloat16")
k_input = paddle.randn([self.q_len + self.k_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
v_input = paddle.randn(k_input.shape, dtype="bfloat16")
mask_start = paddle.zeros([self.q_len], dtype="int32")
mask_end = paddle.zeros([self.q_len], dtype="int32") + self.q_len + self.k_len
mask = paddle.stack([mask_start, mask_end], axis=-1).reshape([-1])
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
bsz = self.bsz
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
cu_seq_k = paddle.arange(bsz + 1) * (self.q_len + self.k_len)
cu_seq_q = cu_seq_q.astype("int32")
cu_seq_k = cu_seq_k.astype("int32")
attn_mask_q = get_attn_mask_q(
cu_seqlens_q=cu_seq_q,
cu_seqlens_k=cu_seq_k,
attn_mask_kv=mask,
kv_token_num=self.q_len + self.k_len,
)
paddle.set_flags({"FLAGS_flash_attn_version": 3})
paddle_attn_out = flash_attn_func(
q_input,
k_input,
v_input,
attn_mask_q=attn_mask_q,
num_heads=self.num_head,
kv_num_heads=self.num_kv_head,
head_dim=self.head_dim,
version=3,
)[0].reshape([self.q_len, self.num_head * self.head_dim])
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
self.assertLessEqual(max_diff, 0.05)
def test_fa2_with_mask(
self,
):
q_input = paddle.randn([self.q_len, self.num_head * self.head_dim], dtype="bfloat16")
k_input = paddle.randn([self.q_len + self.k_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
v_input = paddle.randn(k_input.shape, dtype="bfloat16")
mask_start = paddle.zeros([self.q_len], dtype="int32")
mask_end = paddle.zeros([self.q_len], dtype="int32") + self.q_len + self.k_len
mask = paddle.stack([mask_start, mask_end], axis=-1).reshape([-1])
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
bsz = self.bsz
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
cu_seq_k = paddle.arange(bsz + 1) * (self.q_len + self.k_len)
cu_seq_q = cu_seq_q.astype("int32")
cu_seq_k = cu_seq_k.astype("int32")
attn_mask_q = get_attn_mask_q(
cu_seqlens_q=cu_seq_q,
cu_seqlens_k=cu_seq_k,
attn_mask_kv=mask,
kv_token_num=self.q_len + self.k_len,
)
paddle.set_flags({"FLAGS_flash_attn_version": 2})
paddle_attn_out = flash_attn_func(
q_input,
k_input,
v_input,
attn_mask_q=attn_mask_q,
num_heads=self.num_head,
kv_num_heads=self.num_kv_head,
head_dim=self.head_dim,
version=2,
)[0].reshape([self.q_len, self.num_head * self.head_dim])
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
self.assertLessEqual(max_diff, 0.05)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
+1 -1
View File
@@ -18,9 +18,9 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from utils import OpPerformanceTester
from fastdeploy.model_executor.ops.triton_ops import qk_rmsnorm_fused from fastdeploy.model_executor.ops.triton_ops import qk_rmsnorm_fused
from tests.utils import OpPerformanceTester
paddle.set_default_dtype("bfloat16") paddle.set_default_dtype("bfloat16")
paddle.seed(99) paddle.seed(99)
+1 -1
View File
@@ -36,7 +36,7 @@ project_root = os.path.abspath(os.path.join(current_dir, ".."))
if project_root not in sys.path: if project_root not in sys.path:
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from tests.model_loader.utils import get_torch_model_path from model_loader.utils import get_torch_model_path
test_model_configs = { test_model_configs = {
"Qwen3-0.6B": { "Qwen3-0.6B": {
+1 -1
View File
@@ -26,7 +26,7 @@ from fastdeploy.model_executor.layers.quantization.kv_cache import (
) )
sys.path.append("../") sys.path.append("../")
from tests.utils import get_default_test_fd_config from utils import get_default_test_fd_config
class MockLayer(nn.Layer): class MockLayer(nn.Layer):
+1 -1
View File
@@ -18,11 +18,11 @@ import unittest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import paddle import paddle
from utils import FakeModelConfig, get_default_test_fd_config
from fastdeploy.config import SpeculativeConfig from fastdeploy.config import SpeculativeConfig
from fastdeploy.engine.request import Request, RequestType from fastdeploy.engine.request import Request, RequestType
from fastdeploy.spec_decode.mtp import MTPProposer from fastdeploy.spec_decode.mtp import MTPProposer
from tests.utils import FakeModelConfig, get_default_test_fd_config
class TestMTPProposer(unittest.TestCase): class TestMTPProposer(unittest.TestCase):