[Optimization] [OP] [Models] dsk del prefill mask (#7313)

* dsk del prefill mask

* dsk support 1M+ seq_len rope

* update rope tests
This commit is contained in:
AIbin
2026-04-11 19:32:27 +08:00
committed by GitHub
parent 076ab07528
commit ba01d7a823
4 changed files with 83 additions and 49 deletions
@@ -53,9 +53,13 @@ __global__ void apply_rotary_embedding_kernel(
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
const int head_size,
const int num_tokens) { // 新增 num_tokens 参数用于边界检查
// 用2D grid表示token_idx,突破65535限制
const int token_idx = blockIdx.x + blockIdx.y * gridDim.x;
if (token_idx >= num_tokens) return; // 边界保护
int pos = position_ids[token_idx];
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
@@ -99,13 +103,13 @@ void FusedRotaryPositionEncoding(
int64_t query_stride = num_heads * head_size;
int64_t key_stride = num_kv_heads * head_size;
if (num_tokens > 65535) {
PD_THROW(
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
}
dim3 grid(num_tokens);
// 拆成2D grid:每维最大65535,总计支持 65535*65535 >> 1024*1024
constexpr int MAX_GRID_X = 65535;
int grid_x = std::min<int64_t>(num_tokens, MAX_GRID_X);
int grid_y = (num_tokens + MAX_GRID_X - 1) / MAX_GRID_X;
dim3 grid(grid_x, grid_y);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
query.dtype(), "apply_rotary_embedding_kernel", [&] {
if (is_neox) {
@@ -119,7 +123,8 @@ void FusedRotaryPositionEncoding(
key_stride,
num_heads,
num_kv_heads,
head_size);
head_size,
num_tokens);
} else {
apply_rotary_embedding_kernel<data_t, false>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
@@ -131,7 +136,8 @@ void FusedRotaryPositionEncoding(
key_stride,
num_heads,
num_kv_heads,
head_size);
head_size,
num_tokens);
}
});
}