mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[MTP] refactor MTP pre_process (#6358)
This commit is contained in:
@@ -428,9 +428,9 @@ paddle::Tensor RebuildPaddingFunc(
|
||||
const paddle::Tensor& seq_len_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
||||
const paddle::optional<paddle::Tensor>& batch_id_per_token_output,
|
||||
const paddle::optional<paddle::Tensor>& cu_seqlens_q_output,
|
||||
const paddle::optional<paddle::Tensor>& first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob);
|
||||
|
||||
void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
|
||||
@@ -747,28 +747,23 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||
const paddle::Tensor& output_cum_offsets_tmp,
|
||||
const paddle::Tensor& out_token_num,
|
||||
const paddle::Tensor& seq_lens_output,
|
||||
void SpecTokenPenaltyMultiScores(
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& penalty_scores,
|
||||
const paddle::Tensor& frequency_scores,
|
||||
const paddle::Tensor& presence_scores,
|
||||
const paddle::Tensor& temperatures,
|
||||
const paddle::Tensor& bad_tokens,
|
||||
const paddle::Tensor& bad_tokens_len,
|
||||
const paddle::Tensor& cur_len,
|
||||
const paddle::Tensor& min_len,
|
||||
const paddle::Tensor& eos_token_id,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const int max_seq_len);
|
||||
|
||||
void SpecTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& penalty_scores,
|
||||
const paddle::Tensor& frequency_scores,
|
||||
const paddle::Tensor& presence_scores,
|
||||
const paddle::Tensor& temperatures,
|
||||
const paddle::Tensor& bad_tokens,
|
||||
const paddle::Tensor& bad_tokens_len,
|
||||
const paddle::Tensor& cur_len,
|
||||
const paddle::Tensor& min_len,
|
||||
const paddle::Tensor& eos_token_id,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const int max_seq_len);
|
||||
|
||||
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& pre_ids,
|
||||
@@ -794,7 +789,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& end_tokens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& actual_candidate_len,
|
||||
const paddle::Tensor& actual_draft_token_nums,
|
||||
const paddle::Tensor& topp,
|
||||
@@ -922,7 +917,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
@@ -1102,19 +1097,20 @@ std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
|
||||
|
||||
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);
|
||||
|
||||
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& allowed_tokens,
|
||||
const paddle::Tensor& reasoning_status,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
int64_t think_end_id,
|
||||
int64_t line_break_id);
|
||||
void ReasoningPhaseTokenConstraint(
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& allowed_tokens,
|
||||
const paddle::Tensor& reasoning_status,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
int64_t think_end_id,
|
||||
int64_t line_break_id);
|
||||
|
||||
std::vector<paddle::Tensor> get_attn_mask_q(
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
@@ -1612,10 +1608,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
&SpeculateGetSeqLensOutput,
|
||||
"speculate_get_seq_lens_output function");
|
||||
|
||||
m.def("speculate_get_output_padding_offset",
|
||||
&SpeculateGetOutputPaddingOffset,
|
||||
"speculate_get_output_padding_offset function");
|
||||
|
||||
m.def("speculate_get_token_penalty_multi_scores",
|
||||
&SpecTokenPenaltyMultiScores,
|
||||
"speculate_get_token_penalty_multi_scores function");
|
||||
|
||||
@@ -125,8 +125,8 @@ __global__ void apply_token_enforce_generation_scores_kernel(
|
||||
T* __restrict__ logits_dst, // logits (output)
|
||||
const int64_t* __restrict__ allowed_tokens, // [allowed_len]
|
||||
const int32_t* __restrict__ reasoning_status,
|
||||
const int* output_padding_offset,
|
||||
const int* output_cum_offsets,
|
||||
const int* batch_id_per_token_output,
|
||||
const int* cu_seqlens_q_output,
|
||||
const int max_bsz,
|
||||
const int max_seq_len,
|
||||
const int vocab_size,
|
||||
@@ -134,10 +134,8 @@ __global__ void apply_token_enforce_generation_scores_kernel(
|
||||
int token_idx = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const int bs_idx =
|
||||
(token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
const int query_start_token_idx =
|
||||
bs_idx * max_seq_len - output_cum_offsets[bs_idx];
|
||||
const int bs_idx = batch_id_per_token_output[token_idx];
|
||||
const int query_start_token_idx = cu_seqlens_q_output[bs_idx];
|
||||
bool is_batch_first_token = (token_idx == query_start_token_idx);
|
||||
|
||||
if (allowed_tokens_len == 0 || !is_batch_first_token) {
|
||||
@@ -177,8 +175,8 @@ void reasoning_phase_token_constraint(
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& allowed_tokens,
|
||||
const paddle::Tensor& reasoning_status,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
int64_t think_end_id,
|
||||
int64_t line_break_id) {
|
||||
@@ -233,27 +231,28 @@ void reasoning_phase_token_constraint(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
|
||||
allowed_tokens.data<int64_t>(),
|
||||
reasoning_status.data<int32_t>(),
|
||||
output_padding_offset.data<int32_t>(),
|
||||
output_cum_offsets.data<int32_t>(),
|
||||
batch_id_per_token_output.data<int32_t>(),
|
||||
cu_seqlens_q_output.data<int32_t>(),
|
||||
bs,
|
||||
max_seq_len,
|
||||
vocab_size,
|
||||
allowed_tokens_len);
|
||||
}
|
||||
|
||||
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& allowed_tokens,
|
||||
const paddle::Tensor& reasoning_status,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
int64_t think_end_id,
|
||||
int64_t line_break_id) {
|
||||
void ReasoningPhaseTokenConstraint(
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& allowed_tokens,
|
||||
const paddle::Tensor& reasoning_status,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
int64_t think_end_id,
|
||||
int64_t line_break_id) {
|
||||
switch (logits.type()) {
|
||||
case paddle::DataType::FLOAT16:
|
||||
return reasoning_phase_token_constraint<paddle::DataType::FLOAT16>(
|
||||
@@ -265,8 +264,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
step_idx,
|
||||
allowed_tokens,
|
||||
reasoning_status,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
enable_thinking,
|
||||
think_end_id,
|
||||
line_break_id);
|
||||
@@ -280,8 +279,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
step_idx,
|
||||
allowed_tokens,
|
||||
reasoning_status,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
enable_thinking,
|
||||
think_end_id,
|
||||
line_break_id);
|
||||
@@ -295,8 +294,8 @@ void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
step_idx,
|
||||
allowed_tokens,
|
||||
reasoning_status,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
enable_thinking,
|
||||
think_end_id,
|
||||
line_break_id);
|
||||
@@ -317,8 +316,8 @@ PD_BUILD_STATIC_OP(reasoning_phase_token_constraint)
|
||||
"step_idx",
|
||||
"allowed_tokens",
|
||||
"reasoning_status",
|
||||
"output_padding_offset",
|
||||
"output_cum_offsets",
|
||||
"batch_id_per_token_output",
|
||||
"cu_seqlens_q_output",
|
||||
"enable_thinking"})
|
||||
.Outputs({"logits_out", "reasoning_status_out"})
|
||||
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
|
||||
|
||||
@@ -21,7 +21,6 @@ __global__ void RebuildPaddingKernel(T *output_data,
|
||||
const int *seq_len_this_time,
|
||||
const int *seq_len_decoder,
|
||||
const int *seq_len_encoder,
|
||||
const int max_input_length,
|
||||
const int dim_embed,
|
||||
const int elem_nums) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
@@ -51,8 +50,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
|
||||
const int *seq_len_this_time,
|
||||
const int *seq_len_decoder,
|
||||
const int *seq_len_encoder,
|
||||
const int *output_padding_offset,
|
||||
const int max_input_length,
|
||||
const int *batch_id_per_token_output,
|
||||
const int *cu_seqlens_q_output,
|
||||
const int dim_embed,
|
||||
const int64_t output_elem_nums,
|
||||
const int bsz,
|
||||
@@ -62,17 +61,18 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
|
||||
for (int64_t i = global_idx * VecSize; i < output_elem_nums;
|
||||
i += gridDim.x * blockDim.x * VecSize) {
|
||||
const int out_token_id = i / dim_embed;
|
||||
const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
|
||||
|
||||
const int bi = ori_token_id / max_input_length;
|
||||
|
||||
int seq_id = 0;
|
||||
const int bi = batch_id_per_token_output[out_token_id];
|
||||
if (seq_len_this_time[bi] == 0) continue;
|
||||
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
|
||||
|
||||
if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1;
|
||||
const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi];
|
||||
const int input_token_id = ori_token_id - cum_offset_bi + seq_id;
|
||||
int seq_id = 0;
|
||||
if (seq_len_encoder[bi] > 0) {
|
||||
seq_id = seq_len_encoder[bi] - 1;
|
||||
} else {
|
||||
seq_id = out_token_id - cu_seqlens_q_output[bi];
|
||||
}
|
||||
|
||||
const int input_token_id = cu_seqlens_q[bi] + seq_id;
|
||||
const int bias_idx = i % dim_embed;
|
||||
|
||||
Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx],
|
||||
@@ -95,9 +95,10 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
const paddle::Tensor &seq_len_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
|
||||
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
|
||||
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
@@ -117,8 +118,8 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
|
||||
paddle::Tensor out;
|
||||
int output_token_num;
|
||||
if (output_padding_offset) {
|
||||
output_token_num = output_padding_offset.get().shape()[0];
|
||||
if (batch_id_per_token_output) {
|
||||
output_token_num = batch_id_per_token_output.get().shape()[0];
|
||||
} else {
|
||||
output_token_num = bsz;
|
||||
}
|
||||
@@ -131,7 +132,7 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
const int grid_size = (pack_num + blocksize - 1) / blocksize;
|
||||
if (output_padding_offset) {
|
||||
if (batch_id_per_token_output) {
|
||||
RebuildAppendPaddingKernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
@@ -144,8 +145,8 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
seq_len_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
output_padding_offset.get_ptr()->data<int>(),
|
||||
max_input_length,
|
||||
batch_id_per_token_output.get_ptr()->data<int>(),
|
||||
cu_seqlens_q_output.get_ptr()->data<int>(),
|
||||
dim_embed,
|
||||
elem_nums,
|
||||
bsz,
|
||||
@@ -160,7 +161,6 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
seq_len_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_input_length,
|
||||
dim_embed,
|
||||
elem_nums);
|
||||
}
|
||||
@@ -173,43 +173,46 @@ paddle::Tensor RebuildPaddingFunc(
|
||||
const paddle::Tensor &seq_len_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
|
||||
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob) {
|
||||
switch (tmp_out.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return rebuild_padding<paddle::DataType::BFLOAT16>(tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)[0];
|
||||
return rebuild_padding<paddle::DataType::BFLOAT16>(
|
||||
tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
first_token_out,
|
||||
enable_logprob)[0];
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return rebuild_padding<paddle::DataType::FLOAT16>(tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)[0];
|
||||
return rebuild_padding<paddle::DataType::FLOAT16>(
|
||||
tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
first_token_out,
|
||||
enable_logprob)[0];
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return rebuild_padding<paddle::DataType::FLOAT32>(tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)[0];
|
||||
return rebuild_padding<paddle::DataType::FLOAT32>(
|
||||
tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
first_token_out,
|
||||
enable_logprob)[0];
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
@@ -226,18 +229,18 @@ std::vector<paddle::Tensor> RebuildPadding(
|
||||
const paddle::Tensor &seq_len_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
const paddle::optional<paddle::Tensor> &batch_id_per_token_output,
|
||||
const paddle::optional<paddle::Tensor> &cu_seqlens_q_output,
|
||||
const paddle::optional<paddle::Tensor> &first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob) {
|
||||
return {RebuildPaddingFunc(tmp_out,
|
||||
cu_seqlens_q,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
first_token_out,
|
||||
max_input_length,
|
||||
enable_logprob)};
|
||||
}
|
||||
|
||||
@@ -247,10 +250,12 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||
const std::vector<int64_t> &seq_len_this_time_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
|
||||
const paddle::optional<std::vector<int64_t>>
|
||||
&batch_id_per_token_output_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &cu_seqlens_q_output_shape) {
|
||||
int64_t dim_embed = tmp_out_shape[1];
|
||||
// whether speculative decoding
|
||||
if (output_padding_offset_shape) {
|
||||
if (batch_id_per_token_output_shape) {
|
||||
return {{-1, dim_embed}};
|
||||
} else {
|
||||
int64_t bsz = cu_seqlens_q_shape[0] - 1;
|
||||
@@ -264,7 +269,8 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
||||
const paddle::DataType &seq_len_this_time_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
|
||||
const paddle::optional<paddle::DataType> &batch_id_per_token_output_dtype,
|
||||
const paddle::optional<paddle::DataType> &cu_seqlens_q_output_dtype) {
|
||||
return {tmp_out_dtype};
|
||||
}
|
||||
|
||||
@@ -274,10 +280,11 @@ PD_BUILD_STATIC_OP(rebuild_padding)
|
||||
"seq_len_this_time",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_encoder",
|
||||
paddle::Optional("output_padding_offset"),
|
||||
paddle::Optional("batch_id_per_token_output"),
|
||||
paddle::Optional("cu_seqlens_q_output"),
|
||||
paddle::Optional("first_token_out")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"max_input_length: int", "enable_logprob: bool"})
|
||||
.Attrs({"enable_logprob: bool"})
|
||||
.SetKernelFn(PD_KERNEL(RebuildPadding))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));
|
||||
|
||||
@@ -23,7 +23,7 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
const int* output_cum_offsets,
|
||||
const int* cu_seqlens_q_output,
|
||||
bool* stop_flags,
|
||||
bool* not_need_stop,
|
||||
const int64_t* max_dec_len,
|
||||
@@ -47,8 +47,7 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
auto* pre_ids_now = pre_ids + tid * pre_id_length;
|
||||
auto* base_model_draft_tokens_now =
|
||||
base_model_draft_tokens + tid * max_base_model_draft_token;
|
||||
const int next_tokens_start_id =
|
||||
tid * max_seq_len - output_cum_offsets[tid];
|
||||
const int next_tokens_start_id = cu_seqlens_q_output[tid];
|
||||
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
|
||||
auto seq_len_this_time = seq_lens_this_time[tid];
|
||||
auto seq_len_encoder = seq_lens_encoder[tid];
|
||||
@@ -66,7 +65,6 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
step_idx[tid] += seq_len_this_time;
|
||||
pre_ids_now[step_idx[tid]] = token_this_time;
|
||||
|
||||
|
||||
} else {
|
||||
token_this_time = next_tokens_start[0];
|
||||
|
||||
@@ -80,7 +78,8 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
}
|
||||
|
||||
// multi_end
|
||||
if (is_in_end(token_this_time, end_ids, end_ids_len) || prefill_one_step_stop) {
|
||||
if (is_in_end(token_this_time, end_ids, end_ids_len) ||
|
||||
prefill_one_step_stop) {
|
||||
stop_flags[tid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
// max_dec_len
|
||||
@@ -112,7 +111,6 @@ __global__ void draft_model_update_kernel(const int64_t* inter_next_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& pre_ids,
|
||||
@@ -120,7 +118,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
@@ -140,11 +138,11 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
constexpr int BlockSize = 512;
|
||||
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
|
||||
draft_model_update_kernel<BlockSize><<<1, BlockSize, 0, cu_stream>>>(
|
||||
@@ -155,7 +153,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
output_cum_offsets.data<int>(),
|
||||
cu_seqlens_q_output.data<int>(),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
not_need_stop_gpu.data<bool>(),
|
||||
max_dec_len.data<int64_t>(),
|
||||
@@ -170,14 +168,12 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
substep,
|
||||
prefill_one_step_stop);
|
||||
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(draft_model_update)
|
||||
.Inputs({"inter_next_tokens",
|
||||
"draft_tokens",
|
||||
@@ -186,7 +182,7 @@ PD_BUILD_STATIC_OP(draft_model_update)
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"output_cum_offsets",
|
||||
"cu_seqlens_q_output",
|
||||
"stop_flags",
|
||||
"not_need_stop",
|
||||
"max_dec_len",
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
__global__ void SpeculateGetOutputPaddingOffsetKernel(
|
||||
int* output_padding_offset,
|
||||
int* output_cum_offsets,
|
||||
const int* output_cum_offsets_tmp,
|
||||
const int* seq_lens_output,
|
||||
const int max_seq_len) {
|
||||
// get padding offset of each batch
|
||||
const int bi = blockIdx.x;
|
||||
const int ti = threadIdx.x;
|
||||
int cum_offset = bi == 0 ? 0 : output_cum_offsets_tmp[bi - 1];
|
||||
for (int i = ti; i < seq_lens_output[bi]; i += blockDim.x) {
|
||||
output_padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
|
||||
}
|
||||
if (ti == 0) {
|
||||
output_cum_offsets[bi] = cum_offset;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||
const paddle::Tensor& output_cum_offsets_tmp,
|
||||
const paddle::Tensor& out_token_num,
|
||||
const paddle::Tensor& seq_lens_output,
|
||||
const int max_seq_len) {
|
||||
auto cu_stream = output_cum_offsets_tmp.stream();
|
||||
std::vector<int64_t> output_cum_offsets_tmp_shape =
|
||||
output_cum_offsets_tmp.shape();
|
||||
const int bsz = output_cum_offsets_tmp_shape[0];
|
||||
auto cpu_out_token_num = out_token_num.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
auto output_padding_offset = paddle::full({cpu_out_token_num},
|
||||
0,
|
||||
paddle::DataType::INT32,
|
||||
output_cum_offsets_tmp.place());
|
||||
auto output_cum_offsets =
|
||||
output_cum_offsets_tmp.copy_to(output_cum_offsets_tmp.place(), false);
|
||||
|
||||
SpeculateGetOutputPaddingOffsetKernel<<<bsz, 256, 0, cu_stream>>>(
|
||||
output_padding_offset.data<int>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
output_cum_offsets_tmp.data<int>(),
|
||||
seq_lens_output.data<int>(),
|
||||
max_seq_len);
|
||||
|
||||
return {output_padding_offset, output_cum_offsets};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> SpeculateGetOutputPaddingOffsetInferShape(
|
||||
const std::vector<int64_t>& output_cum_offsets_tmp_shape,
|
||||
const std::vector<int64_t>& out_token_num_shape,
|
||||
const std::vector<int64_t>& seq_lens_output_shape) {
|
||||
int64_t bsz = output_cum_offsets_tmp_shape[0];
|
||||
return {{-1}, {bsz}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> SpeculateGetOutputPaddingOffsetInferDtype(
|
||||
const paddle::DataType& output_cum_offsets_tmp_dtype,
|
||||
const paddle::DataType& out_token_num_dtype,
|
||||
const paddle::DataType& seq_lens_output_dtype) {
|
||||
return {output_cum_offsets_tmp_dtype, output_cum_offsets_tmp_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_output_padding_offset)
|
||||
.Inputs({"output_cum_offsets_tmp", "out_token_num", "seq_lens_output"})
|
||||
.Outputs({"output_padding_offset", "output_cum_offsets"})
|
||||
.Attrs({"max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetOutputPaddingOffset))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetOutputPaddingOffsetInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetOutputPaddingOffsetInferDtype));
|
||||
@@ -20,8 +20,8 @@ __global__ inline void min_length_logits_process(
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int *output_padding_offset,
|
||||
const int *output_cum_offsets,
|
||||
const int *batch_id_per_token_output,
|
||||
const int *cu_seqlens_q_output,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
@@ -29,9 +29,9 @@ __global__ inline void min_length_logits_process(
|
||||
const int max_seq_len) {
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
|
||||
const int query_start_token_idx = cu_seqlens_q_output[bi];
|
||||
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
@@ -49,8 +49,8 @@ __global__ inline void min_length_logits_process<half>(
|
||||
const int64_t *cur_len,
|
||||
const int64_t *min_len,
|
||||
const int64_t *eos_token_id,
|
||||
const int *output_padding_offset,
|
||||
const int *output_cum_offsets,
|
||||
const int *batch_id_per_token_output,
|
||||
const int *cu_seqlens_q_output,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
@@ -58,9 +58,9 @@ __global__ inline void min_length_logits_process<half>(
|
||||
const int max_seq_len) {
|
||||
const int token_idx = threadIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
if (bi >= bs) return;
|
||||
const int query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
|
||||
const int query_start_token_idx = cu_seqlens_q_output[bi];
|
||||
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
@@ -75,7 +75,7 @@ __global__ inline void min_length_logits_process<half>(
|
||||
__global__ void update_repeat_times(const int64_t *pre_ids,
|
||||
const int64_t *cur_len,
|
||||
int *repeat_times,
|
||||
const int *output_padding_offset,
|
||||
const int *batch_id_per_token_output,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
@@ -83,7 +83,7 @@ __global__ void update_repeat_times(const int64_t *pre_ids,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
if (bi >= bs) return;
|
||||
if (cur_len[bi] < 0) {
|
||||
return;
|
||||
@@ -99,20 +99,21 @@ __global__ void update_repeat_times(const int64_t *pre_ids,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void update_value_by_repeat_times(const int *repeat_times,
|
||||
const T *penalty_scores,
|
||||
const T *frequency_score,
|
||||
const T *presence_score,
|
||||
const float *temperatures,
|
||||
T *logits,
|
||||
const int *output_padding_offset,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int max_seq_len) {
|
||||
__global__ void update_value_by_repeat_times(
|
||||
const int *repeat_times,
|
||||
const T *penalty_scores,
|
||||
const T *frequency_score,
|
||||
const T *presence_score,
|
||||
const float *temperatures,
|
||||
T *logits,
|
||||
const int *batch_id_per_token_output,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
T *logits_now = logits + token_idx * length;
|
||||
@@ -135,7 +136,7 @@ template <typename T>
|
||||
__global__ void ban_bad_words(T *logits,
|
||||
const int64_t *bad_tokens,
|
||||
const int64_t *bad_tokens_len,
|
||||
const int *output_padding_offset,
|
||||
const int *batch_id_per_token_output,
|
||||
const int64_t token_num,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
@@ -143,7 +144,7 @@ __global__ void ban_bad_words(T *logits,
|
||||
const int max_seq_len) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= token_num) return;
|
||||
const int bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len;
|
||||
const int bi = batch_id_per_token_output[token_idx];
|
||||
|
||||
if (bi >= bs) return;
|
||||
int tid = threadIdx.x;
|
||||
@@ -172,8 +173,8 @@ void token_penalty_multi_scores_kernel(
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &output_padding_offset,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const paddle::Tensor &batch_id_per_token_output,
|
||||
const paddle::Tensor &cu_seqlens_q_output,
|
||||
const int max_seq_len) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
@@ -196,8 +197,8 @@ void token_penalty_multi_scores_kernel(
|
||||
cur_len.data<int64_t>(),
|
||||
min_len.data<int64_t>(),
|
||||
eos_token_id.data<int64_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
batch_id_per_token_output.data<int>(),
|
||||
cu_seqlens_q_output.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
@@ -210,7 +211,7 @@ void token_penalty_multi_scores_kernel(
|
||||
pre_ids.data<int64_t>(),
|
||||
cur_len.data<int64_t>(),
|
||||
repeat_times.data<int>(),
|
||||
output_padding_offset.data<int>(),
|
||||
batch_id_per_token_output.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
@@ -231,7 +232,7 @@ void token_penalty_multi_scores_kernel(
|
||||
temperatures.data<float>(),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
output_padding_offset.data<int>(),
|
||||
batch_id_per_token_output.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
@@ -243,7 +244,7 @@ void token_penalty_multi_scores_kernel(
|
||||
const_cast<data_t *>(logits.data<data_t>())),
|
||||
bad_tokens.data<int64_t>(),
|
||||
bad_tokens_len.data<int64_t>(),
|
||||
output_padding_offset.data<int>(),
|
||||
batch_id_per_token_output.data<int>(),
|
||||
token_num,
|
||||
bs,
|
||||
length,
|
||||
@@ -251,21 +252,22 @@ void token_penalty_multi_scores_kernel(
|
||||
max_seq_len);
|
||||
}
|
||||
|
||||
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
const paddle::Tensor &presence_scores,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &bad_tokens_len,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &output_padding_offset,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const int max_seq_len) {
|
||||
void SpecTokenPenaltyMultiScores(
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
const paddle::Tensor &presence_scores,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &bad_tokens_len,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &batch_id_per_token_output,
|
||||
const paddle::Tensor &cu_seqlens_q_output,
|
||||
const int max_seq_len) {
|
||||
switch (logits.type()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return token_penalty_multi_scores_kernel<paddle::DataType::BFLOAT16>(
|
||||
@@ -281,8 +283,8 @@ void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
max_seq_len);
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
@@ -299,8 +301,8 @@ void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
max_seq_len);
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
@@ -317,8 +319,8 @@ void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
min_len,
|
||||
eos_token_id,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
batch_id_per_token_output,
|
||||
cu_seqlens_q_output,
|
||||
max_seq_len);
|
||||
}
|
||||
default: {
|
||||
@@ -343,8 +345,8 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
|
||||
"min_len",
|
||||
"eos_token_id",
|
||||
"seq_lens_this_time",
|
||||
"output_padding_offset",
|
||||
"output_cum_offsets"})
|
||||
"batch_id_per_token_output",
|
||||
"cu_seqlens_q_output"})
|
||||
.Outputs({"logits_out"})
|
||||
.Attrs({"max_seq_len: int"})
|
||||
.SetInplaceMap({{"logits", "logits_out"}})
|
||||
|
||||
@@ -82,7 +82,7 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids,
|
||||
const int64_t *max_dec_len,
|
||||
const int64_t *end_tokens,
|
||||
const bool *is_block_step,
|
||||
const int *output_cum_offsets,
|
||||
const int *cu_seqlens_q_output,
|
||||
const int *actual_candidate_len,
|
||||
const int *reasoning_status,
|
||||
const int real_bsz,
|
||||
@@ -101,7 +101,7 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids,
|
||||
int stop_flag_now_int = 0;
|
||||
|
||||
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
||||
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
|
||||
const int start_token_id = cu_seqlens_q_output[bid];
|
||||
|
||||
if (stop_flags[bid]) {
|
||||
stop_flag_now_int = 1;
|
||||
@@ -312,7 +312,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const paddle::Tensor &end_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q_output,
|
||||
const paddle::Tensor &actual_candidate_len,
|
||||
const paddle::Tensor &actual_draft_token_nums,
|
||||
const paddle::Tensor &topp,
|
||||
@@ -374,7 +374,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_tokens.data<int64_t>(),
|
||||
is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
cu_seqlens_q_output.data<int>(),
|
||||
actual_candidate_len.data<int>(),
|
||||
reasoning_status.data<int>(),
|
||||
real_bsz,
|
||||
@@ -407,7 +407,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_tokens.data<int64_t>(),
|
||||
is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
cu_seqlens_q_output.data<int>(),
|
||||
actual_candidate_len.data<int>(),
|
||||
reasoning_status.data<int>(),
|
||||
real_bsz,
|
||||
@@ -442,7 +442,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_tokens.data<int64_t>(),
|
||||
is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
cu_seqlens_q_output.data<int>(),
|
||||
actual_candidate_len.data<int>(),
|
||||
reasoning_status.data<int>(),
|
||||
real_bsz,
|
||||
@@ -475,7 +475,7 @@ void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||
max_dec_len.data<int64_t>(),
|
||||
end_tokens.data<int64_t>(),
|
||||
is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(),
|
||||
cu_seqlens_q_output.data<int>(),
|
||||
actual_candidate_len.data<int>(),
|
||||
reasoning_status.data<int>(),
|
||||
real_bsz,
|
||||
@@ -509,7 +509,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
|
||||
"max_dec_len",
|
||||
"end_tokens",
|
||||
"is_block_step",
|
||||
"output_cum_offsets",
|
||||
"cu_seqlens_q_output",
|
||||
"actual_candidate_len",
|
||||
"actual_draft_token_nums",
|
||||
"topp",
|
||||
|
||||
@@ -407,7 +407,7 @@ template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
|
||||
__global__ void KeMatrixTopPBeamTopKFt(
|
||||
const T* src,
|
||||
const T* top_ps,
|
||||
const int* output_padding_offset,
|
||||
const int* batch_id_per_token_output,
|
||||
int64_t* out_id, // [max_cadidate_len, 1]
|
||||
T* out_val, // [max_cadidate_len, 1]
|
||||
int* actual_candidates_lens,
|
||||
@@ -418,8 +418,7 @@ __global__ void KeMatrixTopPBeamTopKFt(
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + output_padding_offset[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
const int bid = batch_id_per_token_output[token_id];
|
||||
|
||||
int top_num = TopPBeamTopK;
|
||||
float top_p_value = static_cast<float>(top_ps[bid]);
|
||||
@@ -479,7 +478,7 @@ __global__ void KeMatrixTopPBeamTopKFt(
|
||||
template <typename T, int TopKMaxLength>
|
||||
void DispatchTopK(const T* src,
|
||||
const T* top_ps,
|
||||
const int* output_padding_offset,
|
||||
const int* batch_id_per_token_output,
|
||||
int64_t* out_id, // topk id
|
||||
T* out_val, // topk val
|
||||
int* actual_candidates_lens_data,
|
||||
@@ -495,7 +494,7 @@ void DispatchTopK(const T* src,
|
||||
KeMatrixTopPBeamTopKFt<T, TopKMaxLength, kTopK, kBlockDim>
|
||||
<<<token_num, kBlockDim, 0, stream>>>(src,
|
||||
top_ps,
|
||||
output_padding_offset,
|
||||
batch_id_per_token_output,
|
||||
out_id,
|
||||
out_val,
|
||||
actual_candidates_lens_data,
|
||||
@@ -516,7 +515,7 @@ template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor> LaunchTopPCandidates(
|
||||
const paddle::Tensor& probs, // [token_num, vocab_size]
|
||||
const paddle::Tensor& top_p, // [token_num]
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const int candidates_len,
|
||||
const int max_seq_len) {
|
||||
typedef PDTraits<D> traits_;
|
||||
@@ -540,7 +539,7 @@ std::vector<paddle::Tensor> LaunchTopPCandidates(
|
||||
DispatchTopK<DataType_, TopKMaxLength>(
|
||||
reinterpret_cast<const DataType_*>(probs.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(top_p.data<data_t>()),
|
||||
output_padding_offset.data<int>(),
|
||||
batch_id_per_token_output.data<int>(),
|
||||
verify_tokens.data<int64_t>(),
|
||||
reinterpret_cast<DataType_*>(verify_scores.data<data_t>()),
|
||||
actual_candidate_lens.data<int>(),
|
||||
@@ -556,21 +555,21 @@ std::vector<paddle::Tensor> LaunchTopPCandidates(
|
||||
std::vector<paddle::Tensor> DispatchTopPCandidatesWithDtype(
|
||||
const paddle::Tensor& probs,
|
||||
const paddle::Tensor& top_p,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
int candidates_len,
|
||||
int max_seq_len) {
|
||||
switch (probs.type()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return LaunchTopPCandidates<paddle::DataType::BFLOAT16>(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
return LaunchTopPCandidates<paddle::DataType::FLOAT16>(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len);
|
||||
break;
|
||||
case paddle::DataType::FLOAT32:
|
||||
return LaunchTopPCandidates<paddle::DataType::FLOAT32>(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len);
|
||||
break;
|
||||
default:
|
||||
PD_THROW(
|
||||
@@ -583,17 +582,17 @@ std::vector<paddle::Tensor> DispatchTopPCandidatesWithDtype(
|
||||
std::vector<paddle::Tensor> TopPCandidates(
|
||||
const paddle::Tensor& probs,
|
||||
const paddle::Tensor& top_p,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
int candidates_len,
|
||||
int max_seq_len) {
|
||||
return DispatchTopPCandidatesWithDtype(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
probs, top_p, batch_id_per_token_output, candidates_len, max_seq_len);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
|
||||
const std::vector<int64_t>& probs_shape,
|
||||
const std::vector<int64_t>& top_p_shape,
|
||||
const std::vector<int64_t>& output_padding_offset_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_output_shape,
|
||||
int max_candidates_len) {
|
||||
int token_num = probs_shape[0];
|
||||
return {{token_num, max_candidates_len},
|
||||
@@ -604,12 +603,12 @@ std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
|
||||
std::vector<paddle::DataType> TopPCandidatesInferDtype(
|
||||
const paddle::DataType& probs_dtype,
|
||||
const paddle::DataType& top_p_dtype,
|
||||
const paddle::DataType& output_padding_offset_dtype) {
|
||||
const paddle::DataType& batch_id_per_token_output_dtype) {
|
||||
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(top_p_candidates)
|
||||
.Inputs({"probs", "top_p", "output_padding_offset"})
|
||||
.Inputs({"probs", "top_p", "batch_id_per_token_output"})
|
||||
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
|
||||
.Attrs({"candidates_len: int", "max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(TopPCandidates))
|
||||
|
||||
Reference in New Issue
Block a user