[XPU] Refactor pre process (#6993)

* [XPU] support speculate_pre_process

* merge develop

* fix codestype

* fix mtp, support cu_seqlens_q_output

* fix mtp, support cu_seqlens_q_output

* fix test

---------

Co-authored-by: lizan1999 <lizan03@baidu.com>
This commit is contained in:
cmcamdy
2026-04-01 20:29:55 +08:00
committed by GitHub
parent fba8a51ad1
commit 7a2e33098f
36 changed files with 2725 additions and 511 deletions
@@ -388,8 +388,8 @@ DLL_EXPORT int speculate_token_penalty_multi_scores(
const int64_t* min_len,
const int64_t* eos_token_id,
const int64_t* bad_words,
const int* output_padding_offset,
const int* output_cum_offsets,
const int* batch_id_per_token_output,
const int* cu_seqlens_q_output,
const int64_t bs,
const int64_t length,
const int64_t length_id,
@@ -432,7 +432,7 @@ DLL_EXPORT int speculate_verify(api::Context* ctx,
const int64_t* max_dec_len,
const int64_t* end_tokens,
const bool* is_block_step,
const int* output_cum_offsets,
const int* cu_seqlens_q_output,
const int* actual_candidate_len,
const int real_bsz,
const int max_draft_tokens,
@@ -465,7 +465,7 @@ DLL_EXPORT int draft_model_update(api::Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
const int* output_cum_offsets,
const int* cu_seqlens_q_output,
bool* stop_flags,
bool* not_need_stop,
const int64_t* max_dec_len,
@@ -574,7 +574,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
DLL_EXPORT int top_p_candidates(api::Context* ctx,
const T* src,
const T* top_ps,
const int* output_padding_offset,
const int* batch_id_per_token_output,
int64_t* out_id,
T* out_val,
int* actual_candidates_lens,
@@ -630,6 +630,24 @@ DLL_EXPORT int speculate_schedule_cache(api::Context* ctx,
const int block_num_per_seq,
const bool prefill_one_step_stop);
DLL_EXPORT int speculate_preprocess(api::Context* ctx,
int64_t* ids_remove_padding,
int* batch_id_per_token,
int* cu_seqlens_q,
int* cu_seqlens_k,
int* seq_lens_output,
int* cu_seq_lens_q_output,
int* batch_id_per_token_output,
int* real_output_token_num,
const int64_t* input_data,
const int* seq_lens,
const int64_t* draft_tokens,
const int* seq_lens_encoder,
const int max_seq_len,
const int max_draft_tokens_per_batch,
const int token_num_data,
const int real_bs);
DLL_EXPORT int speculate_update_v3(api::Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
@@ -662,6 +680,31 @@ DLL_EXPORT int speculate_update(api::Context* ctx,
const int max_bsz,
const int max_draft_tokens);
DLL_EXPORT int unified_update_model_status(api::Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
bool* has_running_seqs,
int* mask_rollback,
int64_t* step_input_ids,
int* adaptive_step_input_len,
int64_t* step_output_ids,
int* step_output_len,
bool* stop_flags,
int* seq_lens_this_time,
const bool* is_paused,
int64_t* token_ids_all,
const int64_t* prompt_lens,
int64_t* step_idx,
const int64_t* end_tokens,
const int64_t* max_dec_len,
int real_bsz,
int max_bsz,
int max_step_tokens,
int max_model_len,
int num_end_tokens,
bool is_naive_mode,
bool prefill_one_step_stop);
template <typename T>
DLL_EXPORT int rebuild_hidden_states(api::Context* ctx,
const T* input,