[XPU] Unify Spec and non-spec branch.(#6947) (#7180)

* [XPU] cherry-pick PR-6947

* [XPU] use unified_update_model_status.

* refactor xpu_model_runner.

* refactor sampler.

* fix codestyle.

* Fix XPU speculative decoding: rename output tensors to cu_seqlens_q_output/batch_id_per_token_output, correct
  WRAPPER_CHECK_PTR types, and fix dynamic gather shape in verify_draft_tokens path.

* fix codestyle.

* replace output_padding_offset with is_speculative flag in gather_next_token.

* rename hiddden_states.

* unify cu_seqlens_q_output and batch_id_per_token_output init.

---------

Co-authored-by: cmcamdy <1027740945@qq.com>
This commit is contained in:
Jiajun Ji
2026-04-16 14:58:38 +08:00
committed by GitHub
parent 17002edc47
commit 29495b2cf1
9 changed files with 226 additions and 149 deletions
+11 -19
View File
@@ -286,20 +286,12 @@ class InputBatch:
fill_value=max_draft_token_num,
dtype="int32",
)
if current_platform.is_cuda():
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
self.batch_id_per_token_output = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0,
dtype="int32",
)
else:
self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.output_padding_offset = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0,
dtype="int32",
)
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
self.batch_id_per_token_output = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0,
dtype="int32",
)
# For V1_KVCACHE_SCHEDULER
self.step_draft_tokens = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
@@ -437,7 +429,7 @@ class InputBatch:
if current_platform.is_cuda():
swap_data(self.cu_seqlens_q_output, i1, i2)
else:
swap_data(self.output_cum_offsets, i1, i2)
swap_data(self.cu_seqlens_q_output, i1, i2)
swap_data(self.step_draft_tokens, i1, i2)
swap_data(self.step_seq_lens_this_time, i1, i2)
swap_data(self.draft_logits, i1, i2)
@@ -628,8 +620,8 @@ class InputBatch:
fill_paddle_tensor(self, "accept_num", 0)
fill_paddle_tensor(self, "draft_tokens", -1)
fill_paddle_tensor(self, "actual_draft_token_num", max_draft_token_num)
fill_paddle_tensor(self, "output_cum_offsets", 0)
fill_paddle_tensor(self, "output_padding_offset", 0)
fill_paddle_tensor(self, "cu_seqlens_q_output", 0)
fill_paddle_tensor(self, "batch_id_per_token_output", 0)
fill_paddle_tensor(self, "step_draft_tokens", 0)
fill_paddle_tensor(self, "step_seq_lens_this_time", 0)
fill_paddle_tensor(self, "draft_logits", -1)
@@ -742,8 +734,8 @@ class ProposerInputBatch(InputBatch):
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
self.token_ids_all = None
else:
self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"])
self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"])
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])