[MTP] refactor MTP pre_process (#6358)

This commit is contained in:
周周周
2026-02-09 10:47:15 +08:00
committed by GitHub
parent 18e79dd660
commit 2b4748de4f
24 changed files with 411 additions and 533 deletions
+33 -41
View File
@@ -428,9 +428,9 @@ paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor& seq_len_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::optional<paddle::Tensor>& output_padding_offset,
const paddle::optional<paddle::Tensor>& batch_id_per_token_output,
const paddle::optional<paddle::Tensor>& cu_seqlens_q_output,
const paddle::optional<paddle::Tensor>& first_token_out,
int max_input_length,
bool enable_logprob);
void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
@@ -747,28 +747,23 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
const paddle::Tensor& seq_lens_output,
void SpecTokenPenaltyMultiScores(
const paddle::Tensor& pre_ids,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& bad_tokens_len,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const int max_seq_len);
void SpecTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& bad_tokens_len,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const int max_seq_len);
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& pre_ids,
@@ -794,7 +789,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_tokens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& actual_candidate_len,
const paddle::Tensor& actual_draft_token_nums,
const paddle::Tensor& topp,
@@ -922,7 +917,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
@@ -1102,19 +1097,20 @@ std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
const paddle::Tensor& pre_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& enable_thinking,
int64_t think_end_id,
int64_t line_break_id);
void ReasoningPhaseTokenConstraint(
const paddle::Tensor& logits,
const paddle::Tensor& pre_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& enable_thinking,
int64_t think_end_id,
int64_t line_break_id);
std::vector<paddle::Tensor> get_attn_mask_q(
const paddle::Tensor& cu_seqlens_q,
@@ -1612,10 +1608,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&SpeculateGetSeqLensOutput,
"speculate_get_seq_lens_output function");
m.def("speculate_get_output_padding_offset",
&SpeculateGetOutputPaddingOffset,
"speculate_get_output_padding_offset function");
m.def("speculate_get_token_penalty_multi_scores",
&SpecTokenPenaltyMultiScores,
"speculate_get_token_penalty_multi_scores function");
@@ -125,8 +125,8 @@ __global__ void apply_token_enforce_generation_scores_kernel(
T* __restrict__ logits_dst, // logits (output)
const int64_t* __restrict__ allowed_tokens, // [allowed_len]
const int32_t* __restrict__ reasoning_status,
const int* output_padding_offset,
const int* output_cum_offsets,
const int* batch_id_per_token_output,
const int* cu_seqlens_q_output,
const int max_bsz,
const int max_seq_len,
const int vocab_size,
@@ -134,10 +134,8 @@ __global__ void apply_token_enforce_generation_scores_kernel(
int token_idx = blockIdx.x;
int tid = threadIdx.x;
const int bs_idx =
(token_idx + output_padding_offset[token_idx]) / max_seq_len;
const int query_start_token_idx =
bs_idx * max_seq_len - output_cum_offsets[bs_idx];
const int bs_idx = batch_id_per_token_output[token_idx];
const int query_start_token_idx = cu_seqlens_q_output[bs_idx];
bool is_batch_first_token = (token_idx == query_start_token_idx);
if (allowed_tokens_len == 0 || !is_batch_first_token) {
@@ -177,8 +175,8 @@ void reasoning_phase_token_constraint(
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& enable_thinking,
int64_t think_end_id,
int64_t line_break_id) {
@@ -233,27 +231,28 @@ void reasoning_phase_token_constraint(
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
allowed_tokens.data<int64_t>(),
reasoning_status.data<int32_t>(),
output_padding_offset.data<int32_t>(),
output_cum_offsets.data<int32_t>(),
batch_id_per_token_output.data<int32_t>(),
cu_seqlens_q_output.data<int32_t>(),
bs,
max_seq_len,
vocab_size,
allowed_tokens_len);
}
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
const paddle::Tensor& pre_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& enable_thinking,
int64_t think_end_id,
int64_t line_break_id) {
void ReasoningPhaseTokenConstraint(
const paddle::Tensor& logits,
const paddle::Tensor& pre_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& enable_thinking,
int64_t think_end_id,
int64_t line_break_id) {
switch (logits.type()) {
case paddle::DataType::FLOAT16:
return reasoning_phase_token_constraint<paddle::DataType::FLOAT16>(
@@ -265,8 +264,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
step_idx,
allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
batch_id_per_token_output,
cu_seqlens_q_output,
enable_thinking,
think_end_id,
line_break_id);
@@ -280,8 +279,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
step_idx,
allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
batch_id_per_token_output,
cu_seqlens_q_output,
enable_thinking,
think_end_id,
line_break_id);
@@ -295,8 +294,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
step_idx,
allowed_tokens,
reasoning_status,
output_padding_offset,
output_cum_offsets,
batch_id_per_token_output,
cu_seqlens_q_output,
enable_thinking,
think_end_id,
line_break_id);
@@ -317,8 +316,8 @@ PD_BUILD_STATIC_OP(reasoning_phase_token_constraint)
"step_idx",
"allowed_tokens",
"reasoning_status",
"output_padding_offset",
"output_cum_offsets",
"batch_id_per_token_output",
"cu_seqlens_q_output",
"enable_thinking"})
.Outputs({"logits_out", "reasoning_status_out"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
+64 -57
View File
@@ -21,7 +21,6 @@ __global__ void RebuildPaddingKernel(T *output_data,
const int *seq_len_this_time,
const int *seq_len_decoder,
const int *seq_len_encoder,
const int max_input_length,
const int dim_embed,
const int elem_nums) {
using LoadT = AlignedVector<T, VecSize>;
@@ -51,8 +50,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
const int *seq_len_this_time,
const int *seq_len_decoder,
const int *seq_len_encoder,
const int *output_padding_offset,
const int max_input_length,
const int *batch_id_per_token_output,
const int *cu_seqlens_q_output,
const int dim_embed,
const int64_t output_elem_nums,
const int bsz,
@@ -62,17 +61,18 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
for (int64_t i = global_idx * VecSize; i < output_elem_nums;
i += gridDim.x * blockDim.x * VecSize) {
const int out_token_id = i / dim_embed;
const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
const int bi = ori_token_id / max_input_length;
int seq_id = 0;
const int bi = batch_id_per_token_output[out_token_id];
if (seq_len_this_time[bi] == 0) continue;
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1;
const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi];
const int input_token_id = ori_token_id - cum_offset_bi + seq_id;
int seq_id = 0;
if (seq_len_encoder[bi] > 0) {
seq_id = seq_len_encoder[bi] - 1;
} else {
seq_id = out_token_id - cu_seqlens_q_output[bi];
}
const int input_token_id = cu_seqlens_q[bi] + seq_id;
const int bias_idx = i % dim_embed;
Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx],
@@ -95,9 +95,10 @@ std::vector<paddle::Tensor> rebuild_padding(
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
@@ -117,8 +118,8 @@ std::vector<paddle::Tensor> rebuild_padding(
paddle::Tensor out;
int output_token_num;
if (output_padding_offset) {
output_token_num = output_padding_offset.get().shape()[0];
if (batch_id_per_token_output) {
output_token_num = batch_id_per_token_output.get().shape()[0];
} else {
output_token_num = bsz;
}
@@ -131,7 +132,7 @@ std::vector<paddle::Tensor> rebuild_padding(
int pack_num = elem_nums / PackSize;
const int blocksize = 128;
const int grid_size = (pack_num + blocksize - 1) / blocksize;
if (output_padding_offset) {
if (batch_id_per_token_output) {
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
@@ -144,8 +145,8 @@ std::vector<paddle::Tensor> rebuild_padding(
seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
output_padding_offset.get_ptr()->data<int>(),
max_input_length,
batch_id_per_token_output.get_ptr()->data<int>(),
cu_seqlens_q_output.get_ptr()->data<int>(),
dim_embed,
elem_nums,
bsz,
@@ -160,7 +161,6 @@ std::vector<paddle::Tensor> rebuild_padding(
seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
max_input_length,
dim_embed,
elem_nums);
}
@@ -173,43 +173,46 @@ paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
switch (tmp_out.type()) {
case paddle::DataType::BFLOAT16: {
return rebuild_padding<paddle::DataType::BFLOAT16>(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
first_token_out,
max_input_length,
enable_logprob)[0];
return rebuild_padding<paddle::DataType::BFLOAT16>(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
enable_logprob)[0];
}
case paddle::DataType::FLOAT16: {
return rebuild_padding<paddle::DataType::FLOAT16>(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
first_token_out,
max_input_length,
enable_logprob)[0];
return rebuild_padding<paddle::DataType::FLOAT16>(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
enable_logprob)[0];
}
case paddle::DataType::FLOAT32: {
return rebuild_padding<paddle::DataType::FLOAT32>(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
first_token_out,
max_input_length,
enable_logprob)[0];
return rebuild_padding<paddle::DataType::FLOAT32>(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
enable_logprob)[0];
}
default: {
PD_THROW(
@@ -226,18 +229,18 @@ std::vector<paddle::Tensor> RebuildPadding(
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {
return {RebuildPaddingFunc(tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
max_input_length,
enable_logprob)};
}
@@ -247,10 +250,12 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
const std::vector<int64_t> &seq_len_this_time_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_encoder_shape,
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
const paddle::optional<std::vector<int64_t>>
&batch_id_per_token_output_shape,
const paddle::optional<std::vector<int64_t>> &cu_seqlens_q_output_shape) {
int64_t dim_embed = tmp_out_shape[1];
// whether speculative decoding
if (output_padding_offset_shape) {
if (batch_id_per_token_output_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cu_seqlens_q_shape[0] - 1;
@@ -264,7 +269,8 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
const paddle::DataType &seq_len_this_time_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_encoder_dtype,
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
const paddle::optional<paddle::DataType> &batch_id_per_token_output_dtype,
const paddle::optional<paddle::DataType> &cu_seqlens_q_output_dtype) {
return {tmp_out_dtype};
}
@@ -274,10 +280,11 @@ PD_BUILD_STATIC_OP(rebuild_padding)
"seq_len_this_time",
"seq_lens_decoder",
"seq_lens_encoder",
paddle::Optional("output_padding_offset"),
paddle::Optional("batch_id_per_token_output"),
paddle::Optional("cu_seqlens_q_output"),
paddle::Optional("first_token_out")})
.Outputs({"out"})
.Attrs({"max_input_length: int", "enable_logprob: bool"})
.Attrs({"enable_logprob: bool"})
.SetKernelFn(PD_KERNEL(RebuildPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));
@@ -23,7 +23,7 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
const int* cu_seqlens_q_output,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
@@ -47,8 +47,7 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
auto* pre_ids_now = pre_ids + tid * pre_id_length;
auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * max_base_model_draft_token;
const int next_tokens_start_id =
tid * max_seq_len - output_cum_offsets[tid];
const int next_tokens_start_id = cu_seqlens_q_output[tid];
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
auto seq_len_this_time = seq_lens_this_time[tid];
auto seq_len_encoder = seq_lens_encoder[tid];
@@ -66,7 +65,6 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
step_idx[tid] += seq_len_this_time;
pre_ids_now[step_idx[tid]] = token_this_time;
} else {
token_this_time = next_tokens_start[0];
@@ -80,7 +78,8 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
}
// multi_end
if (is_in_end(token_this_time, end_ids, end_ids_len) || prefill_one_step_stop) {
if (is_in_end(token_this_time, end_ids, end_ids_len) ||
prefill_one_step_stop) {
stop_flags[tid] = true;
stop_flag_now_int = 1;
// max_dec_len
@@ -112,7 +111,6 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
}
}
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
@@ -120,7 +118,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
@@ -140,11 +138,11 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
constexpr int BlockSize = 512;
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
draft_model_update_kernel<BlockSize><<<1, BlockSize, 0, cu_stream>>>(
@@ -155,7 +153,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
output_cum_offsets.data<int>(),
cu_seqlens_q_output.data<int>(),
const_cast<bool*>(stop_flags.data<bool>()),
not_need_stop_gpu.data<bool>(),
max_dec_len.data<int64_t>(),
@@ -170,14 +168,12 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
substep,
prefill_one_step_stop);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(draft_model_update)
.Inputs({"inter_next_tokens",
"draft_tokens",
@@ -186,7 +182,7 @@ PD_BUILD_STATIC_OP(draft_model_update)
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"output_cum_offsets",
"cu_seqlens_q_output",
"stop_flags",
"not_need_stop",
"max_dec_len",
@@ -1,88 +0,0 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
__global__ void SpeculateGetOutputPaddingOffsetKernel(
int* output_padding_offset,
int* output_cum_offsets,
const int* output_cum_offsets_tmp,
const int* seq_lens_output,
const int max_seq_len) {
// get padding offset of each batch
const int bi = blockIdx.x;
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : output_cum_offsets_tmp[bi - 1];
for (int i = ti; i < seq_lens_output[bi]; i += blockDim.x) {
output_padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
}
if (ti == 0) {
output_cum_offsets[bi] = cum_offset;
}
}
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
const paddle::Tensor& seq_lens_output,
const int max_seq_len) {
auto cu_stream = output_cum_offsets_tmp.stream();
std::vector<int64_t> output_cum_offsets_tmp_shape =
output_cum_offsets_tmp.shape();
const int bsz = output_cum_offsets_tmp_shape[0];
auto cpu_out_token_num = out_token_num.copy_to(paddle::CPUPlace(), false);
auto output_padding_offset = paddle::full({cpu_out_token_num},
0,
paddle::DataType::INT32,
output_cum_offsets_tmp.place());
auto output_cum_offsets =
output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false);
SpeculateGetOutputPaddingOffsetKernel<<<bsz, 256, 0, cu_stream>>>(
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
output_cum_offsets_tmp.data<int>(),
seq_lens_output.data<int>(),
max_seq_len);
return {output_padding_offset, output_cum_offsets};
}
std::vector<std::vector<int64_t>> SpeculateGetOutputPaddingOffsetInferShape(
const std::vector<int64_t>& output_cum_offsets_tmp_shape,
const std::vector<int64_t>& out_token_num_shape,
const std::vector<int64_t>& seq_lens_output_shape) {
int64_t bsz = output_cum_offsets_tmp_shape[0];
return {{-1}, {bsz}};
}
std::vector<paddle::DataType> SpeculateGetOutputPaddingOffsetInferDtype(
const paddle::DataType& output_cum_offsets_tmp_dtype,
const paddle::DataType& out_token_num_dtype,
const paddle::DataType& seq_lens_output_dtype) {
return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype};
}
PD_BUILD_STATIC_OP(speculate_get_output_padding_offset)
.Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"})
.Outputs({"output_padding_offset", "output_cum_offsets"})
.Attrs({"max_seq_len: int"})
.SetKernelFn(PD_KERNEL(SpeculateGetOutputPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetOutputPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetOutputPaddingOffsetInferDtype));
@@ -20,8 +20,8 @@ __global__ inline void min_length_logits_process(
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int *output_padding_offset,
const int *output_cum_offsets,
const int *batch_id_per_token_output,
const int *cu_seqlens_q_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
@@ -29,9 +29,9 @@ __global__ inline void min_length_logits_process(
const int max_seq_len) {
const int token_idx = threadIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
const int bi = batch_id_per_token_output[token_idx];
if (bi >= bs) return;
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
const int query_start_token_idx = cu_seqlens_q_output[bi];
if (cur_len[bi] < 0) {
return;
@@ -49,8 +49,8 @@ __global__ inline void min_length_logits_process<half>(
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int *output_padding_offset,
const int *output_cum_offsets,
const int *batch_id_per_token_output,
const int *cu_seqlens_q_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
@@ -58,9 +58,9 @@ __global__ inline void min_length_logits_process<half>(
const int max_seq_len) {
const int token_idx = threadIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
const int bi = batch_id_per_token_output[token_idx];
if (bi >= bs) return;
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
const int query_start_token_idx = cu_seqlens_q_output[bi];
if (cur_len[bi] < 0) {
return;
@@ -75,7 +75,7 @@ __global__ inline void min_length_logits_process<half>(
__global__ void update_repeat_times(const int64_t *pre_ids,
const int64_t *cur_len,
int *repeat_times,
const int *output_padding_offset,
const int *batch_id_per_token_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
@@ -83,7 +83,7 @@ __global__ void update_repeat_times(const int64_t *pre_ids,
const int max_seq_len) {
const int token_idx = blockIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
const int bi = batch_id_per_token_output[token_idx];
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
@@ -99,20 +99,21 @@ __global__ void update_repeat_times(const int64_t *pre_ids,
}
template <typename T>
__global__ void update_value_by_repeat_times(const int *repeat_times,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
T *logits,
const int *output_padding_offset,
const int64_t token_num,
const int64_t bs,
const int64_t length,
const int max_seq_len) {
__global__ void update_value_by_repeat_times(
const int *repeat_times,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
T *logits,
const int *batch_id_per_token_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
const int max_seq_len) {
const int token_idx = blockIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
const int bi = batch_id_per_token_output[token_idx];
if (bi >= bs) return;
int tid = threadIdx.x;
T *logits_now = logits + token_idx * length;
@@ -135,7 +136,7 @@ template <typename T>
__global__ void ban_bad_words(T *logits,
const int64_t *bad_tokens,
const int64_t *bad_tokens_len,
const int *output_padding_offset,
const int *batch_id_per_token_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
@@ -143,7 +144,7 @@ __global__ void ban_bad_words(T *logits,
const int max_seq_len) {
const int token_idx = blockIdx.x;
if (token_idx >= token_num) return;
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
const int bi = batch_id_per_token_output[token_idx];
if (bi >= bs) return;
int tid = threadIdx.x;
@@ -172,8 +173,8 @@ void token_penalty_multi_scores_kernel(
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &output_padding_offset,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &batch_id_per_token_output,
const paddle::Tensor &cu_seqlens_q_output,
const int max_seq_len) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
@@ -196,8 +197,8 @@ void token_penalty_multi_scores_kernel(
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
output_padding_offset.data<int>(),
output_cum_offsets.data<int>(),
batch_id_per_token_output.data<int>(),
cu_seqlens_q_output.data<int>(),
token_num,
bs,
length,
@@ -210,7 +211,7 @@ void token_penalty_multi_scores_kernel(
pre_ids.data<int64_t>(),
cur_len.data<int64_t>(),
repeat_times.data<int>(),
output_padding_offset.data<int>(),
batch_id_per_token_output.data<int>(),
token_num,
bs,
length,
@@ -231,7 +232,7 @@ void token_penalty_multi_scores_kernel(
temperatures.data<float>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
output_padding_offset.data<int>(),
batch_id_per_token_output.data<int>(),
token_num,
bs,
length,
@@ -243,7 +244,7 @@ void token_penalty_multi_scores_kernel(
const_cast<data_t *>(logits.data<data_t>())),
bad_tokens.data<int64_t>(),
bad_tokens_len.data<int64_t>(),
output_padding_offset.data<int>(),
batch_id_per_token_output.data<int>(),
token_num,
bs,
length,
@@ -251,21 +252,22 @@ void token_penalty_multi_scores_kernel(
max_seq_len);
}
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
const paddle::Tensor &presence_scores,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &bad_tokens_len,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &output_padding_offset,
const paddle::Tensor &output_cum_offsets,
const int max_seq_len) {
void SpecTokenPenaltyMultiScores(
const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
const paddle::Tensor &presence_scores,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &bad_tokens_len,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &batch_id_per_token_output,
const paddle::Tensor &cu_seqlens_q_output,
const int max_seq_len) {
switch (logits.type()) {
case paddle::DataType::BFLOAT16: {
return token_penalty_multi_scores_kernel<paddle::DataType::BFLOAT16>(
@@ -281,8 +283,8 @@ void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
min_len,
eos_token_id,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
batch_id_per_token_output,
cu_seqlens_q_output,
max_seq_len);
}
case paddle::DataType::FLOAT16: {
@@ -299,8 +301,8 @@ void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
min_len,
eos_token_id,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
batch_id_per_token_output,
cu_seqlens_q_output,
max_seq_len);
}
case paddle::DataType::FLOAT32: {
@@ -317,8 +319,8 @@ void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
min_len,
eos_token_id,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
batch_id_per_token_output,
cu_seqlens_q_output,
max_seq_len);
}
default: {
@@ -343,8 +345,8 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
"min_len",
"eos_token_id",
"seq_lens_this_time",
"output_padding_offset",
"output_cum_offsets"})
"batch_id_per_token_output",
"cu_seqlens_q_output"})
.Outputs({"logits_out"})
.Attrs({"max_seq_len: int"})
.SetInplaceMap({{"logits", "logits_out"}})
@@ -82,7 +82,7 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids,
const int64_t *max_dec_len,
const int64_t *end_tokens,
const bool *is_block_step,
const int *output_cum_offsets,
const int *cu_seqlens_q_output,
const int *actual_candidate_len,
const int *reasoning_status,
const int real_bsz,
@@ -101,7 +101,7 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids,
int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) {
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
const int start_token_id = cu_seqlens_q_output[bid];
if (stop_flags[bid]) {
stop_flag_now_int = 1;
@@ -312,7 +312,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
const paddle::Tensor &max_dec_len,
const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &cu_seqlens_q_output,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &topp,
@@ -374,7 +374,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
cu_seqlens_q_output.data<int>(),
actual_candidate_len.data<int>(),
reasoning_status.data<int>(),
real_bsz,
@@ -407,7 +407,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
cu_seqlens_q_output.data<int>(),
actual_candidate_len.data<int>(),
reasoning_status.data<int>(),
real_bsz,
@@ -442,7 +442,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
cu_seqlens_q_output.data<int>(),
actual_candidate_len.data<int>(),
reasoning_status.data<int>(),
real_bsz,
@@ -475,7 +475,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
output_cum_offsets.data<int>(),
cu_seqlens_q_output.data<int>(),
actual_candidate_len.data<int>(),
reasoning_status.data<int>(),
real_bsz,
@@ -509,7 +509,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
"max_dec_len",
"end_tokens",
"is_block_step",
"output_cum_offsets",
"cu_seqlens_q_output",
"actual_candidate_len",
"actual_draft_token_nums",
"topp",
@@ -407,7 +407,7 @@ template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
__global__ void KeMatrixTopPBeamTopKFt(
const T* src,
const T* top_ps,
const int* output_padding_offset,
const int* batch_id_per_token_output,
int64_t* out_id, // [max_cadidate_len, 1]
T* out_val, // [max_cadidate_len, 1]
int* actual_candidates_lens,
@@ -418,8 +418,7 @@ __global__ void KeMatrixTopPBeamTopKFt(
const int wid = tid / 32;
const int lane = tid % 32;
const int token_id = blockIdx.x;
const int ori_token_id = token_id + output_padding_offset[token_id];
const int bid = ori_token_id / max_seq_len;
const int bid = batch_id_per_token_output[token_id];
int top_num = TopPBeamTopK;
float top_p_value = static_cast<float>(top_ps[bid]);
@@ -479,7 +478,7 @@ __global__ void KeMatrixTopPBeamTopKFt(
template <typename T, int TopKMaxLength>
void DispatchTopK(const T* src,
const T* top_ps,
const int* output_padding_offset,
const int* batch_id_per_token_output,
int64_t* out_id, // topk id
T* out_val, // topk val
int* actual_candidates_lens_data,
@@ -495,7 +494,7 @@ void DispatchTopK(const T* src,
KeMatrixTopPBeamTopKFt<T, TopKMaxLength, kTopK, kBlockDim>
<<<token_num, kBlockDim, 0, stream>>>(src,
top_ps,
output_padding_offset,
batch_id_per_token_output,
out_id,
out_val,
actual_candidates_lens_data,
@@ -516,7 +515,7 @@ template <paddle::DataType D>
std::vector<paddle::Tensor> LaunchTopPCandidates(
const paddle::Tensor& probs, // [token_num, vocab_size]
const paddle::Tensor& top_p, // [token_num]
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& batch_id_per_token_output,
const int candidates_len,
const int max_seq_len) {
typedef PDTraits<D> traits_;
@@ -540,7 +539,7 @@ std::vector<paddle::Tensor> LaunchTopPCandidates(
DispatchTopK<DataType_, TopKMaxLength>(
reinterpret_cast<const DataType_*>(probs.data<data_t>()),
reinterpret_cast<const DataType_*>(top_p.data<data_t>()),
output_padding_offset.data<int>(),
batch_id_per_token_output.data<int>(),
verify_tokens.data<int64_t>(),
reinterpret_cast<DataType_*>(verify_scores.data<data_t>()),
actual_candidate_lens.data<int>(),
@@ -556,21 +555,21 @@ std::vector<paddle::Tensor> LaunchTopPCandidates(
std::vector<paddle::Tensor> DispatchTopPCandidatesWithDtype(
const paddle::Tensor& probs,
const paddle::Tensor& top_p,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& batch_id_per_token_output,
int candidates_len,
int max_seq_len) {
switch (probs.type()) {
case paddle::DataType::BFLOAT16:
return LaunchTopPCandidates<paddle::DataType::BFLOAT16>(
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len);
break;
case paddle::DataType::FLOAT16:
return LaunchTopPCandidates<paddle::DataType::FLOAT16>(
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len);
break;
case paddle::DataType::FLOAT32:
return LaunchTopPCandidates<paddle::DataType::FLOAT32>(
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len);
break;
default:
PD_THROW(
@@ -583,17 +582,17 @@ std::vector<paddle::Tensor> DispatchTopPCandidatesWithDtype(
std::vector<paddle::Tensor> TopPCandidates(
const paddle::Tensor& probs,
const paddle::Tensor& top_p,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& batch_id_per_token_output,
int candidates_len,
int max_seq_len) {
return DispatchTopPCandidatesWithDtype(
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len);
}
std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
const std::vector<int64_t>& probs_shape,
const std::vector<int64_t>& top_p_shape,
const std::vector<int64_t>& output_padding_offset_shape,
const std::vector<int64_t>& batch_id_per_token_output_shape,
int max_candidates_len) {
int token_num = probs_shape[0];
return {{token_num, max_candidates_len},
@@ -604,12 +603,12 @@ std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
std::vector<paddle::DataType> TopPCandidatesInferDtype(
const paddle::DataType& probs_dtype,
const paddle::DataType& top_p_dtype,
const paddle::DataType& output_padding_offset_dtype) {
const paddle::DataType& batch_id_per_token_output_dtype) {
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
}
PD_BUILD_STATIC_OP(top_p_candidates)
.Inputs({"probs", "top_p", "output_padding_offset"})
.Inputs({"probs", "top_p", "batch_id_per_token_output"})
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
.Attrs({"candidates_len: int", "max_seq_len: int"})
.SetKernelFn(PD_KERNEL(TopPCandidates))