update attn_mask_q 2 (#7371)

This commit is contained in:
chen
2026-04-13 23:06:04 +08:00
committed by GitHub
parent 0ddb6e461c
commit 26c47c2afc
+10 -10
View File
@@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel(
const int max_batch_size) {
constexpr int VecSize = 4;
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
int startend_row_vec[4];
int startend_row_vec[2];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
@@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel(
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
startend_row_vec[0] = this_batch_q_end;
startend_row_vec[1] = cu_seqlens_q[max_batch_size];
startend_row_vec[2] = 0;
startend_row_vec[3] = this_batch_q_end;
// startend_row_vec[1] = cu_seqlens_q[max_batch_size];
// startend_row_vec[2] = 0;
startend_row_vec[1] = this_batch_q_end;
for (int this_batch_q_idx = this_batch_q_start;
this_batch_q_idx < this_batch_q_end;
++this_batch_q_idx) {
@@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel(
: this_batch_q_idx - this_batch_q_start + kv_len -
(this_batch_q_len);
if (cache_k_idx <= append_mask_k_end) {
startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx);
startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx);
// 可提前跳出循环
break;
}
}
reinterpret_cast<int4*>(startend_row_indices_ptr +
cu_seqlens_k_idx * 4)[0] =
reinterpret_cast<int4*>(startend_row_vec)[0];
reinterpret_cast<int2*>(startend_row_indices_ptr +
cu_seqlens_k_idx * 2)[0] =
reinterpret_cast<int2*>(startend_row_vec)[0];
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
@@ -82,7 +82,7 @@ std::vector<paddle::Tensor> get_attn_mask_q(
const paddle::optional<paddle::Tensor>& attn_mask_kv,
const int kv_token_num) {
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
{1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place());
{1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place());
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
constexpr int block_size = 512;
int grid_size = div_up(kv_token_num, block_size);
@@ -123,7 +123,7 @@ std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
const std::vector<int64_t>& cu_seqlens_k_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
const int kv_token_num) {
return {{1, 1, kv_token_num, 4}};
return {{1, 1, kv_token_num, 2}};
}
PD_BUILD_STATIC_OP(get_attn_mask_q)