mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[MTP] refactor MTP pre_process (#6358)
This commit is contained in:
@@ -21,7 +21,6 @@ __global__ void RebuildPaddingKernel(T *output_data,
|
||||
const int *seq_len_this_time,
|
||||
const int *seq_len_decoder,
|
||||
const int *seq_len_encoder,
|
||||
const int max_input_length,
|
||||
const int dim_embed,
|
||||
const int elem_nums) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
@@ -51,8 +50,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
|
||||
const int *seq_len_this_time,
|
||||
const int *seq_len_decoder,
|
||||
const int *seq_len_encoder,
|
||||
const int *output_padding_offset,
|
||||
const int max_input_length,
|
||||
const int *batch_id_per_token_output,
|
||||
const int *cu_seqlens_q_output,
|
||||
const int dim_embed,
|
||||
const int64_t output_elem_nums,
|
||||
const int bsz,
|
||||
@@ -62,17 +61,18 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
|
||||
for (int64_t i = global_idx * VecSize; i < output_elem_nums;
|
||||
i += gridDim.x * blockDim.x * VecSize) {
|
||||
const int out_token_id = i / dim_embed;
|
||||
const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
|
||||
|
||||
const int bi = ori_token_id / max_input_length;
|
||||
|
||||
int seq_id = 0;
|
||||
const int bi = batch_id_per_token_output[out_token_id];
|
||||
if (seq_len_this_time[bi] == 0) continue;
|
||||
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
|
||||
|
||||
if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1;
|
||||
const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi];
|
||||
const int input_token_id = ori_token_id - cum_offset_bi + seq_id;
|
||||
int seq_id = 0;
|
||||
if (seq_len_encoder[bi] > 0) {
|
||||
seq_id = seq_len_encoder[bi] - 1;
|
||||
} else {
|
||||
seq_id = out_token_id - cu_seqlens_q_output[bi];
|
||||
}
|
||||
|
||||
const int input_token_id = cu_seqlens_q[bi] + seq_id;
|
||||
const int bias_idx = i % dim_embed;
|
||||
|
||||
Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx],
|
||||
@@ -95,9 +95,10 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
const paddle::Tensor &seq_len_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
|
||||
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
|
||||
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
@@ -117,8 +118,8 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
|
||||
paddle::Tensor out;
|
||||
int output_token_num;
|
||||
if (output_padding_offset) {
|
||||
output_token_num = output_padding_offset.get().shape()[0];
|
||||
if (batch_id_per_token_output) {
|
||||
output_token_num = batch_id_per_token_output.get().shape()[0];
|
||||
} else {
|
||||
output_token_num = bsz;
|
||||
}
|
||||
@@ -131,7 +132,7 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
const int grid_size = (pack_num + blocksize - 1) / blocksize;
|
||||
if (output_padding_offset) {
|
||||
if (batch_id_per_token_output) {
|
||||
RebuildAppendPaddingKernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
@@ -144,8 +145,8 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
seq_len_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
output_padding_offset.get_ptr()->data<int>(),
|
||||
max_input_length,
|
||||
batch_id_per_token_output.get_ptr()->data<int>(),
|
||||
cu_seqlens_q_output.get_ptr()->data<int>(),
|
||||
dim_embed,
|
||||
elem_nums,
|
||||
bsz,
|
||||
@@ -160,7 +161,6 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
seq_len_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
}
|
||||
@@ -173,43 +173,46 @@ paddle::Tensor RebuildPaddingFunc(
|
||||
const paddle::Tensor &seq_len_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
|
||||
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob) {
|
||||
switch (tmp_out.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return rebuild_padding<paddle::DataType::BFLOAT16>(tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)[0];
|
||||
return rebuild_padding<paddle::DataType::BFLOAT16>(
|
||||
tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
first_token_out,
|
||||
enable_logprob)[0];
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return rebuild_padding<paddle::DataType::FLOAT16>(tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)[0];
|
||||
return rebuild_padding<paddle::DataType::FLOAT16>(
|
||||
tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
first_token_out,
|
||||
enable_logprob)[0];
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return rebuild_padding<paddle::DataType::FLOAT32>(tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)[0];
|
||||
return rebuild_padding<paddle::DataType::FLOAT32>(
|
||||
tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
first_token_out,
|
||||
enable_logprob)[0];
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
@@ -226,18 +229,18 @@ std::vector<paddle::Tensor> RebuildPadding(
|
||||
const paddle::Tensor &seq_len_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
|
||||
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob) {
|
||||
return {RebuildPaddingFunc(tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)};
|
||||
}
|
||||
|
||||
@@ -247,10 +250,12 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||
const std::vector<int64_t> &seq_len_this_time_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
|
||||
const paddle::optional<std::vector<int64_t>>
|
||||
&batch_id_per_token_output_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &cu_seqlens_q_output_shape) {
|
||||
int64_t dim_embed = tmp_out_shape[1];
|
||||
// whether speculative decoding
|
||||
if (output_padding_offset_shape) {
|
||||
if (batch_id_per_token_output_shape) {
|
||||
return {{-1, dim_embed}};
|
||||
} else {
|
||||
int64_t bsz = cu_seqlens_q_shape[0] - 1;
|
||||
@@ -264,7 +269,8 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
||||
const paddle::DataType &seq_len_this_time_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
|
||||
const paddle::optional<paddle::DataType> &batch_id_per_token_output_dtype,
|
||||
const paddle::optional<paddle::DataType> &cu_seqlens_q_output_dtype) {
|
||||
return {tmp_out_dtype};
|
||||
}
|
||||
|
||||
@@ -274,10 +280,11 @@ PD_BUILD_STATIC_OP(rebuild_padding)
|
||||
"seq_len_this_time",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_encoder",
|
||||
paddle::Optional("output_padding_offset"),
|
||||
paddle::Optional("batch_id_per_token_output"),
|
||||
paddle::Optional("cu_seqlens_q_output"),
|
||||
paddle::Optional("first_token_out")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"max_input_length: int", "enable_logprob: bool"})
|
||||
.Attrs({"enable_logprob: bool"})
|
||||
.SetKernelFn(PD_KERNEL(RebuildPadding))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));
|
||||
|
||||
Reference in New Issue
Block a user