mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[OP][Models][Optimization] 优化 RoPE CUDA kernel 并更新 DeepSeek V3 配置 (#7359)
* dsk del prefill mask * dsk support 1M+ seq_len rope * update rope tests * Replace max_position_embeddings with max_model_len * 1D grid: gridDim.x has a maximum size of 2^31-1, far exceeding the actual number of tokens.
This commit is contained in:
@@ -53,12 +53,8 @@ __global__ void apply_rotary_embedding_kernel(
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size,
|
||||
const int num_tokens) { // 新增 num_tokens 参数用于边界检查
|
||||
|
||||
// 用2D grid表示token_idx,突破65535限制
|
||||
const int token_idx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (token_idx >= num_tokens) return; // 边界保护
|
||||
const int head_size) {
|
||||
const int token_idx = blockIdx.x;
|
||||
|
||||
int pos = position_ids[token_idx];
|
||||
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
@@ -103,11 +99,8 @@ void FusedRotaryPositionEncoding(
|
||||
int64_t query_stride = num_heads * head_size;
|
||||
int64_t key_stride = num_kv_heads * head_size;
|
||||
|
||||
// 拆成2D grid:每维最大65535,总计支持 65535*65535 >> 1024*1024
|
||||
constexpr int MAX_GRID_X = 65535;
|
||||
int grid_x = std::min<int64_t>(num_tokens, MAX_GRID_X);
|
||||
int grid_y = (num_tokens + MAX_GRID_X - 1) / MAX_GRID_X;
|
||||
dim3 grid(grid_x, grid_y);
|
||||
// 1D grid:gridDim.x 最大 2^31-1,远超实际 token 数
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
|
||||
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
|
||||
@@ -123,8 +116,7 @@ void FusedRotaryPositionEncoding(
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
num_tokens);
|
||||
head_size);
|
||||
} else {
|
||||
apply_rotary_embedding_kernel<data_t, false>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
@@ -136,8 +128,7 @@ void FusedRotaryPositionEncoding(
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
num_tokens);
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user