mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[OP][Optimization] Remove ENABLE_PREFILL template parameter in multi_query_append_attention_warp1_4_kernel (#7201)
This commit is contained in:
@@ -430,8 +430,7 @@ template <typename T,
|
||||
uint32_t num_frags_x,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t num_frags_y,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true>
|
||||
typename OutT = T>
|
||||
__global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
@@ -525,17 +524,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
if (!partition_kv || num_chunks_this_seq <= 1) {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
} else {
|
||||
if (ENABLE_PREFILL) {
|
||||
o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride +
|
||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
} else {
|
||||
o_base_ptr_T =
|
||||
tmp_workspace +
|
||||
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
|
||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
o_base_ptr_T =
|
||||
tmp_workspace +
|
||||
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
|
||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
const int *mask_offset_this_seq =
|
||||
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
@@ -799,18 +792,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
|
||||
|
||||
if (qo_idx - q_start_seq_id < q_len) {
|
||||
uint32_t offset;
|
||||
if (ENABLE_PREFILL) {
|
||||
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
|
||||
qo_head_idx;
|
||||
} else {
|
||||
offset = ((batch_id * speculate_max_draft_token_num +
|
||||
qo_idx_now / GROUP_SIZE) *
|
||||
num_chunks +
|
||||
chunk_idx) *
|
||||
q_num_heads +
|
||||
qo_head_idx;
|
||||
}
|
||||
const uint32_t offset = ((batch_id * speculate_max_draft_token_num +
|
||||
qo_idx_now / GROUP_SIZE) *
|
||||
num_chunks +
|
||||
chunk_idx) *
|
||||
q_num_heads +
|
||||
qo_head_idx;
|
||||
tmp_m[offset] = m_frag[fx][j];
|
||||
tmp_d[offset] = d_frag[fx][j];
|
||||
}
|
||||
@@ -1123,8 +1110,7 @@ void MultiQueryAppendAttention(
|
||||
num_frags_x,
|
||||
num_frags_z,
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>;
|
||||
OUT_NV_TYPE>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
@@ -1169,8 +1155,7 @@ void MultiQueryAppendAttention(
|
||||
num_frags_x,
|
||||
num_frags_z,
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>;
|
||||
OUT_NV_TYPE>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
@@ -1222,43 +1207,18 @@ void MultiQueryAppendAttention(
|
||||
sink_size);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
tmp_workspace = allocator->Allocate(
|
||||
phi::SizeOf(qkv.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
} else {
|
||||
if (ENABLE_PREFILL) {
|
||||
tmp_workspace =
|
||||
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
|
||||
static_cast<size_t>(token_num * num_chunks *
|
||||
num_heads * HEAD_DIM));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||
} else {
|
||||
tmp_workspace = allocator->Allocate(
|
||||
phi::SizeOf(qkv.dtype()) *
|
||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
||||
num_chunks * num_heads * HEAD_DIM));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
||||
num_chunks * num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
||||
num_chunks * num_heads));
|
||||
}
|
||||
}
|
||||
tmp_workspace = allocator->Allocate(
|
||||
phi::SizeOf(qkv.dtype()) *
|
||||
static_cast<size_t>(speculate_max_draft_token_num * bsz * num_chunks *
|
||||
num_heads * HEAD_DIM));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||
static_cast<size_t>(speculate_max_draft_token_num * bsz * num_chunks *
|
||||
num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||
static_cast<size_t>(speculate_max_draft_token_num * bsz * num_chunks *
|
||||
num_heads));
|
||||
launchWithPdlWhenEnabled(
|
||||
split_kv_kernel,
|
||||
grids,
|
||||
|
||||
Reference in New Issue
Block a user