[Optimization][DeepSeekV3.2]Reducing slot_mapping compute frequency from twice per layer to a single pre-processing step. (#7367)

This commit is contained in:
ShaneGZhu
2026-04-16 19:54:12 +08:00
committed by GitHub
parent d2d633b05c
commit 2d8338f9e4
10 changed files with 73 additions and 146 deletions
+4 -6
View File
@@ -540,12 +540,10 @@ std::vector<paddle::Tensor> count_tokens_per_expert_func(
const paddle::Tensor& topk_ids,
int64_t num_experts,
bool compute_padded_cumsum = false);
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch);
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids);
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
@@ -20,8 +20,7 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
const int* seq_lens_this_time,
int* position_ids, // 输出的一维 position_ids
int* mask_encoder_batch,
const int bsz) { // 批次大小
const int bsz) { // 批次大小
// 当前线程索引(每个线程对应一个批次)
int tid = threadIdx.x;
if (tid >= bsz) return;
@@ -43,7 +42,6 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
// 写入 encoder 的 position_ids
for (int i = 0; i < encoder_len; i++) {
position_ids[offset + i] = i;
mask_encoder_batch[offset + i] = 1;
}
offset += encoder_len;
@@ -51,17 +49,14 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
if (decoder_len > 0) {
for (int i = 0; i < seq_len_this_time; i++) {
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
mask_encoder_batch[offset + i] = 0;
}
}
}
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch) {
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids) {
const int bsz = seq_lens_this_time.shape()[0];
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
@@ -69,17 +64,16 @@ void GetPositionIdsAndMaskEncoderBatch(
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
const_cast<int*>(position_ids.data<int>()),
const_cast<int*>(mask_encoder_batch.data<int>()),
bsz);
}
PD_BUILD_STATIC_OP(get_position_ids_and_mask_encoder_batch)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"position_ids",
"mask_encoder_batch"})
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"},
{"mask_encoder_batch", "mask_encoder_batch_out"}})
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"position_ids",
})
.Outputs({"position_ids_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"}})
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));