[XPU] Refactor pre process (#6993)

* [XPU] support speculate_pre_process

* merge develop

* fix codestype

* fix mtp, support cu_seqlens_q_output

* fix mtp, support cu_seqlens_q_output

* fix test

---------

Co-authored-by: lizan1999 <lizan03@baidu.com>
This commit is contained in:
cmcamdy
2026-04-01 20:29:55 +08:00
committed by GitHub
parent fba8a51ad1
commit 7a2e33098f
36 changed files with 2725 additions and 511 deletions
@@ -1069,8 +1069,8 @@ class SpeculativeSampler(nn.Layer):
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
share_inputs["batch_id_per_token_output"],
share_inputs["cu_seqlens_q_output"],
max_model_len,
sampling_metadata.pre_token_ids,
)
@@ -1091,7 +1091,7 @@ class SpeculativeSampler(nn.Layer):
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["output_padding_offset"],
share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len,
max_model_len,
)
@@ -1113,7 +1113,7 @@ class SpeculativeSampler(nn.Layer):
share_inputs["max_dec_len"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["output_cum_offsets"],
share_inputs["cu_seqlens_q_output"],
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
@@ -1338,8 +1338,8 @@ class MTPSampler(nn.Layer):
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
share_inputs["batch_id_per_token_output"],
share_inputs["cu_seqlens_q_output"],
max_model_len,
sampling_metadata.pre_token_ids,
)