[Optimization] [OP] [Models] dsk del prefill mask (#7313)

* dsk del prefill mask

* dsk support 1M+ seq_len rope

* update rope tests
This commit is contained in:
AIbin
2026-04-11 19:32:27 +08:00
committed by GitHub
parent 076ab07528
commit ba01d7a823
4 changed files with 83 additions and 49 deletions
@@ -53,9 +53,13 @@ __global__ void apply_rotary_embedding_kernel(
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
const int head_size,
const int num_tokens) { // 新增 num_tokens 参数用于边界检查
// 用2D grid表示token_idx,突破65535限制
const int token_idx = blockIdx.x + blockIdx.y * gridDim.x;
if (token_idx >= num_tokens) return; // 边界保护
int pos = position_ids[token_idx];
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
@@ -99,13 +103,13 @@ void FusedRotaryPositionEncoding(
int64_t query_stride = num_heads * head_size;
int64_t key_stride = num_kv_heads * head_size;
if (num_tokens > 65535) {
PD_THROW(
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
}
dim3 grid(num_tokens);
// 拆成2D grid:每维最大65535,总计支持 65535*65535 >> 1024*1024
constexpr int MAX_GRID_X = 65535;
int grid_x = std::min<int64_t>(num_tokens, MAX_GRID_X);
int grid_y = (num_tokens + MAX_GRID_X - 1) / MAX_GRID_X;
dim3 grid(grid_x, grid_y);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
query.dtype(), "apply_rotary_embedding_kernel", [&] {
if (is_neox) {
@@ -119,7 +123,8 @@ void FusedRotaryPositionEncoding(
key_stride,
num_heads,
num_kv_heads,
head_size);
head_size,
num_tokens);
} else {
apply_rotary_embedding_kernel<data_t, false>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
@@ -131,7 +136,8 @@ void FusedRotaryPositionEncoding(
key_stride,
num_heads,
num_kv_heads,
head_size);
head_size,
num_tokens);
}
});
}
@@ -44,13 +44,49 @@ __global__ void FillEncoderDecoderResKernel(T *encoder_res_data,
return;
}
const int load_idx =
((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4;
const int base_idx =
((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim;
*reinterpret_cast<float2 *>(encoder_res_data + load_idx) =
*reinterpret_cast<float2 *>(decoder_res_data + load_idx);
if (head_dim == 128) {
const int load_idx = base_idx + land_id * 4;
*reinterpret_cast<float2 *>(encoder_res_data + load_idx) =
*reinterpret_cast<float2 *>(decoder_res_data + load_idx);
} else if (head_dim == 192) {
const int load_idx = base_idx + land_id * 4;
*reinterpret_cast<float2 *>(encoder_res_data + load_idx) =
*reinterpret_cast<float2 *>(decoder_res_data + load_idx);
if (land_id < 16) {
*reinterpret_cast<float2 *>(encoder_res_data + load_idx + 128) =
*reinterpret_cast<float2 *>(decoder_res_data + load_idx + 128);
}
} else if (head_dim == 256) {
// float4 = 单条LDG.128,性能最优
const int load_idx = base_idx + land_id * 8;
*reinterpret_cast<float4 *>(encoder_res_data + load_idx) =
*reinterpret_cast<float4 *>(decoder_res_data + load_idx);
}
}
#define LAUNCH_KERNEL(T, WARPS) \
FillEncoderDecoderResKernel<WARPS> \
<<<grid_dims, head_dim, 0, encoder_res.stream()>>>( \
const_cast<T *>(encoder_res.data<T>()), \
const_cast<T *>(decoder_res.data<T>()), \
seq_lens_encoder.data<int>(), \
seq_lens_decoder.data<int>(), \
seq_lens_this_time.data<int>(), \
cu_seq_q.data<int>(), \
head_num, \
head_dim)
#define LAUNCH_KERNEL_BY_HEAD_DIM(T) \
if (head_dim == 128) \
LAUNCH_KERNEL(T, 4); \
else if (head_dim == 192) \
LAUNCH_KERNEL(T, 6); \
else if (head_dim == 256) \
LAUNCH_KERNEL(T, 8)
void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
@@ -60,41 +96,20 @@ void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res,
const int head_num,
const int head_dim,
const int max_token) {
if (head_dim != 128) {
PD_THROW("Only supported head_dim = 128");
if (head_dim != 128 && head_dim != 192 && head_dim != 256) {
PD_THROW("Only supported head_dim = 128, 192 or 256");
}
const int batch_size = seq_lens_encoder.shape()[0];
constexpr int warps = 4;
const int warps = head_dim / 32;
const int tokens_block = (max_token + warps - 1) / warps;
dim3 grid_dims;
grid_dims.x = batch_size;
grid_dims.y = head_num;
grid_dims.z = tokens_block;
dim3 grid_dims(batch_size, head_num, tokens_block);
if (encoder_res.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
FillEncoderDecoderResKernel<warps>
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
const_cast<T *>(encoder_res.data<T>()),
const_cast<T *>(decoder_res.data<T>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seq_q.data<int>(),
head_num,
head_dim);
LAUNCH_KERNEL_BY_HEAD_DIM(T);
} else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
FillEncoderDecoderResKernel<warps>
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
const_cast<T *>(encoder_res.data<T>()),
const_cast<T *>(decoder_res.data<T>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seq_q.data<int>(),
head_num,
head_dim);
LAUNCH_KERNEL_BY_HEAD_DIM(T);
}
}
@@ -72,6 +72,7 @@ if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
cp_gather_indexer_k_quant_cache,
indexer_k_quant_and_cache,
merge_prefill_decode_output,
radix_topk_ragged_transform,
)
@@ -398,7 +399,6 @@ class DeepseekV3MLAAttention(nn.Layer):
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * forward_meta.mask_encoder_batch.cast(fmha_out_prefill.dtype)
fmha_out = fmha_out_prefill
if need_do_decode: # max_dec_len_this_time
@@ -433,7 +433,17 @@ class DeepseekV3MLAAttention(nn.Layer):
)
if need_do_prefill:
fmha_out += fmha_out_decode
merge_prefill_decode_output(
fmha_out,
fmha_out_decode,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_attention_heads_tp,
self.v_head_dim,
1,
)
else:
fmha_out = fmha_out_decode
@@ -116,9 +116,10 @@ class TestFusedRotaryPositionEncoding(unittest.TestCase):
self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True)
def test_large_num_tokens(self):
self._check_correctness(num_tokens=10, num_heads=2, num_kv_heads=2, head_size=4, rot_dim=4, is_neox=False)
def test_exceed_max_tokens(self):
"""
测试算子支持大量 tokens(超过 65535
算子使用 2D grid,理论上可支持 65535*65535 个 tokens
"""
num_tokens, num_heads, head_size = 65537, 1, 4
num_kv_heads, rot_dim = 1, 4
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
@@ -126,8 +127,10 @@ class TestFusedRotaryPositionEncoding(unittest.TestCase):
position_ids_np = np.arange(num_tokens, dtype="int32")
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
with self.assertRaises(Exception):
self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False)
# 不应该抛出异常,算子应该能处理大量 tokens
query_out, key_out = self._run_op(
query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False
)
if __name__ == "__main__":