[XPU] Unify Spec and non-spec branch.(#6947) (#7180)

* [XPU] cherry-pick PR-6947

* [XPU] use unified_update_model_status.

* refactor xpu_model_runner.

* refactor sampler.

* fix codestyle.

* Fix XPU speculative decoding: rename output tensors to cu_seqlens_q_output/batch_id_per_token_output, correct
  WRAPPER_CHECK_PTR types, and fix dynamic gather shape in verify_draft_tokens path.

* fix codestyle.

* replace output_padding_offset with is_speculative flag in gather_next_token.

* rename hiddden_states.

* unify cu_seqlens_q_output and batch_id_per_token_output init.

---------

Co-authored-by: cmcamdy <1027740945@qq.com>
This commit is contained in:
Jiajun Ji
2026-04-16 14:58:38 +08:00
committed by GitHub
parent 17002edc47
commit 29495b2cf1
9 changed files with 226 additions and 149 deletions
@@ -1045,6 +1045,120 @@ class SpeculativeSampler(nn.Layer):
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
return sampler_output
def _normal_sample_xpu(
self,
logits: paddle.Tensor,
probs: paddle.Tensor,
sampling_metadata: SamplingMetadata,
share_inputs: List[paddle.Tensor],
) -> SamplerOutput:
"""Normal sampling for NAIVE mode on XPU."""
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, next_tokens = top_k_top_p_sampling(
probs,
top_p=top_p,
top_k=top_k,
top_k_list=sampling_metadata.top_k_list,
topp_seed=topp_seed,
)
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
running_mask = (paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]) > 0).cast("int32")
share_inputs["accept_tokens"][:real_bsz, 0] = next_tokens.squeeze(-1)
share_inputs["accept_num"][:real_bsz] = running_mask
return SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
logits=logits,
)
def _verify_and_sample_xpu(
self,
logits: paddle.Tensor,
probs: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> SamplerOutput:
"""Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens."""
from fastdeploy.model_executor.ops.xpu import (
top_p_candidates,
verify_draft_tokens,
)
target_tokens = None
candidate_ids, candidate_scores, candidate_lens = None, None, None
if self.verify_strategy == VerifyStrategy.TARGET_MATCH:
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, target_tokens = top_k_top_p_sampling(
probs,
top_p=top_p,
top_k=top_k,
top_k_list=sampling_metadata.top_k_list,
topp_seed=topp_seed,
)
elif self.verify_strategy == VerifyStrategy.GREEDY:
target_tokens = paddle.argmax(probs, axis=-1)
elif self.verify_strategy == VerifyStrategy.TOPP:
candidate_scores, candidate_ids, candidate_lens = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len,
max_model_len,
)
else:
raise ValueError(f"Unknown verify strategy: {self.verify_strategy}")
final_accept_all = self.config_accept_all or accept_all_drafts
final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode
verify_draft_tokens(
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["draft_tokens"],
target_tokens,
candidate_ids,
candidate_scores,
candidate_lens,
sampling_metadata.top_p,
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_this_time"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["cu_seqlens_q_output"],
share_inputs["reasoning_status"],
share_inputs["max_dec_len"],
share_inputs["step_idx"],
max_model_len,
self.speculative_verify_window,
self.verify_strategy.value,
final_reject_all,
final_accept_all,
)
return SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
logits=logits,
)
def forward_xpu(
self,
logits: paddle.Tensor,
@@ -1053,9 +1167,7 @@ class SpeculativeSampler(nn.Layer):
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> paddle.Tensor:
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
) -> SamplerOutput:
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.token_ids_all,
sampling_metadata.prompt_lens,
@@ -1078,61 +1190,19 @@ class SpeculativeSampler(nn.Layer):
probs = F.softmax(logits)
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, sampled_token_ids = top_k_top_p_sampling(
probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed
)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len,
max_model_len,
)
speculate_verify(
sampled_token_ids,
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["step_idx"],
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs[
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
share_inputs["seq_lens_this_time"],
verify_tokens,
verify_scores,
share_inputs["max_dec_len"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["cu_seqlens_q_output"],
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
max_model_len,
self.speculative_verify_window,
True, # enable_topp
(self.speculative_benchmark_mode or reject_all_drafts),
accept_all_drafts,
)
# TODO(chenhuan09): support return logprobs
token_ids = share_inputs["accept_tokens"]
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=None,
)
return sampler_output
is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE
if is_naive:
return self._normal_sample_xpu(logits, probs, sampling_metadata, share_inputs)
else:
return self._verify_and_sample_xpu(
logits,
probs,
sampling_metadata,
max_model_len,
share_inputs,
accept_all_drafts,
reject_all_drafts,
)
class MTPSampler(nn.Layer):