[XPU] Unify Spec and non-spec branch.(#6947) (#7180)

* [XPU] cherry-pick PR-6947

* [XPU] use unified_update_model_status.

* refactor xpu_model_runner.

* refactor sampler.

* fix codestyle.

* Fix XPU speculative decoding: rename output tensors to cu_seqlens_q_output/batch_id_per_token_output, correct
  WRAPPER_CHECK_PTR types, and fix dynamic gather shape in verify_draft_tokens path.

* fix codestyle.

* replace output_padding_offset with is_speculative flag in gather_next_token.

* rename hiddden_states.

* unify cu_seqlens_q_output and batch_id_per_token_output init.

---------

Co-authored-by: cmcamdy <1027740945@qq.com>
This commit is contained in:
Jiajun Ji
2026-04-16 14:58:38 +08:00
committed by GitHub
parent 17002edc47
commit 29495b2cf1
9 changed files with 226 additions and 149 deletions
@@ -43,12 +43,11 @@ if current_platform.is_xpu():
speculate_pre_process,
speculate_save_output,
speculate_set_stop_value_multi_seqs,
speculate_set_value_by_flags_and_idx,
speculate_step_paddle,
speculate_step_reschedule,
speculate_step_system_cache,
speculate_update,
step_paddle,
unified_update_model_status,
update_inputs,
update_inputs_v1,
)
@@ -172,6 +171,7 @@ def xpu_pre_process(
block_tables=share_inputs["block_tables"],
caches=share_inputs["caches"],
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
is_speculative=use_speculate_method,
)
(
@@ -249,11 +249,6 @@ def xpu_process_output(
) -> paddle.Tensor:
""" """
if isinstance(share_inputs, dict):
output_padding_offset = share_inputs.get("output_padding_offset", None)
else:
output_padding_offset = getattr(share_inputs, "output_padding_offset", None)
hidden_states = gather_next_token(
forward_output,
xpu_forward_meta.encoder_seq_lod,
@@ -265,7 +260,7 @@ def xpu_process_output(
xpu_forward_meta.encoder_batch_map_cpu,
xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.len_info_cpu,
output_padding_offset, # output_padding_offset
xpu_forward_meta.is_speculative,
xpu_forward_meta.max_num_seqs,
)
return hidden_states
@@ -416,6 +411,8 @@ def xpu_post_process_specualate(
share_inputs: Dict[str, paddle.Tensor],
save_each_rank: bool = False,
skip_save_output: bool = False,
is_naive_mode: bool = False,
prefill_one_step_stop: bool = False,
):
""""""
@@ -432,7 +429,7 @@ def xpu_post_process_specualate(
model_output.min_tokens,
)
speculate_update(
unified_update_model_status(
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.not_need_stop,
@@ -444,6 +441,13 @@ def xpu_post_process_specualate(
model_output.seq_lens_this_time,
model_output.is_block_step,
model_output.mask_rollback,
model_output.pre_ids,
model_output.prompt_lens,
model_output.step_idx,
model_output.eos_token_id,
model_output.max_dec_len,
is_naive_mode,
prefill_one_step_stop,
)
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
@@ -464,18 +468,6 @@ def xpu_post_process_specualate(
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
# Update pre_ids through accept tokens
speculate_set_value_by_flags_and_idx(
model_output.pre_ids,
model_output.accept_tokens,
model_output.accept_num,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.step_idx,
)
def step_xpu(
share_inputs: Dict[str, paddle.Tensor],