mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 09:44:10 +08:00
[MTP] refactor MTP pre_process (#6358)
This commit is contained in:
@@ -428,9 +428,9 @@ paddle::Tensor RebuildPaddingFunc(
|
||||
const paddle::Tensor& seq_len_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
||||
const paddle::optional<paddle::Tensor>& batch_id_per_token_output,
|
||||
const paddle::optional<paddle::Tensor>& cu_seqlens_q_output,
|
||||
const paddle::optional<paddle::Tensor>& first_token_out,
|
||||
int max_input_length,
|
||||
bool enable_logprob);
|
||||
|
||||
void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
|
||||
@@ -747,28 +747,23 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||
const paddle::Tensor& output_cum_offsets_tmp,
|
||||
const paddle::Tensor& out_token_num,
|
||||
const paddle::Tensor& seq_lens_output,
|
||||
void SpecTokenPenaltyMultiScores(
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& penalty_scores,
|
||||
const paddle::Tensor& frequency_scores,
|
||||
const paddle::Tensor& presence_scores,
|
||||
const paddle::Tensor& temperatures,
|
||||
const paddle::Tensor& bad_tokens,
|
||||
const paddle::Tensor& bad_tokens_len,
|
||||
const paddle::Tensor& cur_len,
|
||||
const paddle::Tensor& min_len,
|
||||
const paddle::Tensor& eos_token_id,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const int max_seq_len);
|
||||
|
||||
void SpecTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& penalty_scores,
|
||||
const paddle::Tensor& frequency_scores,
|
||||
const paddle::Tensor& presence_scores,
|
||||
const paddle::Tensor& temperatures,
|
||||
const paddle::Tensor& bad_tokens,
|
||||
const paddle::Tensor& bad_tokens_len,
|
||||
const paddle::Tensor& cur_len,
|
||||
const paddle::Tensor& min_len,
|
||||
const paddle::Tensor& eos_token_id,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const int max_seq_len);
|
||||
|
||||
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& pre_ids,
|
||||
@@ -794,7 +789,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& end_tokens,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& actual_candidate_len,
|
||||
const paddle::Tensor& actual_draft_token_nums,
|
||||
const paddle::Tensor& topp,
|
||||
@@ -922,7 +917,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
@@ -1102,19 +1097,20 @@ std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
|
||||
|
||||
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);
|
||||
|
||||
void ReasoningPhaseTokenConstraint(const paddle::Tensor& logits,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& allowed_tokens,
|
||||
const paddle::Tensor& reasoning_status,
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
int64_t think_end_id,
|
||||
int64_t line_break_id);
|
||||
void ReasoningPhaseTokenConstraint(
|
||||
const paddle::Tensor& logits,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& allowed_tokens,
|
||||
const paddle::Tensor& reasoning_status,
|
||||
const paddle::Tensor& batch_id_per_token_output,
|
||||
const paddle::Tensor& cu_seqlens_q_output,
|
||||
const paddle::Tensor& enable_thinking,
|
||||
int64_t think_end_id,
|
||||
int64_t line_break_id);
|
||||
|
||||
std::vector<paddle::Tensor> get_attn_mask_q(
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
@@ -1612,10 +1608,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
&SpeculateGetSeqLensOutput,
|
||||
"speculate_get_seq_lens_output function");
|
||||
|
||||
m.def("speculate_get_output_padding_offset",
|
||||
&SpeculateGetOutputPaddingOffset,
|
||||
"speculate_get_output_padding_offset function");
|
||||
|
||||
m.def("speculate_get_token_penalty_multi_scores",
|
||||
&SpecTokenPenaltyMultiScores,
|
||||
"speculate_get_token_penalty_multi_scores function");
|
||||
|
||||
Reference in New Issue
Block a user