mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
29a313a402
* support FA4 sm100 * flash attn backend support mask * flash attn backend run flashmask correct * add test for flash_attn_backend and flash_attn_func * check * add test for fa4 * requirements.txt add fa4 whl * check test on sm100 * fix CI conflict * add enable_torch_proxy for flash_mask * lazy import fa4 * check * fix tests import * check test_load_mpt import
138 lines
5.1 KiB
Plaintext
138 lines
5.1 KiB
Plaintext
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "helper.h"
|
|
#include "paddle/extension.h"
|
|
|
|
__global__ void get_attn_mask_q_kernel(
|
|
int* __restrict__ startend_row_indices_ptr,
|
|
const int* attn_mask_kv_ptr,
|
|
const int* cu_seqlens_q,
|
|
const int* cu_seqlens_k,
|
|
const int kv_token_num,
|
|
const int max_batch_size) {
|
|
constexpr int VecSize = 4;
|
|
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
|
|
int startend_row_vec[4];
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaGridDependencySynchronize();
|
|
#endif
|
|
for (uint32_t cu_seqlens_k_idx = bid * blockDim.x + tid;
|
|
cu_seqlens_k_idx < kv_token_num;
|
|
cu_seqlens_k_idx += blockDim.x * gridDim.x) {
|
|
uint32_t batch_id = 0;
|
|
|
|
for (int i = 0; i < max_batch_size; ++i) {
|
|
if (cu_seqlens_k_idx >= cu_seqlens_k[i] &&
|
|
cu_seqlens_k_idx < cu_seqlens_k[i + 1]) {
|
|
batch_id = i;
|
|
break;
|
|
}
|
|
}
|
|
const uint32_t this_batch_q_start = cu_seqlens_q[batch_id];
|
|
const uint32_t this_batch_q_end = cu_seqlens_q[batch_id + 1];
|
|
const uint32_t this_batch_q_len = this_batch_q_end - this_batch_q_start;
|
|
const uint32_t kv_start = cu_seqlens_k[batch_id];
|
|
const uint32_t kv_end = cu_seqlens_k[batch_id + 1];
|
|
const uint32_t kv_len = kv_end - kv_start;
|
|
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
|
|
|
|
startend_row_vec[0] = this_batch_q_end;
|
|
startend_row_vec[1] = cu_seqlens_q[max_batch_size];
|
|
startend_row_vec[2] = 0;
|
|
startend_row_vec[3] = this_batch_q_end;
|
|
for (int this_batch_q_idx = this_batch_q_start;
|
|
this_batch_q_idx < this_batch_q_end;
|
|
++this_batch_q_idx) {
|
|
// const int append_mask_k_start = attn_mask_kv_ptr ?
|
|
// attn_mask_kv_ptr[this_batch_q_idx * 2 + 0] : 0;
|
|
const int append_mask_k_end =
|
|
attn_mask_kv_ptr ? attn_mask_kv_ptr[this_batch_q_idx * 2 + 1] - 1
|
|
: this_batch_q_idx - this_batch_q_start + kv_len -
|
|
(this_batch_q_len);
|
|
if (cache_k_idx <= append_mask_k_end) {
|
|
startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx);
|
|
// 可提前跳出循环
|
|
break;
|
|
}
|
|
}
|
|
reinterpret_cast<int4*>(startend_row_indices_ptr +
|
|
cu_seqlens_k_idx * 4)[0] =
|
|
reinterpret_cast<int4*>(startend_row_vec)[0];
|
|
}
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
#endif
|
|
}
|
|
|
|
std::vector<paddle::Tensor> get_attn_mask_q(
|
|
const paddle::Tensor& cu_seqlens_q,
|
|
const paddle::Tensor& cu_seqlens_k,
|
|
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
|
const int kv_token_num) {
|
|
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
|
|
{1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place());
|
|
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
|
|
constexpr int block_size = 512;
|
|
int grid_size = div_up(kv_token_num, block_size);
|
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
|
get_attn_mask_q_kernel<<<grid_size, block_size, 0, cu_seqlens_k.stream()>>>(
|
|
attn_mask_startend_row_indices.data<int>(),
|
|
attn_mask_kv ? attn_mask_kv.get().data<int>() : nullptr,
|
|
cu_seqlens_q.data<int>(),
|
|
cu_seqlens_k.data<int>(),
|
|
kv_token_num,
|
|
max_batch_size);
|
|
#else
|
|
launchWithPdlWhenEnabled(
|
|
get_attn_mask_q_kernel,
|
|
grid_size,
|
|
block_size,
|
|
0,
|
|
cu_seqlens_k.stream(),
|
|
attn_mask_startend_row_indices.data<int>(),
|
|
attn_mask_kv ? attn_mask_kv.get().data<int>() : nullptr,
|
|
cu_seqlens_q.data<int>(),
|
|
cu_seqlens_k.data<int>(),
|
|
kv_token_num,
|
|
max_batch_size);
|
|
#endif
|
|
return {attn_mask_startend_row_indices};
|
|
}
|
|
|
|
std::vector<paddle::DataType> GetAttnMaskQInferDtype(
|
|
const paddle::DataType& cu_seqlens_q_dtype,
|
|
const paddle::DataType& cu_seqlens_k_dtype,
|
|
const paddle::optional<paddle::DataType>& attn_mask_kv_dtype) {
|
|
return {paddle::DataType::INT32};
|
|
}
|
|
|
|
std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
|
|
const std::vector<int64_t>& cu_seqlens_q_shape,
|
|
const std::vector<int64_t>& cu_seqlens_k_shape,
|
|
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
|
|
const int kv_token_num) {
|
|
return {{1, 1, kv_token_num, 4}};
|
|
}
|
|
|
|
PD_BUILD_STATIC_OP(get_attn_mask_q)
|
|
.Inputs({"cu_seqlens_q",
|
|
"cu_seqlens_k",
|
|
paddle::Optional("attn_mask_offsets")})
|
|
.Outputs({"attn_mask_q"})
|
|
.Attrs({"kv_token_num: int"})
|
|
.SetKernelFn(PD_KERNEL(get_attn_mask_q))
|
|
.SetInferShapeFn(PD_INFER_SHAPE(GetAttnMaskQInferShape))
|
|
.SetInferDtypeFn(PD_INFER_DTYPE(GetAttnMaskQInferDtype));
|