[Feature] support mtp logprob (#4464)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support mtp logprob

* fix unitest
This commit is contained in:
GoldPancake
2025-10-20 15:18:12 +08:00
committed by GitHub
parent 1b9f351d21
commit 47595a2480
14 changed files with 1181 additions and 32 deletions
+41 -12
View File
@@ -46,6 +46,7 @@ __global__ void RebuildPaddingKernel(T *output_data,
template <typename T, int VecSize>
__global__ void RebuildAppendPaddingKernel(T *output_data,
T *first_token_out,
const T *input_data,
const int *cu_seqlens_q,
const int *seq_len_this_time,
@@ -55,7 +56,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
const int max_input_length,
const int dim_embed,
const int64_t output_elem_nums,
const int bsz) {
const int bsz,
const bool enable_logprob) {
AlignedVector<T, VecSize> src_vec;
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = global_idx * VecSize; i < output_elem_nums;
@@ -70,13 +72,20 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1;
const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi];
const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi];
const int input_token_id = ori_token_id - cum_offset_bi + seq_id;
const int bias_idx = i % dim_embed;
Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx],
&src_vec);
Store<T, VecSize>(src_vec, &output_data[i]);
if (enable_logprob && seq_len_encoder[bi] > 0) {
const int first_input_token_id = input_token_id - 1;
Load<T, VecSize>(&input_data[first_input_token_id * dim_embed + bias_idx],
&src_vec);
Store<T, VecSize>(src_vec, &first_token_out[bi * dim_embed + bias_idx]);
}
}
}
@@ -89,7 +98,9 @@ std::vector<paddle::Tensor> rebuild_padding(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -134,6 +145,10 @@ std::vector<paddle::Tensor> rebuild_padding(
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
first_token_out.is_initialized()
? reinterpret_cast<DataType_ *>(const_cast<data_t *>(
first_token_out.get_ptr()->data<data_t>()))
: nullptr,
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
cu_seqlens_q.data<int>(),
seq_len_this_time.data<int>(),
@@ -143,7 +158,8 @@ std::vector<paddle::Tensor> rebuild_padding(
max_input_length,
dim_embed,
elem_nums,
bsz);
bsz,
enable_logprob);
} else {
RebuildPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
@@ -168,7 +184,9 @@ paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
switch (tmp_out.type()) {
case paddle::DataType::BFLOAT16: {
return rebuild_padding<paddle::DataType::BFLOAT16>(
@@ -178,7 +196,9 @@ paddle::Tensor RebuildPaddingFunc(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)[0];
first_token_out,
max_input_length,
enable_logprob)[0];
}
case paddle::DataType::FLOAT16: {
return rebuild_padding<paddle::DataType::FLOAT16>(
@@ -188,7 +208,9 @@ paddle::Tensor RebuildPaddingFunc(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)[0];
first_token_out,
max_input_length,
enable_logprob)[0];
}
case paddle::DataType::FLOAT32: {
return rebuild_padding<paddle::DataType::FLOAT32>(
@@ -198,7 +220,9 @@ paddle::Tensor RebuildPaddingFunc(
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)[0];
first_token_out,
max_input_length,
enable_logprob)[0];
}
default: {
PD_THROW(
@@ -216,14 +240,18 @@ std::vector<paddle::Tensor> RebuildPadding(
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
return {RebuildPaddingFunc(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length)};
first_token_out,
max_input_length,
enable_logprob)};
}
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
@@ -259,9 +287,10 @@ PD_BUILD_STATIC_OP(rebuild_padding)
"seq_len_this_time",
"seq_lens_decoder",
"seq_lens_encoder",
paddle::Optional("output_padding_offset")})
paddle::Optional("output_padding_offset"),
paddle::Optional("first_token_out")})
.Outputs({"out"})
.Attrs({"max_input_length: int"})
.Attrs({"max_input_length: int", "enable_logprob: bool"})
.SetKernelFn(PD_KERNEL(RebuildPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));