[Optimization][DeepSeekV3.2]Reducing slot_mapping compute frequency from twice per layer to a single pre-processing step. (#7367)

This commit is contained in:
ShaneGZhu
2026-04-16 19:54:12 +08:00
committed by GitHub
parent d2d633b05c
commit 2d8338f9e4
10 changed files with 73 additions and 146 deletions
+2 -4
View File
@@ -540,12 +540,10 @@ std::vector<paddle::Tensor> count_tokens_per_expert_func(
const paddle::Tensor& topk_ids,
int64_t num_experts,
bool compute_padded_cumsum = false);
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch);
const paddle::Tensor& position_ids);
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
@@ -20,7 +20,6 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
const int* seq_lens_this_time,
int* position_ids, // 输出的一维 position_ids
int* mask_encoder_batch,
const int bsz) { // 批次大小
// 当前线程索引(每个线程对应一个批次)
int tid = threadIdx.x;
@@ -43,7 +42,6 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
// 写入 encoder 的 position_ids
for (int i = 0; i < encoder_len; i++) {
position_ids[offset + i] = i;
mask_encoder_batch[offset + i] = 1;
}
offset += encoder_len;
@@ -51,17 +49,14 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
if (decoder_len > 0) {
for (int i = 0; i < seq_len_this_time; i++) {
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
mask_encoder_batch[offset + i] = 0;
}
}
}
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch) {
const paddle::Tensor& position_ids) {
const int bsz = seq_lens_this_time.shape()[0];
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
@@ -69,17 +64,16 @@ void GetPositionIdsAndMaskEncoderBatch(
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
const_cast<int*>(position_ids.data<int>()),
const_cast<int*>(mask_encoder_batch.data<int>()),
bsz);
}
PD_BUILD_STATIC_OP(get_position_ids_and_mask_encoder_batch)
.Inputs({"seq_lens_encoder",
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"position_ids",
"mask_encoder_batch"})
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"},
{"mask_encoder_batch", "mask_encoder_batch_out"}})
})
.Outputs({"position_ids_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"}})
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));
+2 -1
View File
@@ -160,7 +160,8 @@ class ForwardMeta:
# for mla & dsa
position_ids: Optional[paddle.Tensor] = None
mask_encoder_batch: Optional[paddle.Tensor] = None
# for kvcache slot
slot_mapping: Optional[paddle.Tensor] = None
real_bsz: int = 0
@@ -54,33 +54,6 @@ def yarn_get_mscale(scale=1, mscale=1):
return 0.1 * mscale * math.log(scale) + 1.0
def compute_slot_mapping(
block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req]
positions: paddle.Tensor, # [num_tokens] 每个token的位置
batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求
block_size: int,
) -> paddle.Tensor:
"""
计算 slot_mapping
公式: slot = block_id * block_size + offset_in_block
"""
# 1. 计算每个 token 对应的 block 索引
block_idx = positions // block_size # [num_tokens]
# 2. 从 block_tables 中查表获取 block_id
# block_tables[batch_id_per_token, block_idx]
block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens]
# 3. 计算在 block 内的偏移
block_offset = positions % block_size # [num_tokens]
# 4. 计算 slot_mapping
slot_mapping = block_ids * block_size + block_offset
return slot_mapping.cast(paddle.int64)
@dataclass
class DSAAttentionMetadata(AttentionMetadata):
"""
@@ -346,18 +319,11 @@ class DSAAttentionBackend(AttentionBackend):
scale = paddle.abs(compressed_kv).max() / 200.0
slot_mapping = compute_slot_mapping(
forward_meta.block_tables,
forward_meta.position_ids,
forward_meta.batch_id_per_token,
self.block_size,
)
dsk_attn_write_cache(
compressed_kv,
k_pe,
latent_cache,
slot_mapping,
forward_meta.slot_mapping,
scale.cast(paddle.float32),
"fp8_ds_mla",
)
@@ -46,6 +46,9 @@ from fastdeploy.model_executor.layers.linear import (
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import LayerNorm, RMSNorm
from fastdeploy.model_executor.layers.quantization.fp8_utils import (
per_token_group_quant_fp8,
)
from fastdeploy.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding,
)
@@ -59,16 +62,6 @@ from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda() or current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import (
get_position_ids_and_mask_encoder_batch,
)
from fastdeploy.model_executor.layers.quantization.fp8_utils import (
per_token_group_quant_fp8,
)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
cp_gather_indexer_k_quant_cache,
@@ -471,33 +464,6 @@ class DeepseekV3MLAAttention(nn.Layer):
return output
def compute_slot_mapping(
block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req]
positions: paddle.Tensor, # [num_tokens] 每个token的位置
batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求
block_size: int,
) -> paddle.Tensor:
"""
计算 slot_mapping
公式: slot = block_id * block_size + offset_in_block
"""
# 1. 计算每个 token 对应的 block 索引
block_idx = positions // block_size # [num_tokens]
# 2. 从 block_tables 中查表获取 block_id
# block_tables[batch_id_per_token, block_idx]
block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens]
# 3. 计算在 block 内的偏移
block_offset = positions % block_size # [num_tokens]
# 4. 计算 slot_mapping
slot_mapping = block_ids * block_size + block_offset
return slot_mapping.cast(paddle.int64)
import triton
import triton.language as tl
@@ -686,17 +652,12 @@ class Indexer(nn.Layer):
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.index_n_heads**-0.5
weights = weights.squeeze(-1)
slot_mapping = compute_slot_mapping(
forward_meta.block_tables,
forward_meta.position_ids,
forward_meta.batch_id_per_token,
64,
)
indexer_top_k = paddle.full([q_fp8.shape[0], self.index_topk], -1, dtype="int32")
# indexer write_cache
indexer_k_quant_and_cache(k, self.indexer_cache, slot_mapping, self.quant_block_size, self.scale_fmt)
indexer_k_quant_and_cache(
k, self.indexer_cache, forward_meta.slot_mapping, self.quant_block_size, self.scale_fmt
)
from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm
@@ -1172,12 +1133,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
num_embeddings=fd_config.model_config.vocab_size,
prefix="lm_head",
)
self.position_ids_buffer = paddle.empty(
[fd_config.scheduler_config.max_num_batched_tokens], dtype=paddle.int32
)
self.mask_encoder_batch_buffer = paddle.empty(
[fd_config.scheduler_config.max_num_batched_tokens, 1], dtype=paddle.int32
)
@classmethod
def name(cls):
@@ -1274,25 +1229,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
logits[:, self.ori_vocab_size :] = -float("inf")
return logits
def pre_process(self, forward_meta):
""" """
seq_lens_encoder = forward_meta.seq_lens_encoder
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
current_total_tokens = forward_meta.ids_remove_padding.shape[0]
position_ids = self.position_ids_buffer[:current_total_tokens]
mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens]
get_position_ids_and_mask_encoder_batch(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
position_ids,
mask_encoder_batch,
)
return position_ids, mask_encoder_batch
def empty_input_forward(self, forward_meta):
"""
empty_input_forward
@@ -1313,7 +1249,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
forward_meta: ForwardMeta,
):
ids_remove_padding = inputs["ids_remove_padding"]
forward_meta.position_ids, forward_meta.mask_encoder_batch = self.pre_process(forward_meta)
hidden_states = self.model(
ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta,
+38
View File
@@ -45,6 +45,12 @@ from fastdeploy.model_executor.layers.attention.append_attn_backend import (
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
)
from fastdeploy.model_executor.layers.attention.dsa_attention_backend import (
DSAAttentionBackend,
)
from fastdeploy.model_executor.layers.attention.mla_attention_backend import (
MLAAttentionBackend,
)
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
RoutingReplayManager,
)
@@ -79,6 +85,7 @@ else:
speculate_schedule_cache,
set_data_ipc,
unset_data_ipc,
get_position_ids_and_mask_encoder_batch,
)
import zmq
@@ -1267,6 +1274,33 @@ class GPUModelRunner(ModelRunnerBase):
)
return token_num, token_num_event
def _compute_position_ids_and_slot_mapping(self) -> None:
"""Compute position_ids and slot_mapping for KV cache addressing.
This is a general computation based on sequence length info and block tables,
applicable to all models that need per-token KV cache physical slot addresses.
Results are stored in self.forward_meta.
"""
# NOTE(zhushengguang): Only support MLAAttentionBackend and DSAAttentionBackend currently.
if not isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)):
return
current_total_tokens = self.forward_meta.ids_remove_padding.shape[0]
position_ids = self.share_inputs["position_ids_buffer"][:current_total_tokens]
get_position_ids_and_mask_encoder_batch(
self.forward_meta.seq_lens_encoder,
self.forward_meta.seq_lens_decoder,
self.forward_meta.seq_lens_this_time,
position_ids,
)
block_size = self.cache_config.block_size
block_idx = position_ids // block_size # [num_tokens]
assert self.forward_meta.batch_id_per_token.shape == block_idx.shape
block_ids = self.forward_meta.block_tables[self.forward_meta.batch_id_per_token, block_idx] # [num_tokens]
block_offset = position_ids % block_size # [num_tokens]
slot_mapping = self.share_inputs["slot_mapping_buffer"][:current_total_tokens]
paddle.assign((block_ids * block_size + block_offset).cast(paddle.int64), slot_mapping)
self.forward_meta.position_ids = position_ids
self.forward_meta.slot_mapping = slot_mapping
def _process_reorder(self) -> None:
if self.attn_backends and getattr(self.attn_backends[0], "enable_ids_reorder", False):
self.share_inputs.enable_pd_reorder = True
@@ -1860,6 +1894,8 @@ class GPUModelRunner(ModelRunnerBase):
# 2. Padding inputs for cuda graph
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
self.padding_cudagraph_inputs()
# Compute position_ids and slot_mapping
self._compute_position_ids_and_slot_mapping()
model_inputs = {}
model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"]
@@ -2197,6 +2233,8 @@ class GPUModelRunner(ModelRunnerBase):
# Padding inputs for cuda graph
self.padding_cudagraph_inputs()
# Compute position_ids and slot_mapping
self._compute_position_ids_and_slot_mapping()
model_inputs = {}
model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"]
+5
View File
@@ -188,6 +188,11 @@ class InputBatch:
self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32")
self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32")
# Initialize addressing buffers
_max_batched_tokens = self.scheduler_config.max_num_batched_tokens
self.position_ids_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int32)
self.slot_mapping_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int64)
# Declare AttentionBackend buffers
self.decoder_batch_ids = None
self.decoder_tile_ids_per_batch = None
+1
View File
@@ -85,6 +85,7 @@ class MockFDConfig:
name = "default"
splitwise_role = "mixed"
max_num_seqs = 2
max_num_batched_tokens = 2048
parallel_config = ParallelConfig()
scheduler_config = SchedulerConfig()
@@ -33,24 +33,17 @@ class TestGetPositionIdsAndMaskEncoderBatch(unittest.TestCase):
total_len = int(seq_lens_encoder.numpy().sum() + seq_lens_this_time.numpy().sum())
position_ids = paddle.zeros([total_len], dtype="int32")
mask_encoder_batch = paddle.zeros([total_len], dtype="int32")
# Call the custom operator
get_position_ids_and_mask_encoder_batch(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids, mask_encoder_batch
)
get_position_ids_and_mask_encoder_batch(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids)
expected_position_ids = np.array([0, 1, 2, 1, 0, 1, 2, 3], dtype=np.int32)
expected_mask = np.array([1, 1, 1, 0, 1, 1, 0, 0], dtype=np.int32)
# Convert to numpy for comparison
position_ids_np = position_ids.numpy()
mask_encoder_batch_np = mask_encoder_batch.numpy()
# Assert equality
np.testing.assert_array_equal(position_ids_np, expected_position_ids)
np.testing.assert_array_equal(mask_encoder_batch_np, expected_mask)
def test_empty_decoder(self):
# Test case where decoder length is 0
@@ -59,17 +52,12 @@ class TestGetPositionIdsAndMaskEncoderBatch(unittest.TestCase):
seq_lens_this_time = paddle.to_tensor([0], dtype="int32")
position_ids = paddle.zeros([2], dtype="int32")
mask_encoder_batch = paddle.zeros([2], dtype="int32")
get_position_ids_and_mask_encoder_batch(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids, mask_encoder_batch
)
get_position_ids_and_mask_encoder_batch(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids)
expected_position_ids = np.array([0, 1], dtype=np.int32)
expected_mask = np.array([1, 1], dtype=np.int32)
np.testing.assert_array_equal(position_ids.numpy(), expected_position_ids)
np.testing.assert_array_equal(mask_encoder_batch.numpy(), expected_mask)
if __name__ == "__main__":
@@ -59,6 +59,7 @@ def create_mock_config():
scheduler_config = Mock(spec=SchedulerConfig)
scheduler_config.max_num_seqs = 10
scheduler_config.max_num_batched_tokens = 2048
speculative_config = Mock(spec=SpeculativeConfig)
speculative_config.method = None