[MTP] refactor MTP pre_process (#6358)

This commit is contained in:
周周周
2026-02-09 10:47:15 +08:00
committed by GitHub
parent 18e79dd660
commit 2b4748de4f
24 changed files with 411 additions and 533 deletions
+33 -41
View File
@@ -428,9 +428,9 @@ paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor& seq_len_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::optional<paddle::Tensor>& output_padding_offset,
const paddle::optional<paddle::Tensor>& batch_id_per_token_output,
const paddle::optional<paddle::Tensor>& cu_seqlens_q_output,
const paddle::optional<paddle::Tensor>& first_token_out,
int max_input_length,
bool enable_logprob);
void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
@@ -747,28 +747,23 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
const paddle::Tensor& seq_lens_output,
void SpecTokenPenaltyMultiScores(
const paddle::Tensor& pre_ids,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& bad_tokens_len,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const int max_seq_len);
void SpecTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& bad_tokens_len,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const int max_seq_len);
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& pre_ids,
@@ -794,7 +789,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_tokens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& actual_candidate_len,
const paddle::Tensor& actual_draft_token_nums,
const paddle::Tensor& topp,
@@ -922,7 +917,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
@@ -1102,19 +1097,20 @@ std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
const paddle::Tensor& pre_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& output_padding_offset,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& enable_thinking,
int64_t think_end_id,
int64_t line_break_id);
void ReasoningPhaseTokenConstraint(
const paddle::Tensor& logits,
const paddle::Tensor& pre_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& allowed_tokens,
const paddle::Tensor& reasoning_status,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& enable_thinking,
int64_t think_end_id,
int64_t line_break_id);
std::vector<paddle::Tensor> get_attn_mask_q(
const paddle::Tensor& cu_seqlens_q,
@@ -1612,10 +1608,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&SpeculateGetSeqLensOutput,
"speculate_get_seq_lens_output function");
m.def("speculate_get_output_padding_offset",
&SpeculateGetOutputPaddingOffset,
"speculate_get_output_padding_offset function");
m.def("speculate_get_token_penalty_multi_scores",
&SpecTokenPenaltyMultiScores,
"speculate_get_token_penalty_multi_scores function");