mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
update attn_mask_q 2 (#7371)
This commit is contained in:
@@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel(
|
||||
const int max_batch_size) {
|
||||
constexpr int VecSize = 4;
|
||||
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
|
||||
int startend_row_vec[4];
|
||||
int startend_row_vec[2];
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
@@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel(
|
||||
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
|
||||
|
||||
startend_row_vec[0] = this_batch_q_end;
|
||||
startend_row_vec[1] = cu_seqlens_q[max_batch_size];
|
||||
startend_row_vec[2] = 0;
|
||||
startend_row_vec[3] = this_batch_q_end;
|
||||
// startend_row_vec[1] = cu_seqlens_q[max_batch_size];
|
||||
// startend_row_vec[2] = 0;
|
||||
startend_row_vec[1] = this_batch_q_end;
|
||||
for (int this_batch_q_idx = this_batch_q_start;
|
||||
this_batch_q_idx < this_batch_q_end;
|
||||
++this_batch_q_idx) {
|
||||
@@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel(
|
||||
: this_batch_q_idx - this_batch_q_start + kv_len -
|
||||
(this_batch_q_len);
|
||||
if (cache_k_idx <= append_mask_k_end) {
|
||||
startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx);
|
||||
startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx);
|
||||
// 可提前跳出循环
|
||||
break;
|
||||
}
|
||||
}
|
||||
reinterpret_cast<int4*>(startend_row_indices_ptr +
|
||||
cu_seqlens_k_idx * 4)[0] =
|
||||
reinterpret_cast<int4*>(startend_row_vec)[0];
|
||||
reinterpret_cast<int2*>(startend_row_indices_ptr +
|
||||
cu_seqlens_k_idx * 2)[0] =
|
||||
reinterpret_cast<int2*>(startend_row_vec)[0];
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
@@ -82,7 +82,7 @@ std::vector<paddle::Tensor> get_attn_mask_q(
|
||||
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
||||
const int kv_token_num) {
|
||||
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
|
||||
{1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place());
|
||||
{1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place());
|
||||
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
|
||||
constexpr int block_size = 512;
|
||||
int grid_size = div_up(kv_token_num, block_size);
|
||||
@@ -123,7 +123,7 @@ std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
|
||||
const std::vector<int64_t>& cu_seqlens_k_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
|
||||
const int kv_token_num) {
|
||||
return {{1, 1, kv_token_num, 4}};
|
||||
return {{1, 1, kv_token_num, 2}};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_attn_mask_q)
|
||||
|
||||
Reference in New Issue
Block a user