diff --git a/custom_ops/gpu_ops/get_attn_mask_q.cu b/custom_ops/gpu_ops/get_attn_mask_q.cu index 4ee814178b..a485d04f6b 100644 --- a/custom_ops/gpu_ops/get_attn_mask_q.cu +++ b/custom_ops/gpu_ops/get_attn_mask_q.cu @@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel( const int max_batch_size) { constexpr int VecSize = 4; const uint32_t tid = threadIdx.x, bid = blockIdx.x; - int startend_row_vec[4]; + int startend_row_vec[2]; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif @@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel( const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start; startend_row_vec[0] = this_batch_q_end; - startend_row_vec[1] = cu_seqlens_q[max_batch_size]; - startend_row_vec[2] = 0; - startend_row_vec[3] = this_batch_q_end; + // startend_row_vec[1] = cu_seqlens_q[max_batch_size]; + // startend_row_vec[2] = 0; + startend_row_vec[1] = this_batch_q_end; for (int this_batch_q_idx = this_batch_q_start; this_batch_q_idx < this_batch_q_end; ++this_batch_q_idx) { @@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel( : this_batch_q_idx - this_batch_q_start + kv_len - (this_batch_q_len); if (cache_k_idx <= append_mask_k_end) { - startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx); + startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx); // 可提前跳出循环 break; } } - reinterpret_cast(startend_row_indices_ptr + - cu_seqlens_k_idx * 4)[0] = - reinterpret_cast(startend_row_vec)[0]; + reinterpret_cast(startend_row_indices_ptr + + cu_seqlens_k_idx * 2)[0] = + reinterpret_cast(startend_row_vec)[0]; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); @@ -82,7 +82,7 @@ std::vector get_attn_mask_q( const paddle::optional& attn_mask_kv, const int kv_token_num) { paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor( - {1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place()); + {1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place()); const int max_batch_size = cu_seqlens_k.dims()[0] - 1; constexpr int block_size = 512; int grid_size = div_up(kv_token_num, block_size); @@ -123,7 +123,7 @@ std::vector> GetAttnMaskQInferShape( const std::vector& cu_seqlens_k_shape, const paddle::optional>& attn_mask_kv_shape, const int kv_token_num) { - return {{1, 1, kv_token_num, 4}}; + return {{1, 1, kv_token_num, 2}}; } PD_BUILD_STATIC_OP(get_attn_mask_q)