[XPU] Unify Spec and non-spec branch.(#6947) (#7180)

* [XPU] cherry-pick PR-6947

* [XPU] use unified_update_model_status.

* refactor xpu_model_runner.

* refactor sampler.

* fix codestyle.

* Fix XPU speculative decoding: rename output tensors to cu_seqlens_q_output/batch_id_per_token_output, correct
  WRAPPER_CHECK_PTR types, and fix dynamic gather shape in verify_draft_tokens path.

* fix codestyle.

* replace output_padding_offset with is_speculative flag in gather_next_token.

* rename hiddden_states.

* unify cu_seqlens_q_output and batch_id_per_token_output init.

---------

Co-authored-by: cmcamdy <1027740945@qq.com>
This commit is contained in:
Jiajun Ji
2026-04-16 14:58:38 +08:00
committed by GitHub
parent 17002edc47
commit 29495b2cf1
9 changed files with 226 additions and 149 deletions
@@ -32,7 +32,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
bool is_speculative,
int max_bsz) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
@@ -73,7 +73,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};
paddle::Tensor out;
if (output_padding_offset) {
if (is_speculative) {
int need_delete_token_num = 0;
if (enc_batch > 0) {
need_delete_token_num =
@@ -88,7 +88,7 @@ std::vector<paddle::Tensor> GatherNextToken(
return {out};
}
if (output_padding_offset) {
if (is_speculative) {
int r = fastdeploy::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
@@ -124,14 +124,10 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
const std::vector<int64_t>& encoder_batch_map_cpu_shape,
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
const std::vector<int64_t>& len_info_cpu_shape,
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
// if (output_padding_offset_shape) {
// PD_THROW("speculative decoding is not supported in XPU.");
// }
// int64_t bsz = cum_offsets_shape[0];
bool is_speculative) {
int64_t bsz = 0;
int64_t dim_embed = x_shape[1];
if (output_padding_offset_shape) {
if (is_speculative) {
return {{-1, dim_embed}};
} else {
return {{bsz, dim_embed}};
@@ -148,8 +144,7 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
const paddle::DataType& decoder_seq_lod_cpu_dtype,
const paddle::DataType& encoder_batch_map_cpu_dtype,
const paddle::DataType& decoder_batch_map_cpu_dtype,
const paddle::DataType& len_info_cpu_dtype,
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
const paddle::DataType& len_info_cpu_dtype) {
return {x_dtype};
}
@@ -163,10 +158,9 @@ PD_BUILD_STATIC_OP(gather_next_token)
"decoder_seq_lod_cpu",
"encoder_batch_map_cpu",
"decoder_batch_map_cpu",
"len_info_cpu",
paddle::Optional("output_padding_offset")})
"len_info_cpu"})
.Outputs({"out"})
.Attrs({"max_bsz: int"})
.Attrs({"is_speculative: bool", "max_bsz: int"})
.SetKernelFn(PD_KERNEL(GatherNextToken))
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
+28 -2
View File
@@ -465,7 +465,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
bool is_speculative,
int max_bsz);
std::vector<paddle::Tensor> GetImgBoundaries(
@@ -1035,7 +1035,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("len_info_cpu"),
py::arg("output_padding_offset"),
py::arg("is_speculative"),
py::arg("max_bsz"),
"Gather next token for XPU");
@@ -1164,6 +1164,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_draft_tokens"),
"Unified update model status");
m.def("verify_draft_tokens",
&VerifyDraftTokens,
py::arg("step_output_ids"),
py::arg("step_output_len"),
py::arg("step_input_ids"),
py::arg("target_tokens"),
py::arg("candidate_ids"),
py::arg("candidate_scores"),
py::arg("candidate_lens"),
py::arg("topp"),
py::arg("stop_flags"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_this_time"),
py::arg("end_tokens"),
py::arg("is_block_step"),
py::arg("cu_seqlens_q_output"),
py::arg("reasoning_status"),
py::arg("max_dec_len"),
py::arg("step_idx"),
py::arg("max_seq_len"),
py::arg("verify_window"),
py::arg("verify_strategy"),
py::arg("reject_all"),
py::arg("accept_all"),
"Perform speculative verification for decoding v2");
m.def("mtp_step_paddle",
&MTPStepPaddle,
py::arg("base_model_stop_flags"),
@@ -766,7 +766,6 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode);
DLL_EXPORT int verify_draft_tokens(
api::Context* ctx,
// Core I/O
@@ -24,7 +24,7 @@ from fastdeploy.model_executor.ops.xpu import (
)
def _run_test_base(seq_lens_this_time_data, output_padding_offset):
def _run_test_base(seq_lens_this_time_data, is_speculative):
"""
通用的基础测试执行函数,包含了两个场景共有的逻辑。
"""
@@ -120,7 +120,7 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
encoder_batch_map_cpu,
decoder_batch_map_cpu,
len_info_cpu,
output_padding_offset,
is_speculative,
-1,
)
@@ -136,14 +136,14 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
encoder_batch_map_cpu,
decoder_batch_map_cpu,
len_info_cpu,
output_padding_offset,
is_speculative,
-1,
)
gather_out_np = gather_out.astype("float32").cpu().numpy()
gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy()
if output_padding_offset is not None:
if is_speculative:
np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!")
else:
for i in range(gather_out_cpu.shape[0]):
@@ -160,19 +160,14 @@ class TestXPUOps(unittest.TestCase): # 继承 unittest.TestCase
"""测试混合批次处理中的 MTP (Multi-Token Prediction) 场景"""
print("\nRunning test: test_mix_with_mtp")
seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3]
bsz = len(seq_lens_this_time_data)
output_padding_offset = paddle.zeros(bsz, dtype="int32")
_run_test_base(seq_lens_this_time_data, output_padding_offset)
_run_test_base(seq_lens_this_time_data, True)
print("Test passed for scenario: With MTP")
def test_mix_without_mtp(self):
"""测试非 MTP (Single-Token Prediction) 场景下的功能"""
print("\nRunning test: test_mix_without_mtp")
seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1]
output_padding_offset = None # 非 MTP 场景下,此参数为 None
_run_test_base(seq_lens_this_time_data, output_padding_offset)
_run_test_base(seq_lens_this_time_data, False)
print("Test passed for scenario: Without MTP")