mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] Support FA2/FA3/FA4 with attn_mask_q (#6354)
* support FA4 sm100 * flash attn backend support mask * flash attn backend run flashmask correct * add test for flash_attn_backend and flash_attn_func * check * add test for fa4 * requirements.txt add fa4 whl * check test on sm100 * fix CI conflict * add enable_torch_proxy for flash_mask * lazy import fa4 * check * fix tests import * check test_load_mpt import
This commit is contained in:
@@ -1116,6 +1116,12 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
|||||||
int64_t think_end_id,
|
int64_t think_end_id,
|
||||||
int64_t line_break_id);
|
int64_t line_break_id);
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> get_attn_mask_q(
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& cu_seqlens_k,
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
||||||
|
const int kv_token_num);
|
||||||
|
|
||||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||||
m.def("get_expert_token_num",
|
m.def("get_expert_token_num",
|
||||||
&GetExpertTokenNum,
|
&GetExpertTokenNum,
|
||||||
@@ -1722,6 +1728,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
&ReasoningPhaseTokenConstraint,
|
&ReasoningPhaseTokenConstraint,
|
||||||
"reasoning_phase_token_constraint function");
|
"reasoning_phase_token_constraint function");
|
||||||
|
|
||||||
|
m.def("get_attn_mask_q", &get_attn_mask_q, "get_attn_mask_q function");
|
||||||
|
|
||||||
m.def("get_stop", &GetStop, "get_stop function");
|
m.def("get_stop", &GetStop, "get_stop function");
|
||||||
|
|
||||||
m.def("set_stop", &SetStop, "set_stop function");
|
m.def("set_stop", &SetStop, "set_stop function");
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
|
__global__ void get_attn_mask_q_kernel(
|
||||||
|
int* __restrict__ startend_row_indices_ptr,
|
||||||
|
const int* attn_mask_kv_ptr,
|
||||||
|
const int* cu_seqlens_q,
|
||||||
|
const int* cu_seqlens_k,
|
||||||
|
const int kv_token_num,
|
||||||
|
const int max_batch_size) {
|
||||||
|
constexpr int VecSize = 4;
|
||||||
|
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
|
||||||
|
int startend_row_vec[4];
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
cudaGridDependencySynchronize();
|
||||||
|
#endif
|
||||||
|
for (uint32_t cu_seqlens_k_idx = bid * blockDim.x + tid;
|
||||||
|
cu_seqlens_k_idx < kv_token_num;
|
||||||
|
cu_seqlens_k_idx += blockDim.x * gridDim.x) {
|
||||||
|
uint32_t batch_id = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < max_batch_size; ++i) {
|
||||||
|
if (cu_seqlens_k_idx >= cu_seqlens_k[i] &&
|
||||||
|
cu_seqlens_k_idx < cu_seqlens_k[i + 1]) {
|
||||||
|
batch_id = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const uint32_t this_batch_q_start = cu_seqlens_q[batch_id];
|
||||||
|
const uint32_t this_batch_q_end = cu_seqlens_q[batch_id + 1];
|
||||||
|
const uint32_t this_batch_q_len = this_batch_q_end - this_batch_q_start;
|
||||||
|
const uint32_t kv_start = cu_seqlens_k[batch_id];
|
||||||
|
const uint32_t kv_end = cu_seqlens_k[batch_id + 1];
|
||||||
|
const uint32_t kv_len = kv_end - kv_start;
|
||||||
|
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
|
||||||
|
|
||||||
|
startend_row_vec[0] = this_batch_q_end;
|
||||||
|
startend_row_vec[1] = cu_seqlens_q[max_batch_size];
|
||||||
|
startend_row_vec[2] = 0;
|
||||||
|
startend_row_vec[3] = this_batch_q_end;
|
||||||
|
for (int this_batch_q_idx = this_batch_q_start;
|
||||||
|
this_batch_q_idx < this_batch_q_end;
|
||||||
|
++this_batch_q_idx) {
|
||||||
|
// const int append_mask_k_start = attn_mask_kv_ptr ?
|
||||||
|
// attn_mask_kv_ptr[this_batch_q_idx * 2 + 0] : 0;
|
||||||
|
const int append_mask_k_end =
|
||||||
|
attn_mask_kv_ptr ? attn_mask_kv_ptr[this_batch_q_idx * 2 + 1] - 1
|
||||||
|
: this_batch_q_idx - this_batch_q_start + kv_len -
|
||||||
|
(this_batch_q_len);
|
||||||
|
if (cache_k_idx <= append_mask_k_end) {
|
||||||
|
startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx);
|
||||||
|
// 可提前跳出循环
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reinterpret_cast<int4*>(startend_row_indices_ptr +
|
||||||
|
cu_seqlens_k_idx * 4)[0] =
|
||||||
|
reinterpret_cast<int4*>(startend_row_vec)[0];
|
||||||
|
}
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
cudaTriggerProgrammaticLaunchCompletion();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> get_attn_mask_q(
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& cu_seqlens_k,
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
||||||
|
const int kv_token_num) {
|
||||||
|
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
|
||||||
|
{1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place());
|
||||||
|
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
|
||||||
|
constexpr int block_size = 512;
|
||||||
|
int grid_size = div_up(kv_token_num, block_size);
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||||
|
get_attn_mask_q_kernel<<<grid_size, block_size, 0, cu_seqlens_k.stream()>>>(
|
||||||
|
attn_mask_startend_row_indices.data<int>(),
|
||||||
|
attn_mask_kv ? attn_mask_kv.get().data<int>() : nullptr,
|
||||||
|
cu_seqlens_q.data<int>(),
|
||||||
|
cu_seqlens_k.data<int>(),
|
||||||
|
kv_token_num,
|
||||||
|
max_batch_size);
|
||||||
|
#else
|
||||||
|
launchWithPdlWhenEnabled(
|
||||||
|
get_attn_mask_q_kernel,
|
||||||
|
grid_size,
|
||||||
|
block_size,
|
||||||
|
0,
|
||||||
|
cu_seqlens_k.stream(),
|
||||||
|
attn_mask_startend_row_indices.data<int>(),
|
||||||
|
attn_mask_kv ? attn_mask_kv.get().data<int>() : nullptr,
|
||||||
|
cu_seqlens_q.data<int>(),
|
||||||
|
cu_seqlens_k.data<int>(),
|
||||||
|
kv_token_num,
|
||||||
|
max_batch_size);
|
||||||
|
#endif
|
||||||
|
return {attn_mask_startend_row_indices};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> GetAttnMaskQInferDtype(
|
||||||
|
const paddle::DataType& cu_seqlens_q_dtype,
|
||||||
|
const paddle::DataType& cu_seqlens_k_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& attn_mask_kv_dtype) {
|
||||||
|
return {paddle::DataType::INT32};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
|
||||||
|
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||||
|
const std::vector<int64_t>& cu_seqlens_k_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
|
||||||
|
const int kv_token_num) {
|
||||||
|
return {{1, 1, kv_token_num, 4}};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(get_attn_mask_q)
|
||||||
|
.Inputs({"cu_seqlens_q",
|
||||||
|
"cu_seqlens_k",
|
||||||
|
paddle::Optional("attn_mask_offsets")})
|
||||||
|
.Outputs({"attn_mask_q"})
|
||||||
|
.Attrs({"kv_token_num: int"})
|
||||||
|
.SetKernelFn(PD_KERNEL(get_attn_mask_q))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(GetAttnMaskQInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(GetAttnMaskQInferDtype));
|
||||||
@@ -313,6 +313,7 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
"gpu_ops/fused_neox_rope_embedding.cu",
|
"gpu_ops/fused_neox_rope_embedding.cu",
|
||||||
"gpu_ops/gelu_tanh.cu",
|
"gpu_ops/gelu_tanh.cu",
|
||||||
"gpu_ops/reasoning_phase_token_constraint.cu",
|
"gpu_ops/reasoning_phase_token_constraint.cu",
|
||||||
|
"gpu_ops/get_attn_mask_q.cu",
|
||||||
]
|
]
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
|
|||||||
@@ -21,12 +21,18 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||||
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||||
except:
|
except:
|
||||||
flash_attention_v3_varlen = None
|
flash_attention_v3_varlen = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from paddle.nn.functional.flash_attention import flashmask_attention
|
||||||
|
except:
|
||||||
|
flashmask_attention = None
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||||
@@ -35,6 +41,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
|||||||
)
|
)
|
||||||
from fastdeploy.model_executor.layers.attention.ops import (
|
from fastdeploy.model_executor.layers.attention.ops import (
|
||||||
append_attention,
|
append_attention,
|
||||||
|
get_attn_mask_q,
|
||||||
get_block_shape_and_split_kv_block,
|
get_block_shape_and_split_kv_block,
|
||||||
gqa_rope_write_cache,
|
gqa_rope_write_cache,
|
||||||
init_kv_signal_per_query,
|
init_kv_signal_per_query,
|
||||||
@@ -43,12 +50,16 @@ from fastdeploy.model_executor.layers.attention.ops import (
|
|||||||
pre_cache_len_concat,
|
pre_cache_len_concat,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||||
|
from fastdeploy.model_executor.utils import get_sm_version
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
|
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
|
paddle.compat.enable_torch_proxy(scope={"flash_mask"})
|
||||||
|
flashmask_attention_v4 = None
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
|
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
|
||||||
else:
|
else:
|
||||||
@@ -56,6 +67,123 @@ else:
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
FLASH_ATTN_VERSION = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_flash_attn_version(fa_version: int = None):
|
||||||
|
"""
|
||||||
|
init_flash_attn_version
|
||||||
|
"""
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
global FLASH_ATTN_VERSION
|
||||||
|
if fa_version is not None:
|
||||||
|
FLASH_ATTN_VERSION = fa_version
|
||||||
|
logger.info(f"Force use Flash Attention V{fa_version}.")
|
||||||
|
return
|
||||||
|
sm_version = get_sm_version()
|
||||||
|
if sm_version >= 100:
|
||||||
|
try:
|
||||||
|
from flash_mask.cute.interface import flashmask_attention as fa4
|
||||||
|
|
||||||
|
global flashmask_attention_v4
|
||||||
|
flashmask_attention_v4 = fa4
|
||||||
|
FLASH_ATTN_VERSION = 4
|
||||||
|
logger.info("The current platform supports Flash Attention V4.")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if FLASH_ATTN_VERSION is None:
|
||||||
|
if sm_version >= 89 and any(num >= 89 for num in paddle.version.cuda_archs()):
|
||||||
|
FLASH_ATTN_VERSION = 3
|
||||||
|
logger.info("The current platform supports Flash Attention V3.")
|
||||||
|
else:
|
||||||
|
FLASH_ATTN_VERSION = 2
|
||||||
|
logger.info("The current platform only support Flash Attention V2.")
|
||||||
|
else:
|
||||||
|
logger.info("Only support CUDA version flash attention.")
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_func(
|
||||||
|
q: paddle.Tensor,
|
||||||
|
k: paddle.Tensor,
|
||||||
|
v: paddle.Tensor,
|
||||||
|
cu_seqlens_q: Optional[paddle.Tensor] = None,
|
||||||
|
cu_seqlens_k: Optional[paddle.Tensor] = None,
|
||||||
|
max_seqlen_q: Optional[paddle.Tensor] = None,
|
||||||
|
max_seqlen_k: Optional[paddle.Tensor] = None,
|
||||||
|
attn_mask_q: Optional[paddle.Tensor] = None,
|
||||||
|
causal: bool = True,
|
||||||
|
num_heads: int = None,
|
||||||
|
kv_num_heads: int = None,
|
||||||
|
head_dim: int = 128,
|
||||||
|
version: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if version is None:
|
||||||
|
if FLASH_ATTN_VERSION is None:
|
||||||
|
init_flash_attn_version()
|
||||||
|
version = FLASH_ATTN_VERSION
|
||||||
|
if version == 4:
|
||||||
|
assert (
|
||||||
|
flashmask_attention_v4 is not None
|
||||||
|
), "Cannot import flashmask_attention from flash_mask.cute.interface, please install it first"
|
||||||
|
assert attn_mask_q is not None, "FA4 requires attn_mask_q"
|
||||||
|
assert num_heads is not None
|
||||||
|
assert kv_num_heads is not None
|
||||||
|
original_flash_attn_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])[
|
||||||
|
"FLAGS_flash_attn_version"
|
||||||
|
]
|
||||||
|
with paddle.no_grad():
|
||||||
|
try:
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": 4})
|
||||||
|
out = flashmask_attention_v4(
|
||||||
|
q.reshape([1, -1, num_heads, head_dim]),
|
||||||
|
k.reshape([1, -1, kv_num_heads, head_dim]),
|
||||||
|
v.reshape([1, -1, kv_num_heads, head_dim]),
|
||||||
|
startend_row_indices=attn_mask_q,
|
||||||
|
causal=False,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
training=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": original_flash_attn_version})
|
||||||
|
return out
|
||||||
|
|
||||||
|
if attn_mask_q is not None:
|
||||||
|
assert flashmask_attention is not None
|
||||||
|
out = flashmask_attention(
|
||||||
|
q.reshape([1, -1, num_heads, head_dim]),
|
||||||
|
k.reshape([1, -1, kv_num_heads, head_dim]),
|
||||||
|
v.reshape([1, -1, kv_num_heads, head_dim]),
|
||||||
|
startend_row_indices=attn_mask_q,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if version == 3:
|
||||||
|
out = flash_attention_v3_varlen(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = flash_attn_unpadded(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
causal=causal,
|
||||||
|
scale=head_dim**-0.5,
|
||||||
|
training=False,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashAttentionMetadata(AttentionMetadata):
|
class FlashAttentionMetadata(AttentionMetadata):
|
||||||
@@ -79,6 +207,8 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
max_len_tensor_cpu_decoder: paddle.Tensor = None
|
max_len_tensor_cpu_decoder: paddle.Tensor = None
|
||||||
|
|
||||||
|
attn_mask_q: paddle.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
"""
|
"""
|
||||||
@@ -87,7 +217,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||||
attention_metadata: FlashAttentionMetadata
|
attention_metadata: FlashAttentionMetadata
|
||||||
flash_attn_func: callable = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -127,22 +256,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||||
|
|
||||||
if self.flash_attn_func is None:
|
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
|
||||||
prop = paddle.device.cuda.get_device_properties()
|
fd_config.model_config, "use_3d_rope", False
|
||||||
cc = prop.major * 10 + prop.minor
|
)
|
||||||
is_current_sm_supported = cc >= 90
|
|
||||||
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
|
|
||||||
if is_current_sm_supported and is_paddle_supported:
|
|
||||||
self.flash_attn_func = flash_attention_v3_varlen
|
|
||||||
print("The current platform supports Flash Attention V3.")
|
|
||||||
self.flash_attn_kwargs = {}
|
|
||||||
else:
|
|
||||||
self.flash_attn_func = flash_attn_unpadded
|
|
||||||
self.flash_attn_kwargs = {"scale": self.head_dim**-0.5, "training": False}
|
|
||||||
print(
|
|
||||||
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
|
|
||||||
)
|
|
||||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
|
||||||
# Note(ZKK): here must be consistent with append_attn_backend.py
|
# Note(ZKK): here must be consistent with append_attn_backend.py
|
||||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024))
|
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024))
|
||||||
self.zero_seq_enc_lens_for_decode = paddle.zeros(
|
self.zero_seq_enc_lens_for_decode = paddle.zeros(
|
||||||
@@ -255,6 +371,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.max_len_tensor_cpu[2],
|
forward_meta.max_len_tensor_cpu[2],
|
||||||
self.block_size,
|
self.block_size,
|
||||||
)
|
)
|
||||||
|
if forward_meta.attn_mask_offsets is not None:
|
||||||
|
metadata.attn_mask_q = get_attn_mask_q(
|
||||||
|
cu_seqlens_q=forward_meta.cu_seqlens_q,
|
||||||
|
cu_seqlens_k=metadata.cu_seqlens_k,
|
||||||
|
attn_mask_kv=forward_meta.attn_mask_offsets,
|
||||||
|
kv_token_num=metadata.kv_token_num_cpu[0].item(),
|
||||||
|
)
|
||||||
|
|
||||||
use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0
|
use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0
|
||||||
|
|
||||||
@@ -294,16 +417,19 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.rope_3d,
|
self.rope_3d,
|
||||||
)
|
)
|
||||||
|
|
||||||
res_encoder = self.flash_attn_func(
|
res_encoder = flash_attn_func(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q[: metadata.cu_seqlens_k.shape[0]],
|
||||||
metadata.cu_seqlens_k,
|
metadata.cu_seqlens_k,
|
||||||
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
|
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
|
||||||
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
|
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
|
||||||
|
attn_mask_q=metadata.attn_mask_q,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
**self.flash_attn_kwargs,
|
num_heads=self.num_heads,
|
||||||
|
kv_num_heads=self.kv_num_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
)[0].reshape([-1, self.attn_outputsize_tp])
|
)[0].reshape([-1, self.attn_outputsize_tp])
|
||||||
|
|
||||||
res_decoder = append_attention(
|
res_decoder = append_attention(
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from .append_attention import append_attention, append_attention_with_output
|
from .append_attention import append_attention, append_attention_with_output
|
||||||
from .flash_mask_attention import flash_mask_attention
|
from .flash_mask_attention import flash_mask_attention
|
||||||
|
from .get_attn_mask_q import get_attn_mask_q
|
||||||
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
|
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
|
||||||
from .gqa_rope_write_cache import gqa_rope_write_cache
|
from .gqa_rope_write_cache import gqa_rope_write_cache
|
||||||
from .init_kv_signal_per_query import init_kv_signal_per_query
|
from .init_kv_signal_per_query import init_kv_signal_per_query
|
||||||
@@ -33,4 +34,5 @@ __all__ = [
|
|||||||
"pre_cache_len_concat",
|
"pre_cache_len_concat",
|
||||||
"init_kv_signal_per_query",
|
"init_kv_signal_per_query",
|
||||||
"flash_mask_attention",
|
"flash_mask_attention",
|
||||||
|
"get_attn_mask_q",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
|
get_attn_mask_q as get_attn_mask_q_cuda,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_mask_q(
|
||||||
|
cu_seqlens_q: paddle.Tensor,
|
||||||
|
cu_seqlens_k: paddle.Tensor,
|
||||||
|
attn_mask_kv: Optional[paddle.Tensor] = None,
|
||||||
|
kv_token_num: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
get_attn_mask_q
|
||||||
|
"""
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
out = get_attn_mask_q_cuda(
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
attn_mask_kv,
|
||||||
|
kv_token_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return out
|
||||||
@@ -19,6 +19,7 @@ import re
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cache
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
@@ -547,3 +548,11 @@ def rename_offline_ckpt_suffix_to_fd_suffix(
|
|||||||
return loaded_weight_name
|
return loaded_weight_name
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_sm_version():
|
||||||
|
if paddle.cuda.is_available():
|
||||||
|
prop = paddle.device.cuda.get_device_properties()
|
||||||
|
return prop.major * 10 + prop.minor
|
||||||
|
return 0
|
||||||
|
|||||||
@@ -47,3 +47,4 @@ aistudio_sdk
|
|||||||
p2pstore
|
p2pstore
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
flashinfer-python-paddle
|
flashinfer-python-paddle
|
||||||
|
flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl
|
||||||
|
|||||||
@@ -66,55 +66,14 @@ class TestAttentionPerformance(unittest.TestCase):
|
|||||||
print("Setting up test environment...")
|
print("Setting up test environment...")
|
||||||
paddle.set_device("gpu")
|
paddle.set_device("gpu")
|
||||||
paddle.set_default_dtype("bfloat16")
|
paddle.set_default_dtype("bfloat16")
|
||||||
|
prop = paddle.device.cuda.get_device_properties()
|
||||||
|
self.sm_version = prop.major * 10 + prop.minor
|
||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
|
self.model_dir = tempfile.mkdtemp(prefix="tmp_model_config_")
|
||||||
self.model_dir = self.create_model_config_json()
|
self.create_model_config_json(self.model_dir)
|
||||||
tp_size = paddle.distributed.get_world_size()
|
|
||||||
|
|
||||||
self.fd_config = self.create_fd_config_from_model_path(self.model_dir, tensor_parallel_size=tp_size)
|
|
||||||
self.fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
|
|
||||||
|
|
||||||
# Initialize Attention Layer
|
|
||||||
attn_cls = get_attention_backend()
|
|
||||||
self.attn_backend = attn_cls(
|
|
||||||
self.fd_config,
|
|
||||||
kv_num_heads=self.fd_config.model_config.num_key_value_heads // tp_size,
|
|
||||||
num_heads=self.fd_config.model_config.num_attention_heads // tp_size,
|
|
||||||
head_dim=self.fd_config.model_config.head_dim,
|
|
||||||
encoder_block_shape_q=64,
|
|
||||||
decoder_block_shape_q=16,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_layers = self.fd_config.model_config.num_hidden_layers
|
|
||||||
self.attention_layer = [None] * num_layers
|
|
||||||
for i in range(num_layers):
|
|
||||||
self.attention_layer[i] = Ernie4_5_Attention(self.fd_config, layer_id=i, prefix="test_layer")
|
|
||||||
state_dict = self.create_random_attention_state_dict(self.fd_config, prefix="test_layer")
|
|
||||||
self.attention_layer[i].load_state_dict(state_dict)
|
|
||||||
|
|
||||||
def attn_forward(forward_meta, hidden_states):
|
|
||||||
for i in range(num_layers):
|
|
||||||
hidden_states = self.attention_layer[i](forward_meta, hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
self.attn_forward = attn_forward
|
|
||||||
|
|
||||||
self.cache_quant_type_str = getattr(self.attention_layer[0].attn, "cache_quant_type_str", "none")
|
|
||||||
|
|
||||||
print("===== Initialization Complete =====")
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
"""
|
|
||||||
Clean up the environment after each test.
|
|
||||||
"""
|
|
||||||
print("\nTearing down test environment...")
|
|
||||||
if os.path.exists(self.model_dir):
|
|
||||||
shutil.rmtree(self.model_dir)
|
|
||||||
print(f"Successfully removed temporary directory: {self.model_dir}")
|
|
||||||
|
|
||||||
# region Helper Functions
|
# region Helper Functions
|
||||||
def create_model_config_json(self) -> str:
|
def create_model_config_json(self, model_dir) -> str:
|
||||||
"""
|
"""
|
||||||
Creates a temporary directory and writes the model configuration to a 'config.json' file.
|
Creates a temporary directory and writes the model configuration to a 'config.json' file.
|
||||||
"""
|
"""
|
||||||
@@ -129,14 +88,13 @@ class TestAttentionPerformance(unittest.TestCase):
|
|||||||
"num_key_value_heads": 4,
|
"num_key_value_heads": 4,
|
||||||
"num_hidden_layers": 2,
|
"num_hidden_layers": 2,
|
||||||
}
|
}
|
||||||
model_dir = tempfile.mkdtemp(prefix="tmp_model_config_")
|
|
||||||
config_path = os.path.join(model_dir, "config.json")
|
config_path = os.path.join(model_dir, "config.json")
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_dict, f, indent=4)
|
json.dump(config_dict, f, indent=4)
|
||||||
print(f"Successfully created config.json at: {config_path}")
|
print(f"Successfully created config.json at: {config_path}")
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
def create_fd_config_from_model_path(self, model_path, tensor_parallel_size=1):
|
def create_fd_config_from_model_path(self, model_path, tensor_parallel_size=1, quantization: dict = None):
|
||||||
"""Creates a complete FDConfig from a model path."""
|
"""Creates a complete FDConfig from a model path."""
|
||||||
model_args = {"model": model_path, "dtype": "bfloat16"}
|
model_args = {"model": model_path, "dtype": "bfloat16"}
|
||||||
model_config = ModelConfig(model_args)
|
model_config = ModelConfig(model_args)
|
||||||
@@ -156,10 +114,9 @@ class TestAttentionPerformance(unittest.TestCase):
|
|||||||
scheduler_config=SchedulerConfig({}),
|
scheduler_config=SchedulerConfig({}),
|
||||||
load_config=LoadConfig({}),
|
load_config=LoadConfig({}),
|
||||||
quant_config=MixQuantConfig(
|
quant_config=MixQuantConfig(
|
||||||
dense_quant_type="block_wise_fp8",
|
dense_quant_type=quantization.get("dense_quant_type", None),
|
||||||
moe_quant_type="block_wise_fp8",
|
moe_quant_type=quantization.get("moe_quant_type", None),
|
||||||
kv_cache_quant_type="float8_e4m3fn",
|
kv_cache_quant_type=quantization.get("kv_cache_quant_type", None),
|
||||||
# kv_cache_quant_type=None,
|
|
||||||
),
|
),
|
||||||
graph_opt_config=GraphOptimizationConfig({}),
|
graph_opt_config=GraphOptimizationConfig({}),
|
||||||
commit_config=CommitConfig(),
|
commit_config=CommitConfig(),
|
||||||
@@ -207,14 +164,21 @@ class TestAttentionPerformance(unittest.TestCase):
|
|||||||
fd_config: FDConfig,
|
fd_config: FDConfig,
|
||||||
attn_backend: AttentionBackend,
|
attn_backend: AttentionBackend,
|
||||||
cache_quant_type_str: str = "none",
|
cache_quant_type_str: str = "none",
|
||||||
|
has_attn_mask: bool = False,
|
||||||
) -> ForwardMeta:
|
) -> ForwardMeta:
|
||||||
"""
|
"""
|
||||||
Creates a high-fidelity ForwardMeta object.
|
Creates a high-fidelity ForwardMeta object.
|
||||||
"""
|
"""
|
||||||
|
attn_mask_offsets = None
|
||||||
if mode == ForwardMode.EXTEND:
|
if mode == ForwardMode.EXTEND:
|
||||||
seq_lens_encoder = paddle.full([batch_size], seq_len, dtype="int32")
|
seq_lens_encoder = paddle.full([batch_size], seq_len, dtype="int32")
|
||||||
seq_lens_decoder = paddle.zeros([batch_size], dtype="int32")
|
seq_lens_decoder = paddle.zeros([batch_size], dtype="int32")
|
||||||
seq_lens_this_time = seq_lens_encoder
|
seq_lens_this_time = seq_lens_encoder
|
||||||
|
if has_attn_mask:
|
||||||
|
attn_mask_offsets_numpy = np.zeros([batch_size, seq_len, 2], dtype=np.int32)
|
||||||
|
for i in range(seq_len):
|
||||||
|
attn_mask_offsets_numpy[:, i, 1] = i + 1
|
||||||
|
attn_mask_offsets = paddle.to_tensor(attn_mask_offsets_numpy.reshape([-1, 2]))
|
||||||
elif mode == ForwardMode.DECODE:
|
elif mode == ForwardMode.DECODE:
|
||||||
seq_lens_encoder = paddle.zeros([batch_size], dtype="int32")
|
seq_lens_encoder = paddle.zeros([batch_size], dtype="int32")
|
||||||
seq_lens_decoder = paddle.full([batch_size], seq_len, dtype="int32")
|
seq_lens_decoder = paddle.full([batch_size], seq_len, dtype="int32")
|
||||||
@@ -292,31 +256,81 @@ class TestAttentionPerformance(unittest.TestCase):
|
|||||||
attn_backend=attn_backend,
|
attn_backend=attn_backend,
|
||||||
forward_mode=ForwardMode.MIXED,
|
forward_mode=ForwardMode.MIXED,
|
||||||
attn_mask=None,
|
attn_mask=None,
|
||||||
attn_mask_offsets=None,
|
attn_mask_offsets=attn_mask_offsets,
|
||||||
**attn_backend_buffers,
|
**attn_backend_buffers,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = paddle.randn([token_num, self.fd_config.model_config.hidden_size], dtype="bfloat16")
|
hidden_states = paddle.randn([token_num, fd_config.model_config.hidden_size], dtype="bfloat16")
|
||||||
return forward_meta, hidden_states
|
return forward_meta, hidden_states
|
||||||
|
|
||||||
def test_decode_performance_with_prefill(self):
|
def tearDown(self):
|
||||||
|
"""
|
||||||
|
Clean up the environment after each test.
|
||||||
|
"""
|
||||||
|
print("\nTearing down test environment...")
|
||||||
|
if os.path.exists(self.model_dir):
|
||||||
|
shutil.rmtree(self.model_dir)
|
||||||
|
print(f"Successfully removed temporary directory: {self.model_dir}")
|
||||||
|
|
||||||
|
def attn_forward(self, attention_layer, forward_meta, hidden_states):
|
||||||
|
for i in range(len(attention_layer)):
|
||||||
|
hidden_states = attention_layer[i](forward_meta, hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def test_append_attn_backend_decode_performance_with_prefill(self):
|
||||||
# Test parameters
|
# Test parameters
|
||||||
test_steps = 100
|
test_steps = 100
|
||||||
|
|
||||||
prefill_batch_size = 1
|
prefill_batch_size = 1
|
||||||
prefill_seq_len = 4096 * 2
|
prefill_seq_len = 4096 * 2
|
||||||
|
|
||||||
|
model_dir = self.model_dir
|
||||||
|
tp_size = paddle.distributed.get_world_size()
|
||||||
|
quantization = {
|
||||||
|
"dense_quant_type": "block_wise_fp8",
|
||||||
|
"moe_quant_type": "block_wise_fp8",
|
||||||
|
"kv_cache_quant_type": "float8_e4m3fn",
|
||||||
|
}
|
||||||
|
fd_config = self.create_fd_config_from_model_path(
|
||||||
|
model_dir, tensor_parallel_size=tp_size, quantization=quantization
|
||||||
|
)
|
||||||
|
fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
|
||||||
|
|
||||||
|
# Initialize Attention Layer
|
||||||
|
os.environ["FD_ATTENTION_BACKEND"] = "APPEND_ATTN"
|
||||||
|
attn_cls = get_attention_backend()
|
||||||
|
attn_backend = attn_cls(
|
||||||
|
fd_config,
|
||||||
|
kv_num_heads=fd_config.model_config.num_key_value_heads // tp_size,
|
||||||
|
num_heads=fd_config.model_config.num_attention_heads // tp_size,
|
||||||
|
head_dim=fd_config.model_config.head_dim,
|
||||||
|
encoder_block_shape_q=64,
|
||||||
|
decoder_block_shape_q=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_layers = fd_config.model_config.num_hidden_layers
|
||||||
|
attention_layer = [None] * num_layers
|
||||||
|
for i in range(num_layers):
|
||||||
|
attention_layer[i] = Ernie4_5_Attention(fd_config, layer_id=i, prefix="test_layer")
|
||||||
|
state_dict = self.create_random_attention_state_dict(fd_config, prefix="test_layer")
|
||||||
|
attention_layer[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
|
cache_quant_type_str = getattr(attention_layer[0].attn, "cache_quant_type_str", "none")
|
||||||
|
|
||||||
|
print("===== Initialization Complete =====")
|
||||||
|
|
||||||
forward_meta, prefill_hidden_states = self.create_forward_meta(
|
forward_meta, prefill_hidden_states = self.create_forward_meta(
|
||||||
batch_size=prefill_batch_size,
|
batch_size=prefill_batch_size,
|
||||||
seq_len=prefill_seq_len,
|
seq_len=prefill_seq_len,
|
||||||
mode=ForwardMode.EXTEND,
|
mode=ForwardMode.EXTEND,
|
||||||
fd_config=self.fd_config,
|
fd_config=fd_config,
|
||||||
attn_backend=self.attn_backend,
|
attn_backend=attn_backend,
|
||||||
cache_quant_type_str=self.cache_quant_type_str,
|
cache_quant_type_str=cache_quant_type_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backend.init_attention_metadata(forward_meta)
|
attn_backend.init_attention_metadata(forward_meta)
|
||||||
self.attn_forward(forward_meta, prefill_hidden_states)
|
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
|
||||||
|
|
||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
|
|
||||||
@@ -333,7 +347,7 @@ class TestAttentionPerformance(unittest.TestCase):
|
|||||||
for i in range(test_steps):
|
for i in range(test_steps):
|
||||||
start_events[i].record()
|
start_events[i].record()
|
||||||
|
|
||||||
self.attn_forward(forward_meta, prefill_hidden_states)
|
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
|
||||||
|
|
||||||
end_events[i].record()
|
end_events[i].record()
|
||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
@@ -357,22 +371,22 @@ class TestAttentionPerformance(unittest.TestCase):
|
|||||||
batch_size=decode_batch_size,
|
batch_size=decode_batch_size,
|
||||||
seq_len=10 * 1024,
|
seq_len=10 * 1024,
|
||||||
mode=ForwardMode.DECODE,
|
mode=ForwardMode.DECODE,
|
||||||
fd_config=self.fd_config,
|
fd_config=fd_config,
|
||||||
attn_backend=self.attn_backend,
|
attn_backend=attn_backend,
|
||||||
cache_quant_type_str=self.cache_quant_type_str,
|
cache_quant_type_str=cache_quant_type_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backend.init_attention_metadata(forward_meta)
|
attn_backend.init_attention_metadata(forward_meta)
|
||||||
|
|
||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
|
|
||||||
# 必须要先预热一次!因为预处理被放到了第一层再做了!
|
# 必须要先预热一次!因为预处理被放到了第一层再做了!
|
||||||
self.attn_forward(forward_meta, hidden_states)
|
self.attn_forward(attention_layer, forward_meta, hidden_states)
|
||||||
|
|
||||||
attn_cuda_graphs = graphs.CUDAGraph()
|
attn_cuda_graphs = graphs.CUDAGraph()
|
||||||
attn_cuda_graphs.capture_begin()
|
attn_cuda_graphs.capture_begin()
|
||||||
|
|
||||||
self.attn_forward(forward_meta, hidden_states)
|
self.attn_forward(attention_layer, forward_meta, hidden_states)
|
||||||
|
|
||||||
attn_cuda_graphs.capture_end()
|
attn_cuda_graphs.capture_end()
|
||||||
|
|
||||||
@@ -393,6 +407,219 @@ class TestAttentionPerformance(unittest.TestCase):
|
|||||||
|
|
||||||
# p.stop()
|
# p.stop()
|
||||||
|
|
||||||
|
def test_flash_attn_v3(self):
|
||||||
|
if self.sm_version < 89 or self.sm_version >= 100:
|
||||||
|
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
|
||||||
|
# Test parameters
|
||||||
|
test_steps = 100
|
||||||
|
|
||||||
|
prefill_batch_size = 1
|
||||||
|
prefill_seq_len = 4096 * 2
|
||||||
|
|
||||||
|
model_dir = self.model_dir
|
||||||
|
tp_size = paddle.distributed.get_world_size()
|
||||||
|
quantization = {
|
||||||
|
"dense_quant_type": "block_wise_fp8",
|
||||||
|
"moe_quant_type": "block_wise_fp8",
|
||||||
|
}
|
||||||
|
fd_config = self.create_fd_config_from_model_path(
|
||||||
|
model_dir, tensor_parallel_size=tp_size, quantization=quantization
|
||||||
|
)
|
||||||
|
fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
|
||||||
|
|
||||||
|
# Initialize Attention Layer
|
||||||
|
os.environ["FD_ATTENTION_BACKEND"] = "FLASH_ATTN"
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": 3})
|
||||||
|
attn_cls = get_attention_backend()
|
||||||
|
attn_backend = attn_cls(
|
||||||
|
fd_config,
|
||||||
|
kv_num_heads=fd_config.model_config.num_key_value_heads // tp_size,
|
||||||
|
num_heads=fd_config.model_config.num_attention_heads // tp_size,
|
||||||
|
head_dim=fd_config.model_config.head_dim,
|
||||||
|
encoder_block_shape_q=64,
|
||||||
|
decoder_block_shape_q=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_layers = fd_config.model_config.num_hidden_layers
|
||||||
|
attention_layer = [None] * num_layers
|
||||||
|
for i in range(num_layers):
|
||||||
|
attention_layer[i] = Ernie4_5_Attention(fd_config, layer_id=i, prefix="test_layer")
|
||||||
|
state_dict = self.create_random_attention_state_dict(fd_config, prefix="test_layer")
|
||||||
|
attention_layer[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
|
cache_quant_type_str = getattr(attention_layer[0].attn, "cache_quant_type_str", "none")
|
||||||
|
|
||||||
|
print("===== flash_attn_v3 Initialization Complete =====")
|
||||||
|
|
||||||
|
forward_meta, prefill_hidden_states = self.create_forward_meta(
|
||||||
|
batch_size=prefill_batch_size,
|
||||||
|
seq_len=prefill_seq_len,
|
||||||
|
mode=ForwardMode.EXTEND,
|
||||||
|
fd_config=fd_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
cache_quant_type_str=cache_quant_type_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_backend.init_attention_metadata(forward_meta)
|
||||||
|
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
|
||||||
|
|
||||||
|
paddle.device.synchronize()
|
||||||
|
|
||||||
|
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||||
|
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||||
|
for i in range(test_steps):
|
||||||
|
start_events[i].record()
|
||||||
|
|
||||||
|
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
|
||||||
|
|
||||||
|
end_events[i].record()
|
||||||
|
paddle.device.synchronize()
|
||||||
|
|
||||||
|
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
|
||||||
|
print(times[-5:])
|
||||||
|
|
||||||
|
def test_flash_attn_v3_with_mask(self):
|
||||||
|
if self.sm_version < 89 or self.sm_version >= 100:
|
||||||
|
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
|
||||||
|
# Test parameters
|
||||||
|
test_steps = 100
|
||||||
|
|
||||||
|
prefill_batch_size = 1
|
||||||
|
prefill_seq_len = 4096 * 2
|
||||||
|
|
||||||
|
model_dir = self.model_dir
|
||||||
|
tp_size = paddle.distributed.get_world_size()
|
||||||
|
quantization = {
|
||||||
|
"dense_quant_type": "block_wise_fp8",
|
||||||
|
"moe_quant_type": "block_wise_fp8",
|
||||||
|
}
|
||||||
|
fd_config = self.create_fd_config_from_model_path(
|
||||||
|
model_dir, tensor_parallel_size=tp_size, quantization=quantization
|
||||||
|
)
|
||||||
|
fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
|
||||||
|
|
||||||
|
# Initialize Attention Layer
|
||||||
|
os.environ["FD_ATTENTION_BACKEND"] = "FLASH_ATTN"
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": 3})
|
||||||
|
attn_cls = get_attention_backend()
|
||||||
|
attn_backend = attn_cls(
|
||||||
|
fd_config,
|
||||||
|
kv_num_heads=fd_config.model_config.num_key_value_heads // tp_size,
|
||||||
|
num_heads=fd_config.model_config.num_attention_heads // tp_size,
|
||||||
|
head_dim=fd_config.model_config.head_dim,
|
||||||
|
encoder_block_shape_q=64,
|
||||||
|
decoder_block_shape_q=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_layers = fd_config.model_config.num_hidden_layers
|
||||||
|
attention_layer = [None] * num_layers
|
||||||
|
for i in range(num_layers):
|
||||||
|
attention_layer[i] = Ernie4_5_Attention(fd_config, layer_id=i, prefix="test_layer")
|
||||||
|
state_dict = self.create_random_attention_state_dict(fd_config, prefix="test_layer")
|
||||||
|
attention_layer[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
|
cache_quant_type_str = getattr(attention_layer[0].attn, "cache_quant_type_str", "none")
|
||||||
|
|
||||||
|
print("===== flash_attn_v3_with_mask Initialization Complete =====")
|
||||||
|
|
||||||
|
forward_meta, prefill_hidden_states = self.create_forward_meta(
|
||||||
|
batch_size=prefill_batch_size,
|
||||||
|
seq_len=prefill_seq_len,
|
||||||
|
mode=ForwardMode.EXTEND,
|
||||||
|
fd_config=fd_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
cache_quant_type_str=cache_quant_type_str,
|
||||||
|
has_attn_mask=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_backend.init_attention_metadata(forward_meta)
|
||||||
|
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
|
||||||
|
|
||||||
|
paddle.device.synchronize()
|
||||||
|
|
||||||
|
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||||
|
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||||
|
for i in range(test_steps):
|
||||||
|
start_events[i].record()
|
||||||
|
|
||||||
|
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
|
||||||
|
|
||||||
|
end_events[i].record()
|
||||||
|
paddle.device.synchronize()
|
||||||
|
|
||||||
|
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
|
||||||
|
print(times[-5:])
|
||||||
|
|
||||||
|
def test_flash_attn_v4(self):
|
||||||
|
if self.sm_version < 100:
|
||||||
|
self.skipTest("Flash Attention V4 requires SM100+.")
|
||||||
|
# Test parameters
|
||||||
|
test_steps = 100
|
||||||
|
|
||||||
|
prefill_batch_size = 1
|
||||||
|
prefill_seq_len = 4096 * 2
|
||||||
|
|
||||||
|
model_dir = self.model_dir
|
||||||
|
tp_size = paddle.distributed.get_world_size()
|
||||||
|
quantization = {
|
||||||
|
"dense_quant_type": "block_wise_fp8",
|
||||||
|
"moe_quant_type": "block_wise_fp8",
|
||||||
|
}
|
||||||
|
fd_config = self.create_fd_config_from_model_path(
|
||||||
|
model_dir, tensor_parallel_size=tp_size, quantization=quantization
|
||||||
|
)
|
||||||
|
fd_config.parallel_config.tp_group = paddle.distributed.new_group(range(tp_size))
|
||||||
|
|
||||||
|
# Initialize Attention Layer
|
||||||
|
os.environ["FD_ATTENTION_BACKEND"] = "FLASH_ATTN"
|
||||||
|
attn_cls = get_attention_backend()
|
||||||
|
attn_backend = attn_cls(
|
||||||
|
fd_config,
|
||||||
|
kv_num_heads=fd_config.model_config.num_key_value_heads // tp_size,
|
||||||
|
num_heads=fd_config.model_config.num_attention_heads // tp_size,
|
||||||
|
head_dim=fd_config.model_config.head_dim,
|
||||||
|
encoder_block_shape_q=64,
|
||||||
|
decoder_block_shape_q=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_layers = fd_config.model_config.num_hidden_layers
|
||||||
|
attention_layer = [None] * num_layers
|
||||||
|
for i in range(num_layers):
|
||||||
|
attention_layer[i] = Ernie4_5_Attention(fd_config, layer_id=i, prefix="test_layer")
|
||||||
|
state_dict = self.create_random_attention_state_dict(fd_config, prefix="test_layer")
|
||||||
|
attention_layer[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
|
cache_quant_type_str = getattr(attention_layer[0].attn, "cache_quant_type_str", "none")
|
||||||
|
|
||||||
|
print("===== flash_attn_v4 Initialization Complete =====")
|
||||||
|
|
||||||
|
forward_meta, prefill_hidden_states = self.create_forward_meta(
|
||||||
|
batch_size=prefill_batch_size,
|
||||||
|
seq_len=prefill_seq_len,
|
||||||
|
mode=ForwardMode.EXTEND,
|
||||||
|
fd_config=fd_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
cache_quant_type_str=cache_quant_type_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_backend.init_attention_metadata(forward_meta)
|
||||||
|
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
|
||||||
|
|
||||||
|
paddle.device.synchronize()
|
||||||
|
|
||||||
|
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||||
|
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
|
||||||
|
for i in range(test_steps):
|
||||||
|
start_events[i].record()
|
||||||
|
|
||||||
|
self.attn_forward(attention_layer, forward_meta, prefill_hidden_states)
|
||||||
|
|
||||||
|
end_events[i].record()
|
||||||
|
paddle.device.synchronize()
|
||||||
|
|
||||||
|
times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
|
||||||
|
print(times[-5:])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -0,0 +1,206 @@
|
|||||||
|
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.attention.flash_attn_backend import (
|
||||||
|
flash_attn_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashAttnFunc(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Set up the testing environment before each test..
|
||||||
|
"""
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
paddle.set_default_dtype("bfloat16")
|
||||||
|
prop = paddle.device.cuda.get_device_properties()
|
||||||
|
self.sm_version = prop.major * 10 + prop.minor
|
||||||
|
|
||||||
|
def test_fa3(self):
|
||||||
|
if self.sm_version < 89 or self.sm_version >= 100:
|
||||||
|
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
|
||||||
|
head_dim = 128
|
||||||
|
num_heads = 12
|
||||||
|
kv_num_heads = 4
|
||||||
|
seq_len = 1024
|
||||||
|
batch_size = 4
|
||||||
|
token_num = batch_size * seq_len
|
||||||
|
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
cu_seqlens_q = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
|
||||||
|
cu_seqlens_k = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
|
||||||
|
max_seqlen_q = seq_len
|
||||||
|
max_seqlen_k = seq_len
|
||||||
|
attn_mask_q = None
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": 3})
|
||||||
|
flash_attn_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
attn_mask_q=attn_mask_q,
|
||||||
|
causal=True,
|
||||||
|
num_heads=num_heads,
|
||||||
|
kv_num_heads=kv_num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
version=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fa3_with_mask(self):
|
||||||
|
if self.sm_version < 89 or self.sm_version >= 100:
|
||||||
|
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
|
||||||
|
head_dim = 128
|
||||||
|
num_heads = 12
|
||||||
|
kv_num_heads = 4
|
||||||
|
seq_len = 1024
|
||||||
|
batch_size = 4
|
||||||
|
token_num = batch_size * seq_len
|
||||||
|
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
cu_seqlens_q = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
|
||||||
|
cu_seqlens_k = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
|
||||||
|
max_seqlen_q = seq_len
|
||||||
|
max_seqlen_k = seq_len
|
||||||
|
|
||||||
|
attn_mask_q = paddle.zeros([1, 1, token_num, 4], dtype=paddle.int32)
|
||||||
|
for bid in range(batch_size):
|
||||||
|
attn_mask_q[:, :, seq_len * bid : seq_len * (bid + 1), :2] = seq_len * (bid + 1)
|
||||||
|
for kv_token_id in range(token_num):
|
||||||
|
attn_mask_q[:, :, kv_token_id, 3] = kv_token_id
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": 3})
|
||||||
|
flash_attn_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
attn_mask_q=attn_mask_q,
|
||||||
|
causal=True,
|
||||||
|
num_heads=num_heads,
|
||||||
|
kv_num_heads=kv_num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
version=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fa2(self):
|
||||||
|
head_dim = 128
|
||||||
|
num_heads = 12
|
||||||
|
kv_num_heads = 4
|
||||||
|
seq_len = 1024
|
||||||
|
batch_size = 4
|
||||||
|
token_num = batch_size * seq_len
|
||||||
|
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
cu_seqlens_q = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
|
||||||
|
cu_seqlens_k = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
|
||||||
|
max_seqlen_q = seq_len
|
||||||
|
max_seqlen_k = seq_len
|
||||||
|
attn_mask_q = None
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": 2})
|
||||||
|
flash_attn_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
attn_mask_q=attn_mask_q,
|
||||||
|
causal=True,
|
||||||
|
num_heads=num_heads,
|
||||||
|
kv_num_heads=kv_num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
version=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fa2_with_mask(self):
|
||||||
|
head_dim = 128
|
||||||
|
num_heads = 12
|
||||||
|
kv_num_heads = 4
|
||||||
|
seq_len = 1024
|
||||||
|
batch_size = 4
|
||||||
|
token_num = batch_size * seq_len
|
||||||
|
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
cu_seqlens_q = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
|
||||||
|
cu_seqlens_k = paddle.arange(0, token_num + seq_len, seq_len, dtype=paddle.int32)
|
||||||
|
max_seqlen_q = seq_len
|
||||||
|
max_seqlen_k = seq_len
|
||||||
|
|
||||||
|
attn_mask_q = paddle.zeros([1, 1, token_num, 4], dtype=paddle.int32)
|
||||||
|
for bid in range(batch_size):
|
||||||
|
attn_mask_q[:, :, seq_len * bid : seq_len * (bid + 1), :2] = seq_len * (bid + 1)
|
||||||
|
for kv_token_id in range(token_num):
|
||||||
|
attn_mask_q[:, :, kv_token_id, 3] = kv_token_id
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": 2})
|
||||||
|
flash_attn_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
attn_mask_q=attn_mask_q,
|
||||||
|
causal=True,
|
||||||
|
num_heads=num_heads,
|
||||||
|
kv_num_heads=kv_num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
version=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fa4(self):
|
||||||
|
if self.sm_version < 100:
|
||||||
|
self.skipTest("Flash Attention V4 requires SM100+.")
|
||||||
|
head_dim = 128
|
||||||
|
num_heads = 12
|
||||||
|
kv_num_heads = 4
|
||||||
|
seq_len = 1024
|
||||||
|
batch_size = 4
|
||||||
|
token_num = batch_size * seq_len
|
||||||
|
q = paddle.rand((token_num, num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
k = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
v = paddle.rand((token_num, kv_num_heads, head_dim), dtype=paddle.float32).cast("bfloat16")
|
||||||
|
|
||||||
|
attn_mask_q = paddle.zeros([1, 1, token_num, 4], dtype=paddle.int32)
|
||||||
|
for bid in range(batch_size):
|
||||||
|
attn_mask_q[:, :, seq_len * bid : seq_len * (bid + 1), :2] = seq_len * (bid + 1)
|
||||||
|
for kv_token_id in range(token_num):
|
||||||
|
attn_mask_q[:, :, kv_token_id, 3] = kv_token_id
|
||||||
|
flash_attn_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_mask_q=attn_mask_q,
|
||||||
|
version=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -20,6 +20,7 @@ import unittest
|
|||||||
os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17")
|
os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17")
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
from utils import OpPerformanceTester
|
||||||
|
|
||||||
from fastdeploy.config import (
|
from fastdeploy.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
@@ -38,7 +39,6 @@ from fastdeploy.model_executor.layers.quantization.weight_only import (
|
|||||||
WINT8Config,
|
WINT8Config,
|
||||||
)
|
)
|
||||||
from fastdeploy.scheduler import SchedulerConfig
|
from fastdeploy.scheduler import SchedulerConfig
|
||||||
from tests.utils import OpPerformanceTester
|
|
||||||
|
|
||||||
paddle.set_default_dtype("bfloat16")
|
paddle.set_default_dtype("bfloat16")
|
||||||
paddle.seed(1024)
|
paddle.seed(1024)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import unittest
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.distributed import fleet
|
from paddle.distributed import fleet
|
||||||
|
from utils import OpPerformanceTester
|
||||||
|
|
||||||
from fastdeploy.config import (
|
from fastdeploy.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
@@ -35,7 +36,6 @@ from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
|||||||
from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config
|
from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config
|
||||||
from fastdeploy.scheduler import SchedulerConfig
|
from fastdeploy.scheduler import SchedulerConfig
|
||||||
from fastdeploy.worker.worker_process import init_distributed_environment
|
from fastdeploy.worker.worker_process import init_distributed_environment
|
||||||
from tests.utils import OpPerformanceTester
|
|
||||||
|
|
||||||
paddle.set_default_dtype("bfloat16")
|
paddle.set_default_dtype("bfloat16")
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ import unittest
|
|||||||
import paddle
|
import paddle
|
||||||
from paddle.distributed import fleet
|
from paddle.distributed import fleet
|
||||||
|
|
||||||
|
# from fastdeploy.worker.worker_process import init_distributed_environment
|
||||||
|
from utils import OpPerformanceTester
|
||||||
|
|
||||||
from fastdeploy.config import (
|
from fastdeploy.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
FDConfig,
|
FDConfig,
|
||||||
@@ -19,9 +22,6 @@ from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
|||||||
from fastdeploy.model_executor.layers.quantization.w4afp8 import W4AFP8Config
|
from fastdeploy.model_executor.layers.quantization.w4afp8 import W4AFP8Config
|
||||||
from fastdeploy.scheduler import SchedulerConfig
|
from fastdeploy.scheduler import SchedulerConfig
|
||||||
|
|
||||||
# from fastdeploy.worker.worker_process import init_distributed_environment
|
|
||||||
from tests.utils import OpPerformanceTester
|
|
||||||
|
|
||||||
paddle.set_default_dtype("bfloat16")
|
paddle.set_default_dtype("bfloat16")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,9 +26,9 @@ import paddle.distributed.fleet as fleet
|
|||||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||||
from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPForCausalLM
|
from fastdeploy.model_executor.models.ernie4_5_mtp import Ernie4_5_MTPForCausalLM
|
||||||
|
|
||||||
ROOT = Path(__file__).resolve().parents[2]
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
sys.path.insert(0, str(ROOT))
|
sys.path.insert(0, str(ROOT))
|
||||||
from tests.utils import get_default_test_fd_config
|
from utils import get_default_test_fd_config
|
||||||
|
|
||||||
strategy = fleet.DistributedStrategy()
|
strategy = fleet.DistributedStrategy()
|
||||||
fleet.init(strategy=strategy)
|
fleet.init(strategy=strategy)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ if project_root not in sys.path:
|
|||||||
|
|
||||||
os.environ["FD_USE_MACHETE"] = "0"
|
os.environ["FD_USE_MACHETE"] = "0"
|
||||||
|
|
||||||
from tests.model_loader.utils import (
|
from model_loader.utils import (
|
||||||
check_tokens_id_and_text_close,
|
check_tokens_id_and_text_close,
|
||||||
form_model_get_output_topp0,
|
form_model_get_output_topp0,
|
||||||
get_paddle_model_path,
|
get_paddle_model_path,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ project_root = os.path.abspath(os.path.join(current_dir, ".."))
|
|||||||
if project_root not in sys.path:
|
if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
from tests.model_loader.utils import (
|
from model_loader.utils import (
|
||||||
form_model_get_output_topp0,
|
form_model_get_output_topp0,
|
||||||
get_torch_model_path,
|
get_torch_model_path,
|
||||||
run_with_timeout,
|
run_with_timeout,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ if project_root not in sys.path:
|
|||||||
|
|
||||||
os.environ["FD_USE_MACHETE"] = "0"
|
os.environ["FD_USE_MACHETE"] = "0"
|
||||||
|
|
||||||
from tests.model_loader.utils import (
|
from model_loader.utils import (
|
||||||
calculate_diff_rate,
|
calculate_diff_rate,
|
||||||
form_model_get_output_topp0,
|
form_model_get_output_topp0,
|
||||||
get_torch_model_path,
|
get_torch_model_path,
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.attention.flash_attn_backend import (
|
||||||
|
flash_attn_func,
|
||||||
|
)
|
||||||
|
from fastdeploy.model_executor.layers.attention.ops import get_attn_mask_q
|
||||||
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
|
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
|
||||||
|
|
||||||
|
|
||||||
@@ -31,6 +35,8 @@ class TestFlashMaskAttention(unittest.TestCase):
|
|||||||
self.k_len = 1024
|
self.k_len = 1024
|
||||||
self.head_dim = 128
|
self.head_dim = 128
|
||||||
np.random.seed(self.q_len)
|
np.random.seed(self.q_len)
|
||||||
|
prop = paddle.device.cuda.get_device_properties()
|
||||||
|
self.sm_version = prop.major * 10 + prop.minor
|
||||||
|
|
||||||
def naive_attn(self, q_input, k_input, v_input, mask):
|
def naive_attn(self, q_input, k_input, v_input, mask):
|
||||||
|
|
||||||
@@ -101,6 +107,132 @@ class TestFlashMaskAttention(unittest.TestCase):
|
|||||||
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
|
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
|
||||||
self.assertLessEqual(max_diff, 0.05)
|
self.assertLessEqual(max_diff, 0.05)
|
||||||
|
|
||||||
|
def test_fa4(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
if self.sm_version < 100:
|
||||||
|
self.skipTest("Flash Attention V4 requires SM100+.")
|
||||||
|
q_input = paddle.randn([self.q_len, self.num_head * self.head_dim], dtype="bfloat16")
|
||||||
|
k_input = paddle.randn([self.q_len + self.k_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
|
||||||
|
v_input = paddle.randn(k_input.shape, dtype="bfloat16")
|
||||||
|
|
||||||
|
mask_start = paddle.zeros([self.q_len], dtype="int32")
|
||||||
|
mask_end = paddle.zeros([self.q_len], dtype="int32") + self.q_len + self.k_len
|
||||||
|
mask = paddle.stack([mask_start, mask_end], axis=-1).reshape([-1])
|
||||||
|
|
||||||
|
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
|
||||||
|
|
||||||
|
bsz = self.bsz
|
||||||
|
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
|
||||||
|
cu_seq_k = paddle.arange(bsz + 1) * (self.q_len + self.k_len)
|
||||||
|
cu_seq_q = cu_seq_q.astype("int32")
|
||||||
|
cu_seq_k = cu_seq_k.astype("int32")
|
||||||
|
|
||||||
|
attn_mask_q = get_attn_mask_q(
|
||||||
|
cu_seqlens_q=cu_seq_q,
|
||||||
|
cu_seqlens_k=cu_seq_k,
|
||||||
|
attn_mask_kv=mask,
|
||||||
|
kv_token_num=self.q_len + self.k_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
paddle_attn_out = flash_attn_func(
|
||||||
|
q_input,
|
||||||
|
k_input,
|
||||||
|
v_input,
|
||||||
|
attn_mask_q=attn_mask_q,
|
||||||
|
num_heads=self.num_head,
|
||||||
|
kv_num_heads=self.num_kv_head,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
version=4,
|
||||||
|
)[0].reshape([self.q_len, self.num_head * self.head_dim])
|
||||||
|
|
||||||
|
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
|
||||||
|
self.assertLessEqual(max_diff, 0.05)
|
||||||
|
|
||||||
|
def test_fa3_with_mask(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
if self.sm_version < 89 or self.sm_version >= 100:
|
||||||
|
self.skipTest("Flash Attention V3 requires SM89+ but less than SM100.")
|
||||||
|
q_input = paddle.randn([self.q_len, self.num_head * self.head_dim], dtype="bfloat16")
|
||||||
|
k_input = paddle.randn([self.q_len + self.k_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
|
||||||
|
v_input = paddle.randn(k_input.shape, dtype="bfloat16")
|
||||||
|
|
||||||
|
mask_start = paddle.zeros([self.q_len], dtype="int32")
|
||||||
|
mask_end = paddle.zeros([self.q_len], dtype="int32") + self.q_len + self.k_len
|
||||||
|
mask = paddle.stack([mask_start, mask_end], axis=-1).reshape([-1])
|
||||||
|
|
||||||
|
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
|
||||||
|
|
||||||
|
bsz = self.bsz
|
||||||
|
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
|
||||||
|
cu_seq_k = paddle.arange(bsz + 1) * (self.q_len + self.k_len)
|
||||||
|
cu_seq_q = cu_seq_q.astype("int32")
|
||||||
|
cu_seq_k = cu_seq_k.astype("int32")
|
||||||
|
|
||||||
|
attn_mask_q = get_attn_mask_q(
|
||||||
|
cu_seqlens_q=cu_seq_q,
|
||||||
|
cu_seqlens_k=cu_seq_k,
|
||||||
|
attn_mask_kv=mask,
|
||||||
|
kv_token_num=self.q_len + self.k_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": 3})
|
||||||
|
paddle_attn_out = flash_attn_func(
|
||||||
|
q_input,
|
||||||
|
k_input,
|
||||||
|
v_input,
|
||||||
|
attn_mask_q=attn_mask_q,
|
||||||
|
num_heads=self.num_head,
|
||||||
|
kv_num_heads=self.num_kv_head,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
version=3,
|
||||||
|
)[0].reshape([self.q_len, self.num_head * self.head_dim])
|
||||||
|
|
||||||
|
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
|
||||||
|
self.assertLessEqual(max_diff, 0.05)
|
||||||
|
|
||||||
|
def test_fa2_with_mask(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
q_input = paddle.randn([self.q_len, self.num_head * self.head_dim], dtype="bfloat16")
|
||||||
|
k_input = paddle.randn([self.q_len + self.k_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
|
||||||
|
v_input = paddle.randn(k_input.shape, dtype="bfloat16")
|
||||||
|
|
||||||
|
mask_start = paddle.zeros([self.q_len], dtype="int32")
|
||||||
|
mask_end = paddle.zeros([self.q_len], dtype="int32") + self.q_len + self.k_len
|
||||||
|
mask = paddle.stack([mask_start, mask_end], axis=-1).reshape([-1])
|
||||||
|
|
||||||
|
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
|
||||||
|
|
||||||
|
bsz = self.bsz
|
||||||
|
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
|
||||||
|
cu_seq_k = paddle.arange(bsz + 1) * (self.q_len + self.k_len)
|
||||||
|
cu_seq_q = cu_seq_q.astype("int32")
|
||||||
|
cu_seq_k = cu_seq_k.astype("int32")
|
||||||
|
|
||||||
|
attn_mask_q = get_attn_mask_q(
|
||||||
|
cu_seqlens_q=cu_seq_q,
|
||||||
|
cu_seqlens_k=cu_seq_k,
|
||||||
|
attn_mask_kv=mask,
|
||||||
|
kv_token_num=self.q_len + self.k_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
paddle.set_flags({"FLAGS_flash_attn_version": 2})
|
||||||
|
paddle_attn_out = flash_attn_func(
|
||||||
|
q_input,
|
||||||
|
k_input,
|
||||||
|
v_input,
|
||||||
|
attn_mask_q=attn_mask_q,
|
||||||
|
num_heads=self.num_head,
|
||||||
|
kv_num_heads=self.num_kv_head,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
version=2,
|
||||||
|
)[0].reshape([self.q_len, self.num_head * self.head_dim])
|
||||||
|
|
||||||
|
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
|
||||||
|
self.assertLessEqual(max_diff, 0.05)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
from utils import OpPerformanceTester
|
||||||
|
|
||||||
from fastdeploy.model_executor.ops.triton_ops import qk_rmsnorm_fused
|
from fastdeploy.model_executor.ops.triton_ops import qk_rmsnorm_fused
|
||||||
from tests.utils import OpPerformanceTester
|
|
||||||
|
|
||||||
paddle.set_default_dtype("bfloat16")
|
paddle.set_default_dtype("bfloat16")
|
||||||
paddle.seed(99)
|
paddle.seed(99)
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ project_root = os.path.abspath(os.path.join(current_dir, ".."))
|
|||||||
if project_root not in sys.path:
|
if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
from tests.model_loader.utils import get_torch_model_path
|
from model_loader.utils import get_torch_model_path
|
||||||
|
|
||||||
test_model_configs = {
|
test_model_configs = {
|
||||||
"Qwen3-0.6B": {
|
"Qwen3-0.6B": {
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from fastdeploy.model_executor.layers.quantization.kv_cache import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
sys.path.append("../")
|
sys.path.append("../")
|
||||||
from tests.utils import get_default_test_fd_config
|
from utils import get_default_test_fd_config
|
||||||
|
|
||||||
|
|
||||||
class MockLayer(nn.Layer):
|
class MockLayer(nn.Layer):
|
||||||
|
|||||||
@@ -18,11 +18,11 @@ import unittest
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
from utils import FakeModelConfig, get_default_test_fd_config
|
||||||
|
|
||||||
from fastdeploy.config import SpeculativeConfig
|
from fastdeploy.config import SpeculativeConfig
|
||||||
from fastdeploy.engine.request import Request, RequestType
|
from fastdeploy.engine.request import Request, RequestType
|
||||||
from fastdeploy.spec_decode.mtp import MTPProposer
|
from fastdeploy.spec_decode.mtp import MTPProposer
|
||||||
from tests.utils import FakeModelConfig, get_default_test_fd_config
|
|
||||||
|
|
||||||
|
|
||||||
class TestMTPProposer(unittest.TestCase):
|
class TestMTPProposer(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user