mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] [OP] [Models] dsk del prefill mask (#7313)
* dsk del prefill mask * dsk support 1M+ seq_len rope * update rope tests
This commit is contained in:
@@ -53,9 +53,13 @@ __global__ void apply_rotary_embedding_kernel(
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
const int head_size,
|
||||
const int num_tokens) { // 新增 num_tokens 参数用于边界检查
|
||||
|
||||
// 用2D grid表示token_idx,突破65535限制
|
||||
const int token_idx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (token_idx >= num_tokens) return; // 边界保护
|
||||
|
||||
int pos = position_ids[token_idx];
|
||||
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
@@ -99,13 +103,13 @@ void FusedRotaryPositionEncoding(
|
||||
int64_t query_stride = num_heads * head_size;
|
||||
int64_t key_stride = num_kv_heads * head_size;
|
||||
|
||||
if (num_tokens > 65535) {
|
||||
PD_THROW(
|
||||
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
|
||||
}
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
// 拆成2D grid:每维最大65535,总计支持 65535*65535 >> 1024*1024
|
||||
constexpr int MAX_GRID_X = 65535;
|
||||
int grid_x = std::min<int64_t>(num_tokens, MAX_GRID_X);
|
||||
int grid_y = (num_tokens + MAX_GRID_X - 1) / MAX_GRID_X;
|
||||
dim3 grid(grid_x, grid_y);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
|
||||
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
|
||||
query.dtype(), "apply_rotary_embedding_kernel", [&] {
|
||||
if (is_neox) {
|
||||
@@ -119,7 +123,8 @@ void FusedRotaryPositionEncoding(
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
head_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
apply_rotary_embedding_kernel<data_t, false>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
@@ -131,7 +136,8 @@ void FusedRotaryPositionEncoding(
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
head_size,
|
||||
num_tokens);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -44,13 +44,49 @@ __global__ void FillEncoderDecoderResKernel(T *encoder_res_data,
|
||||
return;
|
||||
}
|
||||
|
||||
const int load_idx =
|
||||
((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4;
|
||||
const int base_idx =
|
||||
((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim;
|
||||
|
||||
*reinterpret_cast<float2 *>(encoder_res_data + load_idx) =
|
||||
*reinterpret_cast<float2 *>(decoder_res_data + load_idx);
|
||||
if (head_dim == 128) {
|
||||
const int load_idx = base_idx + land_id * 4;
|
||||
*reinterpret_cast<float2 *>(encoder_res_data + load_idx) =
|
||||
*reinterpret_cast<float2 *>(decoder_res_data + load_idx);
|
||||
} else if (head_dim == 192) {
|
||||
const int load_idx = base_idx + land_id * 4;
|
||||
*reinterpret_cast<float2 *>(encoder_res_data + load_idx) =
|
||||
*reinterpret_cast<float2 *>(decoder_res_data + load_idx);
|
||||
if (land_id < 16) {
|
||||
*reinterpret_cast<float2 *>(encoder_res_data + load_idx + 128) =
|
||||
*reinterpret_cast<float2 *>(decoder_res_data + load_idx + 128);
|
||||
}
|
||||
} else if (head_dim == 256) {
|
||||
// float4 = 单条LDG.128,性能最优
|
||||
const int load_idx = base_idx + land_id * 8;
|
||||
*reinterpret_cast<float4 *>(encoder_res_data + load_idx) =
|
||||
*reinterpret_cast<float4 *>(decoder_res_data + load_idx);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_KERNEL(T, WARPS) \
|
||||
FillEncoderDecoderResKernel<WARPS> \
|
||||
<<<grid_dims, head_dim, 0, encoder_res.stream()>>>( \
|
||||
const_cast<T *>(encoder_res.data<T>()), \
|
||||
const_cast<T *>(decoder_res.data<T>()), \
|
||||
seq_lens_encoder.data<int>(), \
|
||||
seq_lens_decoder.data<int>(), \
|
||||
seq_lens_this_time.data<int>(), \
|
||||
cu_seq_q.data<int>(), \
|
||||
head_num, \
|
||||
head_dim)
|
||||
|
||||
#define LAUNCH_KERNEL_BY_HEAD_DIM(T) \
|
||||
if (head_dim == 128) \
|
||||
LAUNCH_KERNEL(T, 4); \
|
||||
else if (head_dim == 192) \
|
||||
LAUNCH_KERNEL(T, 6); \
|
||||
else if (head_dim == 256) \
|
||||
LAUNCH_KERNEL(T, 8)
|
||||
|
||||
void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res,
|
||||
const paddle::Tensor &decoder_res,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -60,41 +96,20 @@ void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int max_token) {
|
||||
if (head_dim != 128) {
|
||||
PD_THROW("Only supported head_dim = 128");
|
||||
if (head_dim != 128 && head_dim != 192 && head_dim != 256) {
|
||||
PD_THROW("Only supported head_dim = 128, 192 or 256");
|
||||
}
|
||||
const int batch_size = seq_lens_encoder.shape()[0];
|
||||
constexpr int warps = 4;
|
||||
const int warps = head_dim / 32;
|
||||
const int tokens_block = (max_token + warps - 1) / warps;
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = batch_size;
|
||||
grid_dims.y = head_num;
|
||||
grid_dims.z = tokens_block;
|
||||
dim3 grid_dims(batch_size, head_num, tokens_block);
|
||||
|
||||
if (encoder_res.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
FillEncoderDecoderResKernel<warps>
|
||||
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
|
||||
const_cast<T *>(encoder_res.data<T>()),
|
||||
const_cast<T *>(decoder_res.data<T>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
head_num,
|
||||
head_dim);
|
||||
LAUNCH_KERNEL_BY_HEAD_DIM(T);
|
||||
} else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
FillEncoderDecoderResKernel<warps>
|
||||
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
|
||||
const_cast<T *>(encoder_res.data<T>()),
|
||||
const_cast<T *>(decoder_res.data<T>()),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
cu_seq_q.data<int>(),
|
||||
head_num,
|
||||
head_dim);
|
||||
LAUNCH_KERNEL_BY_HEAD_DIM(T);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user