mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization][DeepSeekV3.2]Reducing slot_mapping compute frequency from twice per layer to a single pre-processing step. (#7367)
This commit is contained in:
@@ -540,12 +540,10 @@ std::vector<paddle::Tensor> count_tokens_per_expert_func(
|
|||||||
const paddle::Tensor& topk_ids,
|
const paddle::Tensor& topk_ids,
|
||||||
int64_t num_experts,
|
int64_t num_experts,
|
||||||
bool compute_padded_cumsum = false);
|
bool compute_padded_cumsum = false);
|
||||||
void GetPositionIdsAndMaskEncoderBatch(
|
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& position_ids);
|
||||||
const paddle::Tensor& position_ids,
|
|
||||||
const paddle::Tensor& mask_encoder_batch);
|
|
||||||
|
|
||||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||||
const paddle::Tensor& kv_nope,
|
const paddle::Tensor& kv_nope,
|
||||||
|
|||||||
@@ -20,8 +20,7 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
|
|||||||
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
|
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
|
||||||
const int* seq_lens_this_time,
|
const int* seq_lens_this_time,
|
||||||
int* position_ids, // 输出的一维 position_ids
|
int* position_ids, // 输出的一维 position_ids
|
||||||
int* mask_encoder_batch,
|
const int bsz) { // 批次大小
|
||||||
const int bsz) { // 批次大小
|
|
||||||
// 当前线程索引(每个线程对应一个批次)
|
// 当前线程索引(每个线程对应一个批次)
|
||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
if (tid >= bsz) return;
|
if (tid >= bsz) return;
|
||||||
@@ -43,7 +42,6 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
|
|||||||
// 写入 encoder 的 position_ids
|
// 写入 encoder 的 position_ids
|
||||||
for (int i = 0; i < encoder_len; i++) {
|
for (int i = 0; i < encoder_len; i++) {
|
||||||
position_ids[offset + i] = i;
|
position_ids[offset + i] = i;
|
||||||
mask_encoder_batch[offset + i] = 1;
|
|
||||||
}
|
}
|
||||||
offset += encoder_len;
|
offset += encoder_len;
|
||||||
|
|
||||||
@@ -51,17 +49,14 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel(
|
|||||||
if (decoder_len > 0) {
|
if (decoder_len > 0) {
|
||||||
for (int i = 0; i < seq_len_this_time; i++) {
|
for (int i = 0; i < seq_len_this_time; i++) {
|
||||||
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
|
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
|
||||||
mask_encoder_batch[offset + i] = 0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetPositionIdsAndMaskEncoderBatch(
|
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& position_ids) {
|
||||||
const paddle::Tensor& position_ids,
|
|
||||||
const paddle::Tensor& mask_encoder_batch) {
|
|
||||||
const int bsz = seq_lens_this_time.shape()[0];
|
const int bsz = seq_lens_this_time.shape()[0];
|
||||||
|
|
||||||
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
|
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
|
||||||
@@ -69,17 +64,16 @@ void GetPositionIdsAndMaskEncoderBatch(
|
|||||||
seq_lens_decoder.data<int>(),
|
seq_lens_decoder.data<int>(),
|
||||||
seq_lens_this_time.data<int>(),
|
seq_lens_this_time.data<int>(),
|
||||||
const_cast<int*>(position_ids.data<int>()),
|
const_cast<int*>(position_ids.data<int>()),
|
||||||
const_cast<int*>(mask_encoder_batch.data<int>()),
|
|
||||||
bsz);
|
bsz);
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(get_position_ids_and_mask_encoder_batch)
|
PD_BUILD_STATIC_OP(get_position_ids_and_mask_encoder_batch)
|
||||||
.Inputs({"seq_lens_encoder",
|
.Inputs({
|
||||||
"seq_lens_decoder",
|
"seq_lens_encoder",
|
||||||
"seq_lens_this_time",
|
"seq_lens_decoder",
|
||||||
"position_ids",
|
"seq_lens_this_time",
|
||||||
"mask_encoder_batch"})
|
"position_ids",
|
||||||
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
|
})
|
||||||
.SetInplaceMap({{"position_ids", "position_ids_out"},
|
.Outputs({"position_ids_out"})
|
||||||
{"mask_encoder_batch", "mask_encoder_batch_out"}})
|
.SetInplaceMap({{"position_ids", "position_ids_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));
|
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));
|
||||||
|
|||||||
@@ -160,7 +160,8 @@ class ForwardMeta:
|
|||||||
|
|
||||||
# for mla & dsa
|
# for mla & dsa
|
||||||
position_ids: Optional[paddle.Tensor] = None
|
position_ids: Optional[paddle.Tensor] = None
|
||||||
mask_encoder_batch: Optional[paddle.Tensor] = None
|
# for kvcache slot
|
||||||
|
slot_mapping: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
real_bsz: int = 0
|
real_bsz: int = 0
|
||||||
|
|
||||||
|
|||||||
@@ -54,33 +54,6 @@ def yarn_get_mscale(scale=1, mscale=1):
|
|||||||
return 0.1 * mscale * math.log(scale) + 1.0
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
def compute_slot_mapping(
|
|
||||||
block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req]
|
|
||||||
positions: paddle.Tensor, # [num_tokens] 每个token的位置
|
|
||||||
batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求
|
|
||||||
block_size: int,
|
|
||||||
) -> paddle.Tensor:
|
|
||||||
"""
|
|
||||||
计算 slot_mapping
|
|
||||||
|
|
||||||
公式: slot = block_id * block_size + offset_in_block
|
|
||||||
"""
|
|
||||||
# 1. 计算每个 token 对应的 block 索引
|
|
||||||
block_idx = positions // block_size # [num_tokens]
|
|
||||||
|
|
||||||
# 2. 从 block_tables 中查表获取 block_id
|
|
||||||
# block_tables[batch_id_per_token, block_idx]
|
|
||||||
block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens]
|
|
||||||
|
|
||||||
# 3. 计算在 block 内的偏移
|
|
||||||
block_offset = positions % block_size # [num_tokens]
|
|
||||||
|
|
||||||
# 4. 计算 slot_mapping
|
|
||||||
slot_mapping = block_ids * block_size + block_offset
|
|
||||||
|
|
||||||
return slot_mapping.cast(paddle.int64)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DSAAttentionMetadata(AttentionMetadata):
|
class DSAAttentionMetadata(AttentionMetadata):
|
||||||
"""
|
"""
|
||||||
@@ -346,18 +319,11 @@ class DSAAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
scale = paddle.abs(compressed_kv).max() / 200.0
|
scale = paddle.abs(compressed_kv).max() / 200.0
|
||||||
|
|
||||||
slot_mapping = compute_slot_mapping(
|
|
||||||
forward_meta.block_tables,
|
|
||||||
forward_meta.position_ids,
|
|
||||||
forward_meta.batch_id_per_token,
|
|
||||||
self.block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
dsk_attn_write_cache(
|
dsk_attn_write_cache(
|
||||||
compressed_kv,
|
compressed_kv,
|
||||||
k_pe,
|
k_pe,
|
||||||
latent_cache,
|
latent_cache,
|
||||||
slot_mapping,
|
forward_meta.slot_mapping,
|
||||||
scale.cast(paddle.float32),
|
scale.cast(paddle.float32),
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ from fastdeploy.model_executor.layers.linear import (
|
|||||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||||
from fastdeploy.model_executor.layers.normalization import LayerNorm, RMSNorm
|
from fastdeploy.model_executor.layers.normalization import LayerNorm, RMSNorm
|
||||||
|
from fastdeploy.model_executor.layers.quantization.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.layers.rotary_embedding import (
|
from fastdeploy.model_executor.layers.rotary_embedding import (
|
||||||
DeepseekScalingRotaryEmbedding,
|
DeepseekScalingRotaryEmbedding,
|
||||||
)
|
)
|
||||||
@@ -59,16 +62,6 @@ from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
|
|||||||
)
|
)
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.is_cuda() or current_platform.is_maca():
|
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
|
||||||
get_position_ids_and_mask_encoder_batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.quantization.fp8_utils import (
|
|
||||||
per_token_group_quant_fp8,
|
|
||||||
)
|
|
||||||
from fastdeploy.platforms import current_platform
|
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
cp_gather_indexer_k_quant_cache,
|
cp_gather_indexer_k_quant_cache,
|
||||||
@@ -471,33 +464,6 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def compute_slot_mapping(
|
|
||||||
block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req]
|
|
||||||
positions: paddle.Tensor, # [num_tokens] 每个token的位置
|
|
||||||
batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求
|
|
||||||
block_size: int,
|
|
||||||
) -> paddle.Tensor:
|
|
||||||
"""
|
|
||||||
计算 slot_mapping
|
|
||||||
|
|
||||||
公式: slot = block_id * block_size + offset_in_block
|
|
||||||
"""
|
|
||||||
# 1. 计算每个 token 对应的 block 索引
|
|
||||||
block_idx = positions // block_size # [num_tokens]
|
|
||||||
|
|
||||||
# 2. 从 block_tables 中查表获取 block_id
|
|
||||||
# block_tables[batch_id_per_token, block_idx]
|
|
||||||
block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens]
|
|
||||||
|
|
||||||
# 3. 计算在 block 内的偏移
|
|
||||||
block_offset = positions % block_size # [num_tokens]
|
|
||||||
|
|
||||||
# 4. 计算 slot_mapping
|
|
||||||
slot_mapping = block_ids * block_size + block_offset
|
|
||||||
|
|
||||||
return slot_mapping.cast(paddle.int64)
|
|
||||||
|
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
@@ -686,17 +652,12 @@ class Indexer(nn.Layer):
|
|||||||
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.index_n_heads**-0.5
|
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.index_n_heads**-0.5
|
||||||
weights = weights.squeeze(-1)
|
weights = weights.squeeze(-1)
|
||||||
|
|
||||||
slot_mapping = compute_slot_mapping(
|
|
||||||
forward_meta.block_tables,
|
|
||||||
forward_meta.position_ids,
|
|
||||||
forward_meta.batch_id_per_token,
|
|
||||||
64,
|
|
||||||
)
|
|
||||||
|
|
||||||
indexer_top_k = paddle.full([q_fp8.shape[0], self.index_topk], -1, dtype="int32")
|
indexer_top_k = paddle.full([q_fp8.shape[0], self.index_topk], -1, dtype="int32")
|
||||||
|
|
||||||
# indexer write_cache
|
# indexer write_cache
|
||||||
indexer_k_quant_and_cache(k, self.indexer_cache, slot_mapping, self.quant_block_size, self.scale_fmt)
|
indexer_k_quant_and_cache(
|
||||||
|
k, self.indexer_cache, forward_meta.slot_mapping, self.quant_block_size, self.scale_fmt
|
||||||
|
)
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm
|
from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm
|
||||||
|
|
||||||
@@ -1172,12 +1133,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
num_embeddings=fd_config.model_config.vocab_size,
|
num_embeddings=fd_config.model_config.vocab_size,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
)
|
)
|
||||||
self.position_ids_buffer = paddle.empty(
|
|
||||||
[fd_config.scheduler_config.max_num_batched_tokens], dtype=paddle.int32
|
|
||||||
)
|
|
||||||
self.mask_encoder_batch_buffer = paddle.empty(
|
|
||||||
[fd_config.scheduler_config.max_num_batched_tokens, 1], dtype=paddle.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name(cls):
|
def name(cls):
|
||||||
@@ -1274,25 +1229,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
logits[:, self.ori_vocab_size :] = -float("inf")
|
logits[:, self.ori_vocab_size :] = -float("inf")
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def pre_process(self, forward_meta):
|
|
||||||
""" """
|
|
||||||
seq_lens_encoder = forward_meta.seq_lens_encoder
|
|
||||||
seq_lens_decoder = forward_meta.seq_lens_decoder
|
|
||||||
seq_lens_this_time = forward_meta.seq_lens_this_time
|
|
||||||
|
|
||||||
current_total_tokens = forward_meta.ids_remove_padding.shape[0]
|
|
||||||
position_ids = self.position_ids_buffer[:current_total_tokens]
|
|
||||||
mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens]
|
|
||||||
|
|
||||||
get_position_ids_and_mask_encoder_batch(
|
|
||||||
seq_lens_encoder,
|
|
||||||
seq_lens_decoder,
|
|
||||||
seq_lens_this_time,
|
|
||||||
position_ids,
|
|
||||||
mask_encoder_batch,
|
|
||||||
)
|
|
||||||
return position_ids, mask_encoder_batch
|
|
||||||
|
|
||||||
def empty_input_forward(self, forward_meta):
|
def empty_input_forward(self, forward_meta):
|
||||||
"""
|
"""
|
||||||
empty_input_forward
|
empty_input_forward
|
||||||
@@ -1313,7 +1249,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
forward_meta: ForwardMeta,
|
forward_meta: ForwardMeta,
|
||||||
):
|
):
|
||||||
ids_remove_padding = inputs["ids_remove_padding"]
|
ids_remove_padding = inputs["ids_remove_padding"]
|
||||||
forward_meta.position_ids, forward_meta.mask_encoder_batch = self.pre_process(forward_meta)
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
ids_remove_padding=ids_remove_padding,
|
ids_remove_padding=ids_remove_padding,
|
||||||
forward_meta=forward_meta,
|
forward_meta=forward_meta,
|
||||||
|
|||||||
@@ -45,6 +45,12 @@ from fastdeploy.model_executor.layers.attention.append_attn_backend import (
|
|||||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.layers.attention.dsa_attention_backend import (
|
||||||
|
DSAAttentionBackend,
|
||||||
|
)
|
||||||
|
from fastdeploy.model_executor.layers.attention.mla_attention_backend import (
|
||||||
|
MLAAttentionBackend,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||||
RoutingReplayManager,
|
RoutingReplayManager,
|
||||||
)
|
)
|
||||||
@@ -79,6 +85,7 @@ else:
|
|||||||
speculate_schedule_cache,
|
speculate_schedule_cache,
|
||||||
set_data_ipc,
|
set_data_ipc,
|
||||||
unset_data_ipc,
|
unset_data_ipc,
|
||||||
|
get_position_ids_and_mask_encoder_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
@@ -1267,6 +1274,33 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
return token_num, token_num_event
|
return token_num, token_num_event
|
||||||
|
|
||||||
|
def _compute_position_ids_and_slot_mapping(self) -> None:
|
||||||
|
"""Compute position_ids and slot_mapping for KV cache addressing.
|
||||||
|
This is a general computation based on sequence length info and block tables,
|
||||||
|
applicable to all models that need per-token KV cache physical slot addresses.
|
||||||
|
Results are stored in self.forward_meta.
|
||||||
|
"""
|
||||||
|
# NOTE(zhushengguang): Only support MLAAttentionBackend and DSAAttentionBackend currently.
|
||||||
|
if not isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)):
|
||||||
|
return
|
||||||
|
current_total_tokens = self.forward_meta.ids_remove_padding.shape[0]
|
||||||
|
position_ids = self.share_inputs["position_ids_buffer"][:current_total_tokens]
|
||||||
|
get_position_ids_and_mask_encoder_batch(
|
||||||
|
self.forward_meta.seq_lens_encoder,
|
||||||
|
self.forward_meta.seq_lens_decoder,
|
||||||
|
self.forward_meta.seq_lens_this_time,
|
||||||
|
position_ids,
|
||||||
|
)
|
||||||
|
block_size = self.cache_config.block_size
|
||||||
|
block_idx = position_ids // block_size # [num_tokens]
|
||||||
|
assert self.forward_meta.batch_id_per_token.shape == block_idx.shape
|
||||||
|
block_ids = self.forward_meta.block_tables[self.forward_meta.batch_id_per_token, block_idx] # [num_tokens]
|
||||||
|
block_offset = position_ids % block_size # [num_tokens]
|
||||||
|
slot_mapping = self.share_inputs["slot_mapping_buffer"][:current_total_tokens]
|
||||||
|
paddle.assign((block_ids * block_size + block_offset).cast(paddle.int64), slot_mapping)
|
||||||
|
self.forward_meta.position_ids = position_ids
|
||||||
|
self.forward_meta.slot_mapping = slot_mapping
|
||||||
|
|
||||||
def _process_reorder(self) -> None:
|
def _process_reorder(self) -> None:
|
||||||
if self.attn_backends and getattr(self.attn_backends[0], "enable_ids_reorder", False):
|
if self.attn_backends and getattr(self.attn_backends[0], "enable_ids_reorder", False):
|
||||||
self.share_inputs.enable_pd_reorder = True
|
self.share_inputs.enable_pd_reorder = True
|
||||||
@@ -1860,6 +1894,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# 2. Padding inputs for cuda graph
|
# 2. Padding inputs for cuda graph
|
||||||
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
|
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
|
||||||
self.padding_cudagraph_inputs()
|
self.padding_cudagraph_inputs()
|
||||||
|
# Compute position_ids and slot_mapping
|
||||||
|
self._compute_position_ids_and_slot_mapping()
|
||||||
|
|
||||||
model_inputs = {}
|
model_inputs = {}
|
||||||
model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"]
|
model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"]
|
||||||
@@ -2197,6 +2233,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
# Padding inputs for cuda graph
|
# Padding inputs for cuda graph
|
||||||
self.padding_cudagraph_inputs()
|
self.padding_cudagraph_inputs()
|
||||||
|
# Compute position_ids and slot_mapping
|
||||||
|
self._compute_position_ids_and_slot_mapping()
|
||||||
|
|
||||||
model_inputs = {}
|
model_inputs = {}
|
||||||
model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"]
|
model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"]
|
||||||
|
|||||||
@@ -188,6 +188,11 @@ class InputBatch:
|
|||||||
self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32")
|
self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32")
|
||||||
self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32")
|
self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32")
|
||||||
|
|
||||||
|
# Initialize addressing buffers
|
||||||
|
_max_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
self.position_ids_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int32)
|
||||||
|
self.slot_mapping_buffer = paddle.zeros([_max_batched_tokens], dtype=paddle.int64)
|
||||||
|
|
||||||
# Declare AttentionBackend buffers
|
# Declare AttentionBackend buffers
|
||||||
self.decoder_batch_ids = None
|
self.decoder_batch_ids = None
|
||||||
self.decoder_tile_ids_per_batch = None
|
self.decoder_tile_ids_per_batch = None
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ class MockFDConfig:
|
|||||||
name = "default"
|
name = "default"
|
||||||
splitwise_role = "mixed"
|
splitwise_role = "mixed"
|
||||||
max_num_seqs = 2
|
max_num_seqs = 2
|
||||||
|
max_num_batched_tokens = 2048
|
||||||
|
|
||||||
parallel_config = ParallelConfig()
|
parallel_config = ParallelConfig()
|
||||||
scheduler_config = SchedulerConfig()
|
scheduler_config = SchedulerConfig()
|
||||||
|
|||||||
@@ -33,24 +33,17 @@ class TestGetPositionIdsAndMaskEncoderBatch(unittest.TestCase):
|
|||||||
|
|
||||||
total_len = int(seq_lens_encoder.numpy().sum() + seq_lens_this_time.numpy().sum())
|
total_len = int(seq_lens_encoder.numpy().sum() + seq_lens_this_time.numpy().sum())
|
||||||
position_ids = paddle.zeros([total_len], dtype="int32")
|
position_ids = paddle.zeros([total_len], dtype="int32")
|
||||||
mask_encoder_batch = paddle.zeros([total_len], dtype="int32")
|
|
||||||
|
|
||||||
# Call the custom operator
|
# Call the custom operator
|
||||||
get_position_ids_and_mask_encoder_batch(
|
get_position_ids_and_mask_encoder_batch(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids)
|
||||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids, mask_encoder_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_position_ids = np.array([0, 1, 2, 1, 0, 1, 2, 3], dtype=np.int32)
|
expected_position_ids = np.array([0, 1, 2, 1, 0, 1, 2, 3], dtype=np.int32)
|
||||||
|
|
||||||
expected_mask = np.array([1, 1, 1, 0, 1, 1, 0, 0], dtype=np.int32)
|
|
||||||
|
|
||||||
# Convert to numpy for comparison
|
# Convert to numpy for comparison
|
||||||
position_ids_np = position_ids.numpy()
|
position_ids_np = position_ids.numpy()
|
||||||
mask_encoder_batch_np = mask_encoder_batch.numpy()
|
|
||||||
|
|
||||||
# Assert equality
|
# Assert equality
|
||||||
np.testing.assert_array_equal(position_ids_np, expected_position_ids)
|
np.testing.assert_array_equal(position_ids_np, expected_position_ids)
|
||||||
np.testing.assert_array_equal(mask_encoder_batch_np, expected_mask)
|
|
||||||
|
|
||||||
def test_empty_decoder(self):
|
def test_empty_decoder(self):
|
||||||
# Test case where decoder length is 0
|
# Test case where decoder length is 0
|
||||||
@@ -59,17 +52,12 @@ class TestGetPositionIdsAndMaskEncoderBatch(unittest.TestCase):
|
|||||||
seq_lens_this_time = paddle.to_tensor([0], dtype="int32")
|
seq_lens_this_time = paddle.to_tensor([0], dtype="int32")
|
||||||
|
|
||||||
position_ids = paddle.zeros([2], dtype="int32")
|
position_ids = paddle.zeros([2], dtype="int32")
|
||||||
mask_encoder_batch = paddle.zeros([2], dtype="int32")
|
|
||||||
|
|
||||||
get_position_ids_and_mask_encoder_batch(
|
get_position_ids_and_mask_encoder_batch(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids)
|
||||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids, mask_encoder_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_position_ids = np.array([0, 1], dtype=np.int32)
|
expected_position_ids = np.array([0, 1], dtype=np.int32)
|
||||||
expected_mask = np.array([1, 1], dtype=np.int32)
|
|
||||||
|
|
||||||
np.testing.assert_array_equal(position_ids.numpy(), expected_position_ids)
|
np.testing.assert_array_equal(position_ids.numpy(), expected_position_ids)
|
||||||
np.testing.assert_array_equal(mask_encoder_batch.numpy(), expected_mask)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ def create_mock_config():
|
|||||||
|
|
||||||
scheduler_config = Mock(spec=SchedulerConfig)
|
scheduler_config = Mock(spec=SchedulerConfig)
|
||||||
scheduler_config.max_num_seqs = 10
|
scheduler_config.max_num_seqs = 10
|
||||||
|
scheduler_config.max_num_batched_tokens = 2048
|
||||||
|
|
||||||
speculative_config = Mock(spec=SpeculativeConfig)
|
speculative_config = Mock(spec=SpeculativeConfig)
|
||||||
speculative_config.method = None
|
speculative_config.method = None
|
||||||
|
|||||||
Reference in New Issue
Block a user