mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[XPU] Refactor pre process (#6993)
* [XPU] support speculate_pre_process * merge develop * fix codestype * fix mtp, support cu_seqlens_q_output * fix mtp, support cu_seqlens_q_output * fix test --------- Co-authored-by: lizan1999 <lizan03@baidu.com>
This commit is contained in:
@@ -388,8 +388,8 @@ DLL_EXPORT int speculate_token_penalty_multi_scores(
|
||||
const int64_t* min_len,
|
||||
const int64_t* eos_token_id,
|
||||
const int64_t* bad_words,
|
||||
const int* output_padding_offset,
|
||||
const int* output_cum_offsets,
|
||||
const int* batch_id_per_token_output,
|
||||
const int* cu_seqlens_q_output,
|
||||
const int64_t bs,
|
||||
const int64_t length,
|
||||
const int64_t length_id,
|
||||
@@ -432,7 +432,7 @@ DLL_EXPORT int speculate_verify(api::Context* ctx,
|
||||
const int64_t* max_dec_len,
|
||||
const int64_t* end_tokens,
|
||||
const bool* is_block_step,
|
||||
const int* output_cum_offsets,
|
||||
const int* cu_seqlens_q_output,
|
||||
const int* actual_candidate_len,
|
||||
const int real_bsz,
|
||||
const int max_draft_tokens,
|
||||
@@ -465,7 +465,7 @@ DLL_EXPORT int draft_model_update(api::Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
const int* output_cum_offsets,
|
||||
const int* cu_seqlens_q_output,
|
||||
bool* stop_flags,
|
||||
bool* not_need_stop,
|
||||
const int64_t* max_dec_len,
|
||||
@@ -574,7 +574,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
|
||||
DLL_EXPORT int top_p_candidates(api::Context* ctx,
|
||||
const T* src,
|
||||
const T* top_ps,
|
||||
const int* output_padding_offset,
|
||||
const int* batch_id_per_token_output,
|
||||
int64_t* out_id,
|
||||
T* out_val,
|
||||
int* actual_candidates_lens,
|
||||
@@ -630,6 +630,24 @@ DLL_EXPORT int speculate_schedule_cache(api::Context* ctx,
|
||||
const int block_num_per_seq,
|
||||
const bool prefill_one_step_stop);
|
||||
|
||||
DLL_EXPORT int speculate_preprocess(api::Context* ctx,
|
||||
int64_t* ids_remove_padding,
|
||||
int* batch_id_per_token,
|
||||
int* cu_seqlens_q,
|
||||
int* cu_seqlens_k,
|
||||
int* seq_lens_output,
|
||||
int* cu_seq_lens_q_output,
|
||||
int* batch_id_per_token_output,
|
||||
int* real_output_token_num,
|
||||
const int64_t* input_data,
|
||||
const int* seq_lens,
|
||||
const int64_t* draft_tokens,
|
||||
const int* seq_lens_encoder,
|
||||
const int max_seq_len,
|
||||
const int max_draft_tokens_per_batch,
|
||||
const int token_num_data,
|
||||
const int real_bs);
|
||||
|
||||
DLL_EXPORT int speculate_update_v3(api::Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
@@ -662,6 +680,31 @@ DLL_EXPORT int speculate_update(api::Context* ctx,
|
||||
const int max_bsz,
|
||||
const int max_draft_tokens);
|
||||
|
||||
DLL_EXPORT int unified_update_model_status(api::Context* ctx,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
bool* has_running_seqs,
|
||||
int* mask_rollback,
|
||||
int64_t* step_input_ids,
|
||||
int* adaptive_step_input_len,
|
||||
int64_t* step_output_ids,
|
||||
int* step_output_len,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
const bool* is_paused,
|
||||
int64_t* token_ids_all,
|
||||
const int64_t* prompt_lens,
|
||||
int64_t* step_idx,
|
||||
const int64_t* end_tokens,
|
||||
const int64_t* max_dec_len,
|
||||
int real_bsz,
|
||||
int max_bsz,
|
||||
int max_step_tokens,
|
||||
int max_model_len,
|
||||
int num_end_tokens,
|
||||
bool is_naive_mode,
|
||||
bool prefill_one_step_stop);
|
||||
|
||||
template <typename T>
|
||||
DLL_EXPORT int rebuild_hidden_states(api::Context* ctx,
|
||||
const T* input,
|
||||
|
||||
Reference in New Issue
Block a user