[XPU] Unify Spec and non-spec branch.(#6947) (#7180)

* [XPU] cherry-pick PR-6947

* [XPU] use unified_update_model_status.

* refactor xpu_model_runner.

* refactor sampler.

* fix codestyle.

* Fix XPU speculative decoding: rename output tensors to cu_seqlens_q_output/batch_id_per_token_output, correct
  WRAPPER_CHECK_PTR types, and fix dynamic gather shape in verify_draft_tokens path.

* fix codestyle.

* replace output_padding_offset with is_speculative flag in gather_next_token.

* rename hiddden_states.

* unify cu_seqlens_q_output and batch_id_per_token_output init.

---------

Co-authored-by: cmcamdy <1027740945@qq.com>
This commit is contained in:
Jiajun Ji
2026-04-16 14:58:38 +08:00
committed by GitHub
parent 17002edc47
commit 29495b2cf1
9 changed files with 226 additions and 149 deletions
@@ -32,7 +32,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& encoder_batch_map_cpu, const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& len_info_cpu, const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset, bool is_speculative,
int max_bsz) { int max_bsz) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
@@ -73,7 +73,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())}; const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};
paddle::Tensor out; paddle::Tensor out;
if (output_padding_offset) { if (is_speculative) {
int need_delete_token_num = 0; int need_delete_token_num = 0;
if (enc_batch > 0) { if (enc_batch > 0) {
need_delete_token_num = need_delete_token_num =
@@ -88,7 +88,7 @@ std::vector<paddle::Tensor> GatherNextToken(
return {out}; return {out};
} }
if (output_padding_offset) { if (is_speculative) {
int r = fastdeploy::plugin::eb_mtp_gather_next_token<XPUType, XPUType>( int r = fastdeploy::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
ctx, ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()), reinterpret_cast<const XPUType*>(x.data<data_t>()),
@@ -124,14 +124,10 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
const std::vector<int64_t>& encoder_batch_map_cpu_shape, const std::vector<int64_t>& encoder_batch_map_cpu_shape,
const std::vector<int64_t>& decoder_batch_map_cpu_shape, const std::vector<int64_t>& decoder_batch_map_cpu_shape,
const std::vector<int64_t>& len_info_cpu_shape, const std::vector<int64_t>& len_info_cpu_shape,
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) { bool is_speculative) {
// if (output_padding_offset_shape) {
// PD_THROW("speculative decoding is not supported in XPU.");
// }
// int64_t bsz = cum_offsets_shape[0];
int64_t bsz = 0; int64_t bsz = 0;
int64_t dim_embed = x_shape[1]; int64_t dim_embed = x_shape[1];
if (output_padding_offset_shape) { if (is_speculative) {
return {{-1, dim_embed}}; return {{-1, dim_embed}};
} else { } else {
return {{bsz, dim_embed}}; return {{bsz, dim_embed}};
@@ -148,8 +144,7 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
const paddle::DataType& decoder_seq_lod_cpu_dtype, const paddle::DataType& decoder_seq_lod_cpu_dtype,
const paddle::DataType& encoder_batch_map_cpu_dtype, const paddle::DataType& encoder_batch_map_cpu_dtype,
const paddle::DataType& decoder_batch_map_cpu_dtype, const paddle::DataType& decoder_batch_map_cpu_dtype,
const paddle::DataType& len_info_cpu_dtype, const paddle::DataType& len_info_cpu_dtype) {
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
return {x_dtype}; return {x_dtype};
} }
@@ -163,10 +158,9 @@ PD_BUILD_STATIC_OP(gather_next_token)
"decoder_seq_lod_cpu", "decoder_seq_lod_cpu",
"encoder_batch_map_cpu", "encoder_batch_map_cpu",
"decoder_batch_map_cpu", "decoder_batch_map_cpu",
"len_info_cpu", "len_info_cpu"})
paddle::Optional("output_padding_offset")})
.Outputs({"out"}) .Outputs({"out"})
.Attrs({"max_bsz: int"}) .Attrs({"is_speculative: bool", "max_bsz: int"})
.SetKernelFn(PD_KERNEL(GatherNextToken)) .SetKernelFn(PD_KERNEL(GatherNextToken))
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype)); .SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
+28 -2
View File
@@ -465,7 +465,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& encoder_batch_map_cpu, const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& len_info_cpu, const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset, bool is_speculative,
int max_bsz); int max_bsz);
std::vector<paddle::Tensor> GetImgBoundaries( std::vector<paddle::Tensor> GetImgBoundaries(
@@ -1035,7 +1035,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("encoder_batch_map_cpu"), py::arg("encoder_batch_map_cpu"),
py::arg("decoder_batch_map_cpu"), py::arg("decoder_batch_map_cpu"),
py::arg("len_info_cpu"), py::arg("len_info_cpu"),
py::arg("output_padding_offset"), py::arg("is_speculative"),
py::arg("max_bsz"), py::arg("max_bsz"),
"Gather next token for XPU"); "Gather next token for XPU");
@@ -1164,6 +1164,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_draft_tokens"), py::arg("max_draft_tokens"),
"Unified update model status"); "Unified update model status");
m.def("verify_draft_tokens",
&VerifyDraftTokens,
py::arg("step_output_ids"),
py::arg("step_output_len"),
py::arg("step_input_ids"),
py::arg("target_tokens"),
py::arg("candidate_ids"),
py::arg("candidate_scores"),
py::arg("candidate_lens"),
py::arg("topp"),
py::arg("stop_flags"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_this_time"),
py::arg("end_tokens"),
py::arg("is_block_step"),
py::arg("cu_seqlens_q_output"),
py::arg("reasoning_status"),
py::arg("max_dec_len"),
py::arg("step_idx"),
py::arg("max_seq_len"),
py::arg("verify_window"),
py::arg("verify_strategy"),
py::arg("reject_all"),
py::arg("accept_all"),
"Perform speculative verification for decoding v2");
m.def("mtp_step_paddle", m.def("mtp_step_paddle",
&MTPStepPaddle, &MTPStepPaddle,
py::arg("base_model_stop_flags"), py::arg("base_model_stop_flags"),
@@ -766,7 +766,6 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
const int eos_token_id_len, const int eos_token_id_len,
const int inject_len, const int inject_len,
const bool splitwise_role_is_decode); const bool splitwise_role_is_decode);
DLL_EXPORT int verify_draft_tokens( DLL_EXPORT int verify_draft_tokens(
api::Context* ctx, api::Context* ctx,
// Core I/O // Core I/O
@@ -24,7 +24,7 @@ from fastdeploy.model_executor.ops.xpu import (
) )
def _run_test_base(seq_lens_this_time_data, output_padding_offset): def _run_test_base(seq_lens_this_time_data, is_speculative):
""" """
通用的基础测试执行函数,包含了两个场景共有的逻辑。 通用的基础测试执行函数,包含了两个场景共有的逻辑。
""" """
@@ -120,7 +120,7 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
encoder_batch_map_cpu, encoder_batch_map_cpu,
decoder_batch_map_cpu, decoder_batch_map_cpu,
len_info_cpu, len_info_cpu,
output_padding_offset, is_speculative,
-1, -1,
) )
@@ -136,14 +136,14 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
encoder_batch_map_cpu, encoder_batch_map_cpu,
decoder_batch_map_cpu, decoder_batch_map_cpu,
len_info_cpu, len_info_cpu,
output_padding_offset, is_speculative,
-1, -1,
) )
gather_out_np = gather_out.astype("float32").cpu().numpy() gather_out_np = gather_out.astype("float32").cpu().numpy()
gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy() gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy()
if output_padding_offset is not None: if is_speculative:
np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!") np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!")
else: else:
for i in range(gather_out_cpu.shape[0]): for i in range(gather_out_cpu.shape[0]):
@@ -160,19 +160,14 @@ class TestXPUOps(unittest.TestCase): # 继承 unittest.TestCase
"""测试混合批次处理中的 MTP (Multi-Token Prediction) 场景""" """测试混合批次处理中的 MTP (Multi-Token Prediction) 场景"""
print("\nRunning test: test_mix_with_mtp") print("\nRunning test: test_mix_with_mtp")
seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3] seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3]
bsz = len(seq_lens_this_time_data) _run_test_base(seq_lens_this_time_data, True)
output_padding_offset = paddle.zeros(bsz, dtype="int32")
_run_test_base(seq_lens_this_time_data, output_padding_offset)
print("Test passed for scenario: With MTP") print("Test passed for scenario: With MTP")
def test_mix_without_mtp(self): def test_mix_without_mtp(self):
"""测试非 MTP (Single-Token Prediction) 场景下的功能""" """测试非 MTP (Single-Token Prediction) 场景下的功能"""
print("\nRunning test: test_mix_without_mtp") print("\nRunning test: test_mix_without_mtp")
seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1] seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1]
output_padding_offset = None # 非 MTP 场景下,此参数为 None _run_test_base(seq_lens_this_time_data, False)
_run_test_base(seq_lens_this_time_data, output_padding_offset)
print("Test passed for scenario: Without MTP") print("Test passed for scenario: Without MTP")
@@ -275,6 +275,7 @@ class XPUForwardMeta(ForwardMeta):
hidden_states: Optional[paddle.Tensor] = None hidden_states: Optional[paddle.Tensor] = None
is_draft: bool = False is_draft: bool = False
is_speculative: bool = False
# max bs # max bs
max_num_seqs: int = 0 max_num_seqs: int = 0
@@ -1045,6 +1045,120 @@ class SpeculativeSampler(nn.Layer):
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu() sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
return sampler_output return sampler_output
def _normal_sample_xpu(
self,
logits: paddle.Tensor,
probs: paddle.Tensor,
sampling_metadata: SamplingMetadata,
share_inputs: List[paddle.Tensor],
) -> SamplerOutput:
"""Normal sampling for NAIVE mode on XPU."""
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, next_tokens = top_k_top_p_sampling(
probs,
top_p=top_p,
top_k=top_k,
top_k_list=sampling_metadata.top_k_list,
topp_seed=topp_seed,
)
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
running_mask = (paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]) > 0).cast("int32")
share_inputs["accept_tokens"][:real_bsz, 0] = next_tokens.squeeze(-1)
share_inputs["accept_num"][:real_bsz] = running_mask
return SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
logits=logits,
)
def _verify_and_sample_xpu(
self,
logits: paddle.Tensor,
probs: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> SamplerOutput:
"""Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens."""
from fastdeploy.model_executor.ops.xpu import (
top_p_candidates,
verify_draft_tokens,
)
target_tokens = None
candidate_ids, candidate_scores, candidate_lens = None, None, None
if self.verify_strategy == VerifyStrategy.TARGET_MATCH:
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, target_tokens = top_k_top_p_sampling(
probs,
top_p=top_p,
top_k=top_k,
top_k_list=sampling_metadata.top_k_list,
topp_seed=topp_seed,
)
elif self.verify_strategy == VerifyStrategy.GREEDY:
target_tokens = paddle.argmax(probs, axis=-1)
elif self.verify_strategy == VerifyStrategy.TOPP:
candidate_scores, candidate_ids, candidate_lens = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len,
max_model_len,
)
else:
raise ValueError(f"Unknown verify strategy: {self.verify_strategy}")
final_accept_all = self.config_accept_all or accept_all_drafts
final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode
verify_draft_tokens(
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["draft_tokens"],
target_tokens,
candidate_ids,
candidate_scores,
candidate_lens,
sampling_metadata.top_p,
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_this_time"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["cu_seqlens_q_output"],
share_inputs["reasoning_status"],
share_inputs["max_dec_len"],
share_inputs["step_idx"],
max_model_len,
self.speculative_verify_window,
self.verify_strategy.value,
final_reject_all,
final_accept_all,
)
return SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
logits=logits,
)
def forward_xpu( def forward_xpu(
self, self,
logits: paddle.Tensor, logits: paddle.Tensor,
@@ -1053,9 +1167,7 @@ class SpeculativeSampler(nn.Layer):
share_inputs: List[paddle.Tensor], share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False, accept_all_drafts: bool = False,
reject_all_drafts: bool = False, reject_all_drafts: bool = False,
) -> paddle.Tensor: ) -> SamplerOutput:
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
logits = apply_speculative_penalty_multi_scores( logits = apply_speculative_penalty_multi_scores(
sampling_metadata.token_ids_all, sampling_metadata.token_ids_all,
sampling_metadata.prompt_lens, sampling_metadata.prompt_lens,
@@ -1078,61 +1190,19 @@ class SpeculativeSampler(nn.Layer):
probs = F.softmax(logits) probs = F.softmax(logits)
top_p, top_k, topp_seed = padding_sampling_params( is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE
sampling_metadata.top_p, if is_naive:
sampling_metadata.top_k, return self._normal_sample_xpu(logits, probs, sampling_metadata, share_inputs)
sampling_metadata.seed, else:
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), return self._verify_and_sample_xpu(
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), logits,
) probs,
_, sampled_token_ids = top_k_top_p_sampling( sampling_metadata,
probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed max_model_len,
) share_inputs,
accept_all_drafts,
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( reject_all_drafts,
probs, )
sampling_metadata.top_p,
share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len,
max_model_len,
)
speculate_verify(
sampled_token_ids,
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["step_idx"],
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs[
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
share_inputs["seq_lens_this_time"],
verify_tokens,
verify_scores,
share_inputs["max_dec_len"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["cu_seqlens_q_output"],
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
max_model_len,
self.speculative_verify_window,
True, # enable_topp
(self.speculative_benchmark_mode or reject_all_drafts),
accept_all_drafts,
)
# TODO(chenhuan09): support return logprobs
token_ids = share_inputs["accept_tokens"]
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=None,
)
return sampler_output
class MTPSampler(nn.Layer): class MTPSampler(nn.Layer):
@@ -43,12 +43,11 @@ if current_platform.is_xpu():
speculate_pre_process, speculate_pre_process,
speculate_save_output, speculate_save_output,
speculate_set_stop_value_multi_seqs, speculate_set_stop_value_multi_seqs,
speculate_set_value_by_flags_and_idx,
speculate_step_paddle, speculate_step_paddle,
speculate_step_reschedule, speculate_step_reschedule,
speculate_step_system_cache, speculate_step_system_cache,
speculate_update,
step_paddle, step_paddle,
unified_update_model_status,
update_inputs, update_inputs,
update_inputs_v1, update_inputs_v1,
) )
@@ -172,6 +171,7 @@ def xpu_pre_process(
block_tables=share_inputs["block_tables"], block_tables=share_inputs["block_tables"],
caches=share_inputs["caches"], caches=share_inputs["caches"],
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0], max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
is_speculative=use_speculate_method,
) )
( (
@@ -249,11 +249,6 @@ def xpu_process_output(
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """ """ """
if isinstance(share_inputs, dict):
output_padding_offset = share_inputs.get("output_padding_offset", None)
else:
output_padding_offset = getattr(share_inputs, "output_padding_offset", None)
hidden_states = gather_next_token( hidden_states = gather_next_token(
forward_output, forward_output,
xpu_forward_meta.encoder_seq_lod, xpu_forward_meta.encoder_seq_lod,
@@ -265,7 +260,7 @@ def xpu_process_output(
xpu_forward_meta.encoder_batch_map_cpu, xpu_forward_meta.encoder_batch_map_cpu,
xpu_forward_meta.decoder_batch_map_cpu, xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.len_info_cpu, xpu_forward_meta.len_info_cpu,
output_padding_offset, # output_padding_offset xpu_forward_meta.is_speculative,
xpu_forward_meta.max_num_seqs, xpu_forward_meta.max_num_seqs,
) )
return hidden_states return hidden_states
@@ -416,6 +411,8 @@ def xpu_post_process_specualate(
share_inputs: Dict[str, paddle.Tensor], share_inputs: Dict[str, paddle.Tensor],
save_each_rank: bool = False, save_each_rank: bool = False,
skip_save_output: bool = False, skip_save_output: bool = False,
is_naive_mode: bool = False,
prefill_one_step_stop: bool = False,
): ):
"""""" """"""
@@ -432,7 +429,7 @@ def xpu_post_process_specualate(
model_output.min_tokens, model_output.min_tokens,
) )
speculate_update( unified_update_model_status(
model_output.seq_lens_encoder, model_output.seq_lens_encoder,
model_output.seq_lens_decoder, model_output.seq_lens_decoder,
model_output.not_need_stop, model_output.not_need_stop,
@@ -444,6 +441,13 @@ def xpu_post_process_specualate(
model_output.seq_lens_this_time, model_output.seq_lens_this_time,
model_output.is_block_step, model_output.is_block_step,
model_output.mask_rollback, model_output.mask_rollback,
model_output.pre_ids,
model_output.prompt_lens,
model_output.step_idx,
model_output.eos_token_id,
model_output.max_dec_len,
is_naive_mode,
prefill_one_step_stop,
) )
if not skip_save_output: if not skip_save_output:
if sampler_output.logprobs_tensors is None: if sampler_output.logprobs_tensors is None:
@@ -464,18 +468,6 @@ def xpu_post_process_specualate(
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder) speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
# Update pre_ids through accept tokens
speculate_set_value_by_flags_and_idx(
model_output.pre_ids,
model_output.accept_tokens,
model_output.accept_num,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.step_idx,
)
def step_xpu( def step_xpu(
share_inputs: Dict[str, paddle.Tensor], share_inputs: Dict[str, paddle.Tensor],
+11 -19
View File
@@ -286,20 +286,12 @@ class InputBatch:
fill_value=max_draft_token_num, fill_value=max_draft_token_num,
dtype="int32", dtype="int32",
) )
if current_platform.is_cuda(): self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32") self.batch_id_per_token_output = paddle.full(
self.batch_id_per_token_output = paddle.full( shape=[max_num_seqs * (max_draft_token_num + 1)],
shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0,
fill_value=0, dtype="int32",
dtype="int32", )
)
else:
self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.output_padding_offset = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0,
dtype="int32",
)
# For V1_KVCACHE_SCHEDULER # For V1_KVCACHE_SCHEDULER
self.step_draft_tokens = paddle.full( self.step_draft_tokens = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1], shape=[max_num_seqs, max_draft_token_num + 1],
@@ -437,7 +429,7 @@ class InputBatch:
if current_platform.is_cuda(): if current_platform.is_cuda():
swap_data(self.cu_seqlens_q_output, i1, i2) swap_data(self.cu_seqlens_q_output, i1, i2)
else: else:
swap_data(self.output_cum_offsets, i1, i2) swap_data(self.cu_seqlens_q_output, i1, i2)
swap_data(self.step_draft_tokens, i1, i2) swap_data(self.step_draft_tokens, i1, i2)
swap_data(self.step_seq_lens_this_time, i1, i2) swap_data(self.step_seq_lens_this_time, i1, i2)
swap_data(self.draft_logits, i1, i2) swap_data(self.draft_logits, i1, i2)
@@ -628,8 +620,8 @@ class InputBatch:
fill_paddle_tensor(self, "accept_num", 0) fill_paddle_tensor(self, "accept_num", 0)
fill_paddle_tensor(self, "draft_tokens", -1) fill_paddle_tensor(self, "draft_tokens", -1)
fill_paddle_tensor(self, "actual_draft_token_num", max_draft_token_num) fill_paddle_tensor(self, "actual_draft_token_num", max_draft_token_num)
fill_paddle_tensor(self, "output_cum_offsets", 0) fill_paddle_tensor(self, "cu_seqlens_q_output", 0)
fill_paddle_tensor(self, "output_padding_offset", 0) fill_paddle_tensor(self, "batch_id_per_token_output", 0)
fill_paddle_tensor(self, "step_draft_tokens", 0) fill_paddle_tensor(self, "step_draft_tokens", 0)
fill_paddle_tensor(self, "step_seq_lens_this_time", 0) fill_paddle_tensor(self, "step_seq_lens_this_time", 0)
fill_paddle_tensor(self, "draft_logits", -1) fill_paddle_tensor(self, "draft_logits", -1)
@@ -742,8 +734,8 @@ class ProposerInputBatch(InputBatch):
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
self.token_ids_all = None self.token_ids_all = None
else: else:
self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"]) self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"]) self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"]) self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"]) self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])
+31 -23
View File
@@ -135,8 +135,8 @@ class XPUModelRunner(ModelRunnerBase):
self.encoder_cache = None self.encoder_cache = None
self.device_id = device_id self.device_id = device_id
self.speculative_method = self.fd_config.speculative_config.method self.spec_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.speculative_method is not None self.speculative_decoding = self.spec_method is not None
# used by SamplingMetadata # used by SamplingMetadata
self.enable_logprob = fd_config.model_config.enable_logprob # fd_config.model_config.enable_logprob self.enable_logprob = fd_config.model_config.enable_logprob # fd_config.model_config.enable_logprob
@@ -728,7 +728,7 @@ class XPUModelRunner(ModelRunnerBase):
if has_prefill_task or has_decode_task: if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
if self.speculative_method == SpecMethod.MTP: if self.spec_method == SpecMethod.MTP:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests) self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
@@ -877,7 +877,7 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
if self.speculative_method == SpecMethod.MTP: if self.spec_method == SpecMethod.MTP:
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request( self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request(
request, "temp_scaled_logprobs", False request, "temp_scaled_logprobs", False
) )
@@ -1070,12 +1070,18 @@ class XPUModelRunner(ModelRunnerBase):
fill_value=max_draft_token_num, fill_value=max_draft_token_num,
dtype="int32", dtype="int32",
) )
self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") self.share_inputs["cu_seqlens_q_output"] = paddle.full(
self.share_inputs["output_padding_offset"] = paddle.full( shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32"
)
self.share_inputs["batch_id_per_token_output"] = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)], shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0, fill_value=0,
dtype="int32", dtype="int32",
) )
# reasoning_status: per-sequence reasoning phase indicator
# 0=thinking, 1=emitting boundary, 2=response, 3=end
# verify_draft_tokens 在 reasoning_status==1 时强制拒绝所有 draft token
self.share_inputs["reasoning_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
# For V1_KVCACHE_SCHEDULER # For V1_KVCACHE_SCHEDULER
self.share_inputs["step_draft_tokens"] = paddle.full( self.share_inputs["step_draft_tokens"] = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1], shape=[max_num_seqs, max_draft_token_num + 1],
@@ -1439,7 +1445,7 @@ class XPUModelRunner(ModelRunnerBase):
block_num=block_num, block_num=block_num,
) )
if self.speculative_method == SpecMethod.MTP: if self.spec_method == SpecMethod.MTP:
self.proposer.dummy_prefill_inputs( self.proposer.dummy_prefill_inputs(
num_tokens=num_tokens, num_tokens=num_tokens,
batch_size=batch_size, batch_size=batch_size,
@@ -1456,19 +1462,16 @@ class XPUModelRunner(ModelRunnerBase):
""" """
Init speculative proposer Init speculative proposer
""" """
if self.speculative_method == SpecMethod.NGRAM: if self.spec_method is None:
# xpu not support ngram proposer now
self.proposer = None
elif self.speculative_method == SpecMethod.MTP:
self.proposer = self.speculative_method.create_proposer(
self.fd_config,
main_model=self.get_model(),
local_rank=self.local_rank,
device_id=self.device_id,
share_inputs=self.share_inputs,
)
else:
self.proposer = None self.proposer = None
return
self.proposer = self.spec_method.create_proposer(
self.fd_config,
main_model=self.get_model(),
local_rank=self.local_rank,
device_id=self.device_id,
share_inputs=self.share_inputs,
)
def _set_debug_level( def _set_debug_level(
self, debug_level: int = 0x1, model_forward_batch: Optional[List[Request]] = None, is_dummy_run: bool = False self, debug_level: int = 0x1, model_forward_batch: Optional[List[Request]] = None, is_dummy_run: bool = False
@@ -1670,6 +1673,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs, self.share_inputs,
self.parallel_config.data_parallel_size > 1, self.parallel_config.data_parallel_size > 1,
skip_save_output, skip_save_output,
is_naive_mode=(self.speculative_decoding and self.proposer is None),
prefill_one_step_stop=self.parallel_config.prefill_one_step_stop,
) )
else: else:
xpu_post_process_normal( xpu_post_process_normal(
@@ -1685,8 +1690,11 @@ class XPUModelRunner(ModelRunnerBase):
) )
# 6. Draft model propose # 6. Draft model propose
if self.speculative_method == SpecMethod.MTP: if self.speculative_decoding and self.proposer is not None:
self.proposer.run(full_hidden_states=model_output) if self.spec_method == SpecMethod.MTP:
self.proposer.run(full_hidden_states=model_output)
else:
self.proposer.run(share_inputs=self.share_inputs)
# 7. Updata 'infer_seed' and step_paddle() # 7. Updata 'infer_seed' and step_paddle()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
@@ -1738,7 +1746,7 @@ class XPUModelRunner(ModelRunnerBase):
"""Execute a forward pass with dummy inputs to profile the memory usage of the model""" """Execute a forward pass with dummy inputs to profile the memory usage of the model"""
self.num_gpu_blocks = self.cache_config.total_block_num self.num_gpu_blocks = self.cache_config.total_block_num
if self.speculative_method == SpecMethod.MTP: if self.spec_method == SpecMethod.MTP:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
self.initialize_kv_cache(profile=True) self.initialize_kv_cache(profile=True)
@@ -1760,7 +1768,7 @@ class XPUModelRunner(ModelRunnerBase):
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
# Reset block table and kv cache with global block num # Reset block table and kv cache with global block num
if self.speculative_method == SpecMethod.MTP: if self.spec_method == SpecMethod.MTP:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
self.initialize_kv_cache() self.initialize_kv_cache()