[MTP] refactor MTP pre_process (#6358)

This commit is contained in:
周周周
2026-02-09 10:47:15 +08:00
committed by GitHub
parent 18e79dd660
commit 2b4748de4f
24 changed files with 411 additions and 533 deletions
+64 -57
View File
@@ -21,7 +21,6 @@ __global__ void RebuildPaddingKernel(T *output_data,
const int *seq_len_this_time,
const int *seq_len_decoder,
const int *seq_len_encoder,
const int max_input_length,
const int dim_embed,
const int elem_nums) {
using LoadT = AlignedVector<T, VecSize>;
@@ -51,8 +50,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
const int *seq_len_this_time,
const int *seq_len_decoder,
const int *seq_len_encoder,
const int *output_padding_offset,
const int max_input_length,
const int *batch_id_per_token_output,
const int *cu_seqlens_q_output,
const int dim_embed,
const int64_t output_elem_nums,
const int bsz,
@@ -62,17 +61,18 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
for (int64_t i = global_idx * VecSize; i < output_elem_nums;
i += gridDim.x * blockDim.x * VecSize) {
const int out_token_id = i / dim_embed;
const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
const int bi = ori_token_id / max_input_length;
int seq_id = 0;
const int bi = batch_id_per_token_output[out_token_id];
if (seq_len_this_time[bi] == 0) continue;
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1;
const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi];
const int input_token_id = ori_token_id - cum_offset_bi + seq_id;
int seq_id = 0;
if (seq_len_encoder[bi] > 0) {
seq_id = seq_len_encoder[bi] - 1;
} else {
seq_id = out_token_id - cu_seqlens_q_output[bi];
}
const int input_token_id = cu_seqlens_q[bi] + seq_id;
const int bias_idx = i % dim_embed;
Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx],
@@ -95,9 +95,10 @@ std::vector<paddle::Tensor> rebuild_padding(
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
@@ -117,8 +118,8 @@ std::vector<paddle::Tensor> rebuild_padding(
paddle::Tensor out;
int output_token_num;
if (output_padding_offset) {
output_token_num = output_padding_offset.get().shape()[0];
if (batch_id_per_token_output) {
output_token_num = batch_id_per_token_output.get().shape()[0];
} else {
output_token_num = bsz;
}
@@ -131,7 +132,7 @@ std::vector<paddle::Tensor> rebuild_padding(
int pack_num = elem_nums / PackSize;
const int blocksize = 128;
const int grid_size = (pack_num + blocksize - 1) / blocksize;
if (output_padding_offset) {
if (batch_id_per_token_output) {
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
@@ -144,8 +145,8 @@ std::vector<paddle::Tensor> rebuild_padding(
seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
output_padding_offset.get_ptr()->data<int>(),
max_input_length,
batch_id_per_token_output.get_ptr()->data<int>(),
cu_seqlens_q_output.get_ptr()->data<int>(),
dim_embed,
elem_nums,
bsz,
@@ -160,7 +161,6 @@ std::vector<paddle::Tensor> rebuild_padding(
seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
max_input_length,
dim_embed,
elem_nums);
}
@@ -173,43 +173,46 @@ paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
switch (tmp_out.type()) {
case paddle::DataType::BFLOAT16: {
return rebuild_padding<paddle::DataType::BFLOAT16>(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
first_token_out,
max_input_length,
enable_logprob)[0];
return rebuild_padding<paddle::DataType::BFLOAT16>(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
enable_logprob)[0];
}
case paddle::DataType::FLOAT16: {
return rebuild_padding<paddle::DataType::FLOAT16>(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
first_token_out,
max_input_length,
enable_logprob)[0];
return rebuild_padding<paddle::DataType::FLOAT16>(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
enable_logprob)[0];
}
case paddle::DataType::FLOAT32: {
return rebuild_padding<paddle::DataType::FLOAT32>(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
first_token_out,
max_input_length,
enable_logprob)[0];
return rebuild_padding<paddle::DataType::FLOAT32>(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
enable_logprob)[0];
}
default: {
PD_THROW(
@@ -226,18 +229,18 @@ std::vector<paddle::Tensor> RebuildPadding(
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
return {RebuildPaddingFunc(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
max_input_length,
enable_logprob)};
}
@@ -247,10 +250,12 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
const std::vector<int64_t> &seq_len_this_time_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_encoder_shape,
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
const paddle::optional<std::vector<int64_t>>
&batch_id_per_token_output_shape,
const paddle::optional<std::vector<int64_t>> &cu_seqlens_q_output_shape) {
int64_t dim_embed = tmp_out_shape[1];
// whether speculative decoding
if (output_padding_offset_shape) {
if (batch_id_per_token_output_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cu_seqlens_q_shape[0] - 1;
@@ -264,7 +269,8 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
const paddle::DataType &seq_len_this_time_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_encoder_dtype,
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
const paddle::optional<paddle::DataType> &batch_id_per_token_output_dtype,
const paddle::optional<paddle::DataType> &cu_seqlens_q_output_dtype) {
return {tmp_out_dtype};
}
@@ -274,10 +280,11 @@ PD_BUILD_STATIC_OP(rebuild_padding)
"seq_len_this_time",
"seq_lens_decoder",
"seq_lens_encoder",
paddle::Optional("output_padding_offset"),
paddle::Optional("batch_id_per_token_output"),
paddle::Optional("cu_seqlens_q_output"),
paddle::Optional("first_token_out")})
.Outputs({"out"})
.Attrs({"max_input_length: int", "enable_logprob: bool"})
.Attrs({"enable_logprob: bool"})
.SetKernelFn(PD_KERNEL(RebuildPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));