From 29495b2cf13035b7e1fb1aa77c9ae57e2030a4bd Mon Sep 17 00:00:00 2001 From: Jiajun Ji Date: Thu, 16 Apr 2026 14:58:38 +0800 Subject: [PATCH] [XPU] Unify Spec and non-spec branch.(#6947) (#7180) * [XPU] cherry-pick PR-6947 * [XPU] use unified_update_model_status. * refactor xpu_model_runner. * refactor sampler. * fix codestyle. * Fix XPU speculative decoding: rename output tensors to cu_seqlens_q_output/batch_id_per_token_output, correct WRAPPER_CHECK_PTR types, and fix dynamic gather shape in verify_draft_tokens path. * fix codestyle. * replace output_padding_offset with is_speculative flag in gather_next_token. * rename hiddden_states. * unify cu_seqlens_q_output and batch_id_per_token_output init. --------- Co-authored-by: cmcamdy <1027740945@qq.com> --- .../xpu_ops/src/ops/gather_next_token.cc | 22 +-- custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 30 ++- .../xpu_ops/src/plugin/include/xpu/plugin.h | 1 - ...test_adjust_batch_and_gather_next_token.py | 17 +- fastdeploy/model_executor/forward_meta.py | 1 + .../model_executor/layers/sample/sampler.py | 186 ++++++++++++------ .../xpu_pre_and_post_process.py | 34 ++-- fastdeploy/worker/input_batch.py | 30 ++- fastdeploy/worker/xpu_model_runner.py | 54 ++--- 9 files changed, 226 insertions(+), 149 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/gather_next_token.cc b/custom_ops/xpu_ops/src/ops/gather_next_token.cc index 31c2142ca0..ee261965d6 100644 --- a/custom_ops/xpu_ops/src/ops/gather_next_token.cc +++ b/custom_ops/xpu_ops/src/ops/gather_next_token.cc @@ -32,7 +32,7 @@ std::vector GatherNextToken( const paddle::Tensor& encoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& len_info_cpu, - const paddle::optional& output_padding_offset, + bool is_speculative, int max_bsz) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); @@ -73,7 +73,7 @@ std::vector GatherNextToken( const_cast(decoder_batch_map.data())}; paddle::Tensor out; - if (output_padding_offset) { + if (is_speculative) { int need_delete_token_num = 0; if (enc_batch > 0) { need_delete_token_num = @@ -88,7 +88,7 @@ std::vector GatherNextToken( return {out}; } - if (output_padding_offset) { + if (is_speculative) { int r = fastdeploy::plugin::eb_mtp_gather_next_token( ctx, reinterpret_cast(x.data()), @@ -124,14 +124,10 @@ std::vector> GatherNextTokenInferShape( const std::vector& encoder_batch_map_cpu_shape, const std::vector& decoder_batch_map_cpu_shape, const std::vector& len_info_cpu_shape, - const paddle::optional>& output_padding_offset_shape) { - // if (output_padding_offset_shape) { - // PD_THROW("speculative decoding is not supported in XPU."); - // } - // int64_t bsz = cum_offsets_shape[0]; + bool is_speculative) { int64_t bsz = 0; int64_t dim_embed = x_shape[1]; - if (output_padding_offset_shape) { + if (is_speculative) { return {{-1, dim_embed}}; } else { return {{bsz, dim_embed}}; @@ -148,8 +144,7 @@ std::vector GatherNextTokenInferDtype( const paddle::DataType& decoder_seq_lod_cpu_dtype, const paddle::DataType& encoder_batch_map_cpu_dtype, const paddle::DataType& decoder_batch_map_cpu_dtype, - const paddle::DataType& len_info_cpu_dtype, - const paddle::optional& output_padding_offset_dtype) { + const paddle::DataType& len_info_cpu_dtype) { return {x_dtype}; } @@ -163,10 +158,9 @@ PD_BUILD_STATIC_OP(gather_next_token) "decoder_seq_lod_cpu", "encoder_batch_map_cpu", "decoder_batch_map_cpu", - "len_info_cpu", - paddle::Optional("output_padding_offset")}) + "len_info_cpu"}) .Outputs({"out"}) - .Attrs({"max_bsz: int"}) + .Attrs({"is_speculative: bool", "max_bsz: int"}) .SetKernelFn(PD_KERNEL(GatherNextToken)) .SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index e8a66c750a..887a9ac95a 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -465,7 +465,7 @@ std::vector GatherNextToken( const paddle::Tensor& encoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& len_info_cpu, - const paddle::optional& output_padding_offset, + bool is_speculative, int max_bsz); std::vector GetImgBoundaries( @@ -1035,7 +1035,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("encoder_batch_map_cpu"), py::arg("decoder_batch_map_cpu"), py::arg("len_info_cpu"), - py::arg("output_padding_offset"), + py::arg("is_speculative"), py::arg("max_bsz"), "Gather next token for XPU"); @@ -1164,6 +1164,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("max_draft_tokens"), "Unified update model status"); + m.def("verify_draft_tokens", + &VerifyDraftTokens, + py::arg("step_output_ids"), + py::arg("step_output_len"), + py::arg("step_input_ids"), + py::arg("target_tokens"), + py::arg("candidate_ids"), + py::arg("candidate_scores"), + py::arg("candidate_lens"), + py::arg("topp"), + py::arg("stop_flags"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_this_time"), + py::arg("end_tokens"), + py::arg("is_block_step"), + py::arg("cu_seqlens_q_output"), + py::arg("reasoning_status"), + py::arg("max_dec_len"), + py::arg("step_idx"), + py::arg("max_seq_len"), + py::arg("verify_window"), + py::arg("verify_strategy"), + py::arg("reject_all"), + py::arg("accept_all"), + "Perform speculative verification for decoding v2"); + m.def("mtp_step_paddle", &MTPStepPaddle, py::arg("base_model_stop_flags"), diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 76e684dc49..f7bcc2042d 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -766,7 +766,6 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel( const int eos_token_id_len, const int inject_len, const bool splitwise_role_is_decode); - DLL_EXPORT int verify_draft_tokens( api::Context* ctx, // Core I/O diff --git a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py index 758dff17e5..bc074242b4 100644 --- a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py +++ b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py @@ -24,7 +24,7 @@ from fastdeploy.model_executor.ops.xpu import ( ) -def _run_test_base(seq_lens_this_time_data, output_padding_offset): +def _run_test_base(seq_lens_this_time_data, is_speculative): """ 通用的基础测试执行函数,包含了两个场景共有的逻辑。 """ @@ -120,7 +120,7 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset): encoder_batch_map_cpu, decoder_batch_map_cpu, len_info_cpu, - output_padding_offset, + is_speculative, -1, ) @@ -136,14 +136,14 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset): encoder_batch_map_cpu, decoder_batch_map_cpu, len_info_cpu, - output_padding_offset, + is_speculative, -1, ) gather_out_np = gather_out.astype("float32").cpu().numpy() gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy() - if output_padding_offset is not None: + if is_speculative: np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!") else: for i in range(gather_out_cpu.shape[0]): @@ -160,19 +160,14 @@ class TestXPUOps(unittest.TestCase): # 继承 unittest.TestCase """测试混合批次处理中的 MTP (Multi-Token Prediction) 场景""" print("\nRunning test: test_mix_with_mtp") seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3] - bsz = len(seq_lens_this_time_data) - output_padding_offset = paddle.zeros(bsz, dtype="int32") - - _run_test_base(seq_lens_this_time_data, output_padding_offset) + _run_test_base(seq_lens_this_time_data, True) print("Test passed for scenario: With MTP") def test_mix_without_mtp(self): """测试非 MTP (Single-Token Prediction) 场景下的功能""" print("\nRunning test: test_mix_without_mtp") seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1] - output_padding_offset = None # 非 MTP 场景下,此参数为 None - - _run_test_base(seq_lens_this_time_data, output_padding_offset) + _run_test_base(seq_lens_this_time_data, False) print("Test passed for scenario: Without MTP") diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 9b36556a6e..491ba0732a 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -275,6 +275,7 @@ class XPUForwardMeta(ForwardMeta): hidden_states: Optional[paddle.Tensor] = None is_draft: bool = False + is_speculative: bool = False # max bs max_num_seqs: int = 0 diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index c08395c964..507e88199c 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -1045,6 +1045,120 @@ class SpeculativeSampler(nn.Layer): sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu() return sampler_output + def _normal_sample_xpu( + self, + logits: paddle.Tensor, + probs: paddle.Tensor, + sampling_metadata: SamplingMetadata, + share_inputs: List[paddle.Tensor], + ) -> SamplerOutput: + """Normal sampling for NAIVE mode on XPU.""" + top_p, top_k, topp_seed = padding_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), + paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), + ) + _, next_tokens = top_k_top_p_sampling( + probs, + top_p=top_p, + top_k=top_k, + top_k_list=sampling_metadata.top_k_list, + topp_seed=topp_seed, + ) + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + running_mask = (paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]) > 0).cast("int32") + share_inputs["accept_tokens"][:real_bsz, 0] = next_tokens.squeeze(-1) + share_inputs["accept_num"][:real_bsz] = running_mask + return SamplerOutput( + sampled_token_ids=share_inputs["accept_tokens"], + logprobs_tensors=None, + token_num_per_batch=share_inputs["accept_num"], + logits=logits, + ) + + def _verify_and_sample_xpu( + self, + logits: paddle.Tensor, + probs: paddle.Tensor, + sampling_metadata: SamplingMetadata, + max_model_len: int, + share_inputs: List[paddle.Tensor], + accept_all_drafts: bool = False, + reject_all_drafts: bool = False, + ) -> SamplerOutput: + """Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens.""" + from fastdeploy.model_executor.ops.xpu import ( + top_p_candidates, + verify_draft_tokens, + ) + + target_tokens = None + candidate_ids, candidate_scores, candidate_lens = None, None, None + + if self.verify_strategy == VerifyStrategy.TARGET_MATCH: + top_p, top_k, topp_seed = padding_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), + paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), + ) + _, target_tokens = top_k_top_p_sampling( + probs, + top_p=top_p, + top_k=top_k, + top_k_list=sampling_metadata.top_k_list, + topp_seed=topp_seed, + ) + elif self.verify_strategy == VerifyStrategy.GREEDY: + target_tokens = paddle.argmax(probs, axis=-1) + elif self.verify_strategy == VerifyStrategy.TOPP: + candidate_scores, candidate_ids, candidate_lens = top_p_candidates( + probs, + sampling_metadata.top_p, + share_inputs["batch_id_per_token_output"], + self.speculative_max_candidate_len, + max_model_len, + ) + else: + raise ValueError(f"Unknown verify strategy: {self.verify_strategy}") + + final_accept_all = self.config_accept_all or accept_all_drafts + final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode + + verify_draft_tokens( + share_inputs["accept_tokens"], + share_inputs["accept_num"], + share_inputs["draft_tokens"], + target_tokens, + candidate_ids, + candidate_scores, + candidate_lens, + sampling_metadata.top_p, + share_inputs["stop_flags"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_this_time"], + sampling_metadata.eos_token_ids, + share_inputs["is_block_step"], + share_inputs["cu_seqlens_q_output"], + share_inputs["reasoning_status"], + share_inputs["max_dec_len"], + share_inputs["step_idx"], + max_model_len, + self.speculative_verify_window, + self.verify_strategy.value, + final_reject_all, + final_accept_all, + ) + return SamplerOutput( + sampled_token_ids=share_inputs["accept_tokens"], + logprobs_tensors=None, + token_num_per_batch=share_inputs["accept_num"], + logits=logits, + ) + def forward_xpu( self, logits: paddle.Tensor, @@ -1053,9 +1167,7 @@ class SpeculativeSampler(nn.Layer): share_inputs: List[paddle.Tensor], accept_all_drafts: bool = False, reject_all_drafts: bool = False, - ) -> paddle.Tensor: - from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates - + ) -> SamplerOutput: logits = apply_speculative_penalty_multi_scores( sampling_metadata.token_ids_all, sampling_metadata.prompt_lens, @@ -1078,61 +1190,19 @@ class SpeculativeSampler(nn.Layer): probs = F.softmax(logits) - top_p, top_k, topp_seed = padding_sampling_params( - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.seed, - paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), - paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), - ) - _, sampled_token_ids = top_k_top_p_sampling( - probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed - ) - - verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( - probs, - sampling_metadata.top_p, - share_inputs["batch_id_per_token_output"], - self.speculative_max_candidate_len, - max_model_len, - ) - - speculate_verify( - sampled_token_ids, - share_inputs["accept_tokens"], - share_inputs["accept_num"], - share_inputs["step_idx"], - share_inputs["stop_flags"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs[ - "draft_tokens" - ], # Both input and output, need to write the last 1 token accepted to position 0. - share_inputs["seq_lens_this_time"], - verify_tokens, - verify_scores, - share_inputs["max_dec_len"], - sampling_metadata.eos_token_ids, - share_inputs["is_block_step"], - share_inputs["cu_seqlens_q_output"], - actual_candidate_len, - share_inputs["actual_draft_token_num"], - sampling_metadata.top_p, - max_model_len, - self.speculative_verify_window, - True, # enable_topp - (self.speculative_benchmark_mode or reject_all_drafts), - accept_all_drafts, - ) - # TODO(chenhuan09): support return logprobs - token_ids = share_inputs["accept_tokens"] - sampler_output = SamplerOutput( - sampled_token_ids=token_ids, - logprobs_tensors=None, - token_num_per_batch=share_inputs["accept_num"], - cu_batch_token_offset=None, - ) - return sampler_output + is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE + if is_naive: + return self._normal_sample_xpu(logits, probs, sampling_metadata, share_inputs) + else: + return self._verify_and_sample_xpu( + logits, + probs, + sampling_metadata, + max_model_len, + share_inputs, + accept_all_drafts, + reject_all_drafts, + ) class MTPSampler(nn.Layer): diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index b8c97529dc..523a08d0c1 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -43,12 +43,11 @@ if current_platform.is_xpu(): speculate_pre_process, speculate_save_output, speculate_set_stop_value_multi_seqs, - speculate_set_value_by_flags_and_idx, speculate_step_paddle, speculate_step_reschedule, speculate_step_system_cache, - speculate_update, step_paddle, + unified_update_model_status, update_inputs, update_inputs_v1, ) @@ -172,6 +171,7 @@ def xpu_pre_process( block_tables=share_inputs["block_tables"], caches=share_inputs["caches"], max_num_seqs=share_inputs["seq_lens_this_time"].shape[0], + is_speculative=use_speculate_method, ) ( @@ -249,11 +249,6 @@ def xpu_process_output( ) -> paddle.Tensor: """ """ - if isinstance(share_inputs, dict): - output_padding_offset = share_inputs.get("output_padding_offset", None) - else: - output_padding_offset = getattr(share_inputs, "output_padding_offset", None) - hidden_states = gather_next_token( forward_output, xpu_forward_meta.encoder_seq_lod, @@ -265,7 +260,7 @@ def xpu_process_output( xpu_forward_meta.encoder_batch_map_cpu, xpu_forward_meta.decoder_batch_map_cpu, xpu_forward_meta.len_info_cpu, - output_padding_offset, # output_padding_offset + xpu_forward_meta.is_speculative, xpu_forward_meta.max_num_seqs, ) return hidden_states @@ -416,6 +411,8 @@ def xpu_post_process_specualate( share_inputs: Dict[str, paddle.Tensor], save_each_rank: bool = False, skip_save_output: bool = False, + is_naive_mode: bool = False, + prefill_one_step_stop: bool = False, ): """""" @@ -432,7 +429,7 @@ def xpu_post_process_specualate( model_output.min_tokens, ) - speculate_update( + unified_update_model_status( model_output.seq_lens_encoder, model_output.seq_lens_decoder, model_output.not_need_stop, @@ -444,6 +441,13 @@ def xpu_post_process_specualate( model_output.seq_lens_this_time, model_output.is_block_step, model_output.mask_rollback, + model_output.pre_ids, + model_output.prompt_lens, + model_output.step_idx, + model_output.eos_token_id, + model_output.max_dec_len, + is_naive_mode, + prefill_one_step_stop, ) if not skip_save_output: if sampler_output.logprobs_tensors is None: @@ -464,18 +468,6 @@ def xpu_post_process_specualate( speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder) - # Update pre_ids through accept tokens - speculate_set_value_by_flags_and_idx( - model_output.pre_ids, - model_output.accept_tokens, - model_output.accept_num, - model_output.stop_flags, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - model_output.step_idx, - ) - def step_xpu( share_inputs: Dict[str, paddle.Tensor], diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 55a3f39a2e..4c31358684 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -286,20 +286,12 @@ class InputBatch: fill_value=max_draft_token_num, dtype="int32", ) - if current_platform.is_cuda(): - self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32") - self.batch_id_per_token_output = paddle.full( - shape=[max_num_seqs * (max_draft_token_num + 1)], - fill_value=0, - dtype="int32", - ) - else: - self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.output_padding_offset = paddle.full( - shape=[max_num_seqs * (max_draft_token_num + 1)], - fill_value=0, - dtype="int32", - ) + self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32") + self.batch_id_per_token_output = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], + fill_value=0, + dtype="int32", + ) # For V1_KVCACHE_SCHEDULER self.step_draft_tokens = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], @@ -437,7 +429,7 @@ class InputBatch: if current_platform.is_cuda(): swap_data(self.cu_seqlens_q_output, i1, i2) else: - swap_data(self.output_cum_offsets, i1, i2) + swap_data(self.cu_seqlens_q_output, i1, i2) swap_data(self.step_draft_tokens, i1, i2) swap_data(self.step_seq_lens_this_time, i1, i2) swap_data(self.draft_logits, i1, i2) @@ -628,8 +620,8 @@ class InputBatch: fill_paddle_tensor(self, "accept_num", 0) fill_paddle_tensor(self, "draft_tokens", -1) fill_paddle_tensor(self, "actual_draft_token_num", max_draft_token_num) - fill_paddle_tensor(self, "output_cum_offsets", 0) - fill_paddle_tensor(self, "output_padding_offset", 0) + fill_paddle_tensor(self, "cu_seqlens_q_output", 0) + fill_paddle_tensor(self, "batch_id_per_token_output", 0) fill_paddle_tensor(self, "step_draft_tokens", 0) fill_paddle_tensor(self, "step_seq_lens_this_time", 0) fill_paddle_tensor(self, "draft_logits", -1) @@ -742,8 +734,8 @@ class ProposerInputBatch(InputBatch): self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.token_ids_all = None else: - self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"]) - self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"]) + self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"]) + self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"]) self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"]) self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"]) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 201d503649..0e3e6b4a3e 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -135,8 +135,8 @@ class XPUModelRunner(ModelRunnerBase): self.encoder_cache = None self.device_id = device_id - self.speculative_method = self.fd_config.speculative_config.method - self.speculative_decoding = self.speculative_method is not None + self.spec_method = self.fd_config.speculative_config.method + self.speculative_decoding = self.spec_method is not None # used by SamplingMetadata self.enable_logprob = fd_config.model_config.enable_logprob # fd_config.model_config.enable_logprob @@ -728,7 +728,7 @@ class XPUModelRunner(ModelRunnerBase): if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.proposer.insert_tasks_v1(req_dicts, num_running_requests) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): @@ -877,7 +877,7 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["not_need_stop"][0] = True - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request( request, "temp_scaled_logprobs", False ) @@ -1070,12 +1070,18 @@ class XPUModelRunner(ModelRunnerBase): fill_value=max_draft_token_num, dtype="int32", ) - self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.share_inputs["output_padding_offset"] = paddle.full( + self.share_inputs["cu_seqlens_q_output"] = paddle.full( + shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32" + ) + self.share_inputs["batch_id_per_token_output"] = paddle.full( shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0, dtype="int32", ) + # reasoning_status: per-sequence reasoning phase indicator + # 0=thinking, 1=emitting boundary, 2=response, 3=end + # verify_draft_tokens 在 reasoning_status==1 时强制拒绝所有 draft token + self.share_inputs["reasoning_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") # For V1_KVCACHE_SCHEDULER self.share_inputs["step_draft_tokens"] = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], @@ -1439,7 +1445,7 @@ class XPUModelRunner(ModelRunnerBase): block_num=block_num, ) - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.proposer.dummy_prefill_inputs( num_tokens=num_tokens, batch_size=batch_size, @@ -1456,19 +1462,16 @@ class XPUModelRunner(ModelRunnerBase): """ Init speculative proposer """ - if self.speculative_method == SpecMethod.NGRAM: - # xpu not support ngram proposer now - self.proposer = None - elif self.speculative_method == SpecMethod.MTP: - self.proposer = self.speculative_method.create_proposer( - self.fd_config, - main_model=self.get_model(), - local_rank=self.local_rank, - device_id=self.device_id, - share_inputs=self.share_inputs, - ) - else: + if self.spec_method is None: self.proposer = None + return + self.proposer = self.spec_method.create_proposer( + self.fd_config, + main_model=self.get_model(), + local_rank=self.local_rank, + device_id=self.device_id, + share_inputs=self.share_inputs, + ) def _set_debug_level( self, debug_level: int = 0x1, model_forward_batch: Optional[List[Request]] = None, is_dummy_run: bool = False @@ -1670,6 +1673,8 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs, self.parallel_config.data_parallel_size > 1, skip_save_output, + is_naive_mode=(self.speculative_decoding and self.proposer is None), + prefill_one_step_stop=self.parallel_config.prefill_one_step_stop, ) else: xpu_post_process_normal( @@ -1685,8 +1690,11 @@ class XPUModelRunner(ModelRunnerBase): ) # 6. Draft model propose - if self.speculative_method == SpecMethod.MTP: - self.proposer.run(full_hidden_states=model_output) + if self.speculative_decoding and self.proposer is not None: + if self.spec_method == SpecMethod.MTP: + self.proposer.run(full_hidden_states=model_output) + else: + self.proposer.run(share_inputs=self.share_inputs) # 7. Updata 'infer_seed' and step_paddle() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) @@ -1738,7 +1746,7 @@ class XPUModelRunner(ModelRunnerBase): """Execute a forward pass with dummy inputs to profile the memory usage of the model""" self.num_gpu_blocks = self.cache_config.total_block_num - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) self.initialize_kv_cache(profile=True) @@ -1760,7 +1768,7 @@ class XPUModelRunner(ModelRunnerBase): self.num_gpu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) self.initialize_kv_cache()