mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
* [XPU] cherry-pick PR-6947 * [XPU] use unified_update_model_status. * refactor xpu_model_runner. * refactor sampler. * fix codestyle. * Fix XPU speculative decoding: rename output tensors to cu_seqlens_q_output/batch_id_per_token_output, correct WRAPPER_CHECK_PTR types, and fix dynamic gather shape in verify_draft_tokens path. * fix codestyle. * replace output_padding_offset with is_speculative flag in gather_next_token. * rename hiddden_states. * unify cu_seqlens_q_output and batch_id_per_token_output init. --------- Co-authored-by: cmcamdy <1027740945@qq.com>
This commit is contained in:
@@ -32,7 +32,7 @@ std::vector<paddle::Tensor> GatherNextToken(
|
|||||||
const paddle::Tensor& encoder_batch_map_cpu,
|
const paddle::Tensor& encoder_batch_map_cpu,
|
||||||
const paddle::Tensor& decoder_batch_map_cpu,
|
const paddle::Tensor& decoder_batch_map_cpu,
|
||||||
const paddle::Tensor& len_info_cpu,
|
const paddle::Tensor& len_info_cpu,
|
||||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
bool is_speculative,
|
||||||
int max_bsz) {
|
int max_bsz) {
|
||||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||||
@@ -73,7 +73,7 @@ std::vector<paddle::Tensor> GatherNextToken(
|
|||||||
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};
|
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};
|
||||||
|
|
||||||
paddle::Tensor out;
|
paddle::Tensor out;
|
||||||
if (output_padding_offset) {
|
if (is_speculative) {
|
||||||
int need_delete_token_num = 0;
|
int need_delete_token_num = 0;
|
||||||
if (enc_batch > 0) {
|
if (enc_batch > 0) {
|
||||||
need_delete_token_num =
|
need_delete_token_num =
|
||||||
@@ -88,7 +88,7 @@ std::vector<paddle::Tensor> GatherNextToken(
|
|||||||
return {out};
|
return {out};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output_padding_offset) {
|
if (is_speculative) {
|
||||||
int r = fastdeploy::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
|
int r = fastdeploy::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
|
||||||
ctx,
|
ctx,
|
||||||
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
reinterpret_cast<const XPUType*>(x.data<data_t>()),
|
||||||
@@ -124,14 +124,10 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
|
|||||||
const std::vector<int64_t>& encoder_batch_map_cpu_shape,
|
const std::vector<int64_t>& encoder_batch_map_cpu_shape,
|
||||||
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
|
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
|
||||||
const std::vector<int64_t>& len_info_cpu_shape,
|
const std::vector<int64_t>& len_info_cpu_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
|
bool is_speculative) {
|
||||||
// if (output_padding_offset_shape) {
|
|
||||||
// PD_THROW("speculative decoding is not supported in XPU.");
|
|
||||||
// }
|
|
||||||
// int64_t bsz = cum_offsets_shape[0];
|
|
||||||
int64_t bsz = 0;
|
int64_t bsz = 0;
|
||||||
int64_t dim_embed = x_shape[1];
|
int64_t dim_embed = x_shape[1];
|
||||||
if (output_padding_offset_shape) {
|
if (is_speculative) {
|
||||||
return {{-1, dim_embed}};
|
return {{-1, dim_embed}};
|
||||||
} else {
|
} else {
|
||||||
return {{bsz, dim_embed}};
|
return {{bsz, dim_embed}};
|
||||||
@@ -148,8 +144,7 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
|
|||||||
const paddle::DataType& decoder_seq_lod_cpu_dtype,
|
const paddle::DataType& decoder_seq_lod_cpu_dtype,
|
||||||
const paddle::DataType& encoder_batch_map_cpu_dtype,
|
const paddle::DataType& encoder_batch_map_cpu_dtype,
|
||||||
const paddle::DataType& decoder_batch_map_cpu_dtype,
|
const paddle::DataType& decoder_batch_map_cpu_dtype,
|
||||||
const paddle::DataType& len_info_cpu_dtype,
|
const paddle::DataType& len_info_cpu_dtype) {
|
||||||
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
|
|
||||||
return {x_dtype};
|
return {x_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,10 +158,9 @@ PD_BUILD_STATIC_OP(gather_next_token)
|
|||||||
"decoder_seq_lod_cpu",
|
"decoder_seq_lod_cpu",
|
||||||
"encoder_batch_map_cpu",
|
"encoder_batch_map_cpu",
|
||||||
"decoder_batch_map_cpu",
|
"decoder_batch_map_cpu",
|
||||||
"len_info_cpu",
|
"len_info_cpu"})
|
||||||
paddle::Optional("output_padding_offset")})
|
|
||||||
.Outputs({"out"})
|
.Outputs({"out"})
|
||||||
.Attrs({"max_bsz: int"})
|
.Attrs({"is_speculative: bool", "max_bsz: int"})
|
||||||
.SetKernelFn(PD_KERNEL(GatherNextToken))
|
.SetKernelFn(PD_KERNEL(GatherNextToken))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
|
||||||
|
|||||||
@@ -465,7 +465,7 @@ std::vector<paddle::Tensor> GatherNextToken(
|
|||||||
const paddle::Tensor& encoder_batch_map_cpu,
|
const paddle::Tensor& encoder_batch_map_cpu,
|
||||||
const paddle::Tensor& decoder_batch_map_cpu,
|
const paddle::Tensor& decoder_batch_map_cpu,
|
||||||
const paddle::Tensor& len_info_cpu,
|
const paddle::Tensor& len_info_cpu,
|
||||||
const paddle::optional<paddle::Tensor>& output_padding_offset,
|
bool is_speculative,
|
||||||
int max_bsz);
|
int max_bsz);
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GetImgBoundaries(
|
std::vector<paddle::Tensor> GetImgBoundaries(
|
||||||
@@ -1035,7 +1035,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("encoder_batch_map_cpu"),
|
py::arg("encoder_batch_map_cpu"),
|
||||||
py::arg("decoder_batch_map_cpu"),
|
py::arg("decoder_batch_map_cpu"),
|
||||||
py::arg("len_info_cpu"),
|
py::arg("len_info_cpu"),
|
||||||
py::arg("output_padding_offset"),
|
py::arg("is_speculative"),
|
||||||
py::arg("max_bsz"),
|
py::arg("max_bsz"),
|
||||||
"Gather next token for XPU");
|
"Gather next token for XPU");
|
||||||
|
|
||||||
@@ -1164,6 +1164,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("max_draft_tokens"),
|
py::arg("max_draft_tokens"),
|
||||||
"Unified update model status");
|
"Unified update model status");
|
||||||
|
|
||||||
|
m.def("verify_draft_tokens",
|
||||||
|
&VerifyDraftTokens,
|
||||||
|
py::arg("step_output_ids"),
|
||||||
|
py::arg("step_output_len"),
|
||||||
|
py::arg("step_input_ids"),
|
||||||
|
py::arg("target_tokens"),
|
||||||
|
py::arg("candidate_ids"),
|
||||||
|
py::arg("candidate_scores"),
|
||||||
|
py::arg("candidate_lens"),
|
||||||
|
py::arg("topp"),
|
||||||
|
py::arg("stop_flags"),
|
||||||
|
py::arg("seq_lens_encoder"),
|
||||||
|
py::arg("seq_lens_this_time"),
|
||||||
|
py::arg("end_tokens"),
|
||||||
|
py::arg("is_block_step"),
|
||||||
|
py::arg("cu_seqlens_q_output"),
|
||||||
|
py::arg("reasoning_status"),
|
||||||
|
py::arg("max_dec_len"),
|
||||||
|
py::arg("step_idx"),
|
||||||
|
py::arg("max_seq_len"),
|
||||||
|
py::arg("verify_window"),
|
||||||
|
py::arg("verify_strategy"),
|
||||||
|
py::arg("reject_all"),
|
||||||
|
py::arg("accept_all"),
|
||||||
|
"Perform speculative verification for decoding v2");
|
||||||
|
|
||||||
m.def("mtp_step_paddle",
|
m.def("mtp_step_paddle",
|
||||||
&MTPStepPaddle,
|
&MTPStepPaddle,
|
||||||
py::arg("base_model_stop_flags"),
|
py::arg("base_model_stop_flags"),
|
||||||
|
|||||||
@@ -766,7 +766,6 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
|
|||||||
const int eos_token_id_len,
|
const int eos_token_id_len,
|
||||||
const int inject_len,
|
const int inject_len,
|
||||||
const bool splitwise_role_is_decode);
|
const bool splitwise_role_is_decode);
|
||||||
|
|
||||||
DLL_EXPORT int verify_draft_tokens(
|
DLL_EXPORT int verify_draft_tokens(
|
||||||
api::Context* ctx,
|
api::Context* ctx,
|
||||||
// Core I/O
|
// Core I/O
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from fastdeploy.model_executor.ops.xpu import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _run_test_base(seq_lens_this_time_data, output_padding_offset):
|
def _run_test_base(seq_lens_this_time_data, is_speculative):
|
||||||
"""
|
"""
|
||||||
通用的基础测试执行函数,包含了两个场景共有的逻辑。
|
通用的基础测试执行函数,包含了两个场景共有的逻辑。
|
||||||
"""
|
"""
|
||||||
@@ -120,7 +120,7 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
|
|||||||
encoder_batch_map_cpu,
|
encoder_batch_map_cpu,
|
||||||
decoder_batch_map_cpu,
|
decoder_batch_map_cpu,
|
||||||
len_info_cpu,
|
len_info_cpu,
|
||||||
output_padding_offset,
|
is_speculative,
|
||||||
-1,
|
-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -136,14 +136,14 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
|
|||||||
encoder_batch_map_cpu,
|
encoder_batch_map_cpu,
|
||||||
decoder_batch_map_cpu,
|
decoder_batch_map_cpu,
|
||||||
len_info_cpu,
|
len_info_cpu,
|
||||||
output_padding_offset,
|
is_speculative,
|
||||||
-1,
|
-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
gather_out_np = gather_out.astype("float32").cpu().numpy()
|
gather_out_np = gather_out.astype("float32").cpu().numpy()
|
||||||
gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy()
|
gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy()
|
||||||
|
|
||||||
if output_padding_offset is not None:
|
if is_speculative:
|
||||||
np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!")
|
np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!")
|
||||||
else:
|
else:
|
||||||
for i in range(gather_out_cpu.shape[0]):
|
for i in range(gather_out_cpu.shape[0]):
|
||||||
@@ -160,19 +160,14 @@ class TestXPUOps(unittest.TestCase): # 继承 unittest.TestCase
|
|||||||
"""测试混合批次处理中的 MTP (Multi-Token Prediction) 场景"""
|
"""测试混合批次处理中的 MTP (Multi-Token Prediction) 场景"""
|
||||||
print("\nRunning test: test_mix_with_mtp")
|
print("\nRunning test: test_mix_with_mtp")
|
||||||
seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3]
|
seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3]
|
||||||
bsz = len(seq_lens_this_time_data)
|
_run_test_base(seq_lens_this_time_data, True)
|
||||||
output_padding_offset = paddle.zeros(bsz, dtype="int32")
|
|
||||||
|
|
||||||
_run_test_base(seq_lens_this_time_data, output_padding_offset)
|
|
||||||
print("Test passed for scenario: With MTP")
|
print("Test passed for scenario: With MTP")
|
||||||
|
|
||||||
def test_mix_without_mtp(self):
|
def test_mix_without_mtp(self):
|
||||||
"""测试非 MTP (Single-Token Prediction) 场景下的功能"""
|
"""测试非 MTP (Single-Token Prediction) 场景下的功能"""
|
||||||
print("\nRunning test: test_mix_without_mtp")
|
print("\nRunning test: test_mix_without_mtp")
|
||||||
seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1]
|
seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1]
|
||||||
output_padding_offset = None # 非 MTP 场景下,此参数为 None
|
_run_test_base(seq_lens_this_time_data, False)
|
||||||
|
|
||||||
_run_test_base(seq_lens_this_time_data, output_padding_offset)
|
|
||||||
print("Test passed for scenario: Without MTP")
|
print("Test passed for scenario: Without MTP")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -275,6 +275,7 @@ class XPUForwardMeta(ForwardMeta):
|
|||||||
hidden_states: Optional[paddle.Tensor] = None
|
hidden_states: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
is_draft: bool = False
|
is_draft: bool = False
|
||||||
|
is_speculative: bool = False
|
||||||
# max bs
|
# max bs
|
||||||
max_num_seqs: int = 0
|
max_num_seqs: int = 0
|
||||||
|
|
||||||
|
|||||||
@@ -1045,6 +1045,120 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
|
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
|
def _normal_sample_xpu(
|
||||||
|
self,
|
||||||
|
logits: paddle.Tensor,
|
||||||
|
probs: paddle.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
share_inputs: List[paddle.Tensor],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
"""Normal sampling for NAIVE mode on XPU."""
|
||||||
|
top_p, top_k, topp_seed = padding_sampling_params(
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
sampling_metadata.top_k,
|
||||||
|
sampling_metadata.seed,
|
||||||
|
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
|
||||||
|
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
|
||||||
|
)
|
||||||
|
_, next_tokens = top_k_top_p_sampling(
|
||||||
|
probs,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
top_k_list=sampling_metadata.top_k_list,
|
||||||
|
topp_seed=topp_seed,
|
||||||
|
)
|
||||||
|
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||||
|
running_mask = (paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]) > 0).cast("int32")
|
||||||
|
share_inputs["accept_tokens"][:real_bsz, 0] = next_tokens.squeeze(-1)
|
||||||
|
share_inputs["accept_num"][:real_bsz] = running_mask
|
||||||
|
return SamplerOutput(
|
||||||
|
sampled_token_ids=share_inputs["accept_tokens"],
|
||||||
|
logprobs_tensors=None,
|
||||||
|
token_num_per_batch=share_inputs["accept_num"],
|
||||||
|
logits=logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _verify_and_sample_xpu(
|
||||||
|
self,
|
||||||
|
logits: paddle.Tensor,
|
||||||
|
probs: paddle.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
max_model_len: int,
|
||||||
|
share_inputs: List[paddle.Tensor],
|
||||||
|
accept_all_drafts: bool = False,
|
||||||
|
reject_all_drafts: bool = False,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
"""Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens."""
|
||||||
|
from fastdeploy.model_executor.ops.xpu import (
|
||||||
|
top_p_candidates,
|
||||||
|
verify_draft_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
target_tokens = None
|
||||||
|
candidate_ids, candidate_scores, candidate_lens = None, None, None
|
||||||
|
|
||||||
|
if self.verify_strategy == VerifyStrategy.TARGET_MATCH:
|
||||||
|
top_p, top_k, topp_seed = padding_sampling_params(
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
sampling_metadata.top_k,
|
||||||
|
sampling_metadata.seed,
|
||||||
|
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
|
||||||
|
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
|
||||||
|
)
|
||||||
|
_, target_tokens = top_k_top_p_sampling(
|
||||||
|
probs,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
top_k_list=sampling_metadata.top_k_list,
|
||||||
|
topp_seed=topp_seed,
|
||||||
|
)
|
||||||
|
elif self.verify_strategy == VerifyStrategy.GREEDY:
|
||||||
|
target_tokens = paddle.argmax(probs, axis=-1)
|
||||||
|
elif self.verify_strategy == VerifyStrategy.TOPP:
|
||||||
|
candidate_scores, candidate_ids, candidate_lens = top_p_candidates(
|
||||||
|
probs,
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
share_inputs["batch_id_per_token_output"],
|
||||||
|
self.speculative_max_candidate_len,
|
||||||
|
max_model_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown verify strategy: {self.verify_strategy}")
|
||||||
|
|
||||||
|
final_accept_all = self.config_accept_all or accept_all_drafts
|
||||||
|
final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode
|
||||||
|
|
||||||
|
verify_draft_tokens(
|
||||||
|
share_inputs["accept_tokens"],
|
||||||
|
share_inputs["accept_num"],
|
||||||
|
share_inputs["draft_tokens"],
|
||||||
|
target_tokens,
|
||||||
|
candidate_ids,
|
||||||
|
candidate_scores,
|
||||||
|
candidate_lens,
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
share_inputs["stop_flags"],
|
||||||
|
share_inputs["seq_lens_encoder"],
|
||||||
|
share_inputs["seq_lens_this_time"],
|
||||||
|
sampling_metadata.eos_token_ids,
|
||||||
|
share_inputs["is_block_step"],
|
||||||
|
share_inputs["cu_seqlens_q_output"],
|
||||||
|
share_inputs["reasoning_status"],
|
||||||
|
share_inputs["max_dec_len"],
|
||||||
|
share_inputs["step_idx"],
|
||||||
|
max_model_len,
|
||||||
|
self.speculative_verify_window,
|
||||||
|
self.verify_strategy.value,
|
||||||
|
final_reject_all,
|
||||||
|
final_accept_all,
|
||||||
|
)
|
||||||
|
return SamplerOutput(
|
||||||
|
sampled_token_ids=share_inputs["accept_tokens"],
|
||||||
|
logprobs_tensors=None,
|
||||||
|
token_num_per_batch=share_inputs["accept_num"],
|
||||||
|
logits=logits,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_xpu(
|
def forward_xpu(
|
||||||
self,
|
self,
|
||||||
logits: paddle.Tensor,
|
logits: paddle.Tensor,
|
||||||
@@ -1053,9 +1167,7 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
share_inputs: List[paddle.Tensor],
|
share_inputs: List[paddle.Tensor],
|
||||||
accept_all_drafts: bool = False,
|
accept_all_drafts: bool = False,
|
||||||
reject_all_drafts: bool = False,
|
reject_all_drafts: bool = False,
|
||||||
) -> paddle.Tensor:
|
) -> SamplerOutput:
|
||||||
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
|
|
||||||
|
|
||||||
logits = apply_speculative_penalty_multi_scores(
|
logits = apply_speculative_penalty_multi_scores(
|
||||||
sampling_metadata.token_ids_all,
|
sampling_metadata.token_ids_all,
|
||||||
sampling_metadata.prompt_lens,
|
sampling_metadata.prompt_lens,
|
||||||
@@ -1078,61 +1190,19 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
|
|
||||||
probs = F.softmax(logits)
|
probs = F.softmax(logits)
|
||||||
|
|
||||||
top_p, top_k, topp_seed = padding_sampling_params(
|
is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE
|
||||||
sampling_metadata.top_p,
|
if is_naive:
|
||||||
sampling_metadata.top_k,
|
return self._normal_sample_xpu(logits, probs, sampling_metadata, share_inputs)
|
||||||
sampling_metadata.seed,
|
else:
|
||||||
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
|
return self._verify_and_sample_xpu(
|
||||||
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
|
logits,
|
||||||
)
|
probs,
|
||||||
_, sampled_token_ids = top_k_top_p_sampling(
|
sampling_metadata,
|
||||||
probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed
|
max_model_len,
|
||||||
)
|
share_inputs,
|
||||||
|
accept_all_drafts,
|
||||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
reject_all_drafts,
|
||||||
probs,
|
)
|
||||||
sampling_metadata.top_p,
|
|
||||||
share_inputs["batch_id_per_token_output"],
|
|
||||||
self.speculative_max_candidate_len,
|
|
||||||
max_model_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
speculate_verify(
|
|
||||||
sampled_token_ids,
|
|
||||||
share_inputs["accept_tokens"],
|
|
||||||
share_inputs["accept_num"],
|
|
||||||
share_inputs["step_idx"],
|
|
||||||
share_inputs["stop_flags"],
|
|
||||||
share_inputs["seq_lens_encoder"],
|
|
||||||
share_inputs["seq_lens_decoder"],
|
|
||||||
share_inputs[
|
|
||||||
"draft_tokens"
|
|
||||||
], # Both input and output, need to write the last 1 token accepted to position 0.
|
|
||||||
share_inputs["seq_lens_this_time"],
|
|
||||||
verify_tokens,
|
|
||||||
verify_scores,
|
|
||||||
share_inputs["max_dec_len"],
|
|
||||||
sampling_metadata.eos_token_ids,
|
|
||||||
share_inputs["is_block_step"],
|
|
||||||
share_inputs["cu_seqlens_q_output"],
|
|
||||||
actual_candidate_len,
|
|
||||||
share_inputs["actual_draft_token_num"],
|
|
||||||
sampling_metadata.top_p,
|
|
||||||
max_model_len,
|
|
||||||
self.speculative_verify_window,
|
|
||||||
True, # enable_topp
|
|
||||||
(self.speculative_benchmark_mode or reject_all_drafts),
|
|
||||||
accept_all_drafts,
|
|
||||||
)
|
|
||||||
# TODO(chenhuan09): support return logprobs
|
|
||||||
token_ids = share_inputs["accept_tokens"]
|
|
||||||
sampler_output = SamplerOutput(
|
|
||||||
sampled_token_ids=token_ids,
|
|
||||||
logprobs_tensors=None,
|
|
||||||
token_num_per_batch=share_inputs["accept_num"],
|
|
||||||
cu_batch_token_offset=None,
|
|
||||||
)
|
|
||||||
return sampler_output
|
|
||||||
|
|
||||||
|
|
||||||
class MTPSampler(nn.Layer):
|
class MTPSampler(nn.Layer):
|
||||||
|
|||||||
@@ -43,12 +43,11 @@ if current_platform.is_xpu():
|
|||||||
speculate_pre_process,
|
speculate_pre_process,
|
||||||
speculate_save_output,
|
speculate_save_output,
|
||||||
speculate_set_stop_value_multi_seqs,
|
speculate_set_stop_value_multi_seqs,
|
||||||
speculate_set_value_by_flags_and_idx,
|
|
||||||
speculate_step_paddle,
|
speculate_step_paddle,
|
||||||
speculate_step_reschedule,
|
speculate_step_reschedule,
|
||||||
speculate_step_system_cache,
|
speculate_step_system_cache,
|
||||||
speculate_update,
|
|
||||||
step_paddle,
|
step_paddle,
|
||||||
|
unified_update_model_status,
|
||||||
update_inputs,
|
update_inputs,
|
||||||
update_inputs_v1,
|
update_inputs_v1,
|
||||||
)
|
)
|
||||||
@@ -172,6 +171,7 @@ def xpu_pre_process(
|
|||||||
block_tables=share_inputs["block_tables"],
|
block_tables=share_inputs["block_tables"],
|
||||||
caches=share_inputs["caches"],
|
caches=share_inputs["caches"],
|
||||||
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
|
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
|
||||||
|
is_speculative=use_speculate_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -249,11 +249,6 @@ def xpu_process_output(
|
|||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
if isinstance(share_inputs, dict):
|
|
||||||
output_padding_offset = share_inputs.get("output_padding_offset", None)
|
|
||||||
else:
|
|
||||||
output_padding_offset = getattr(share_inputs, "output_padding_offset", None)
|
|
||||||
|
|
||||||
hidden_states = gather_next_token(
|
hidden_states = gather_next_token(
|
||||||
forward_output,
|
forward_output,
|
||||||
xpu_forward_meta.encoder_seq_lod,
|
xpu_forward_meta.encoder_seq_lod,
|
||||||
@@ -265,7 +260,7 @@ def xpu_process_output(
|
|||||||
xpu_forward_meta.encoder_batch_map_cpu,
|
xpu_forward_meta.encoder_batch_map_cpu,
|
||||||
xpu_forward_meta.decoder_batch_map_cpu,
|
xpu_forward_meta.decoder_batch_map_cpu,
|
||||||
xpu_forward_meta.len_info_cpu,
|
xpu_forward_meta.len_info_cpu,
|
||||||
output_padding_offset, # output_padding_offset
|
xpu_forward_meta.is_speculative,
|
||||||
xpu_forward_meta.max_num_seqs,
|
xpu_forward_meta.max_num_seqs,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -416,6 +411,8 @@ def xpu_post_process_specualate(
|
|||||||
share_inputs: Dict[str, paddle.Tensor],
|
share_inputs: Dict[str, paddle.Tensor],
|
||||||
save_each_rank: bool = False,
|
save_each_rank: bool = False,
|
||||||
skip_save_output: bool = False,
|
skip_save_output: bool = False,
|
||||||
|
is_naive_mode: bool = False,
|
||||||
|
prefill_one_step_stop: bool = False,
|
||||||
):
|
):
|
||||||
""""""
|
""""""
|
||||||
|
|
||||||
@@ -432,7 +429,7 @@ def xpu_post_process_specualate(
|
|||||||
model_output.min_tokens,
|
model_output.min_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculate_update(
|
unified_update_model_status(
|
||||||
model_output.seq_lens_encoder,
|
model_output.seq_lens_encoder,
|
||||||
model_output.seq_lens_decoder,
|
model_output.seq_lens_decoder,
|
||||||
model_output.not_need_stop,
|
model_output.not_need_stop,
|
||||||
@@ -444,6 +441,13 @@ def xpu_post_process_specualate(
|
|||||||
model_output.seq_lens_this_time,
|
model_output.seq_lens_this_time,
|
||||||
model_output.is_block_step,
|
model_output.is_block_step,
|
||||||
model_output.mask_rollback,
|
model_output.mask_rollback,
|
||||||
|
model_output.pre_ids,
|
||||||
|
model_output.prompt_lens,
|
||||||
|
model_output.step_idx,
|
||||||
|
model_output.eos_token_id,
|
||||||
|
model_output.max_dec_len,
|
||||||
|
is_naive_mode,
|
||||||
|
prefill_one_step_stop,
|
||||||
)
|
)
|
||||||
if not skip_save_output:
|
if not skip_save_output:
|
||||||
if sampler_output.logprobs_tensors is None:
|
if sampler_output.logprobs_tensors is None:
|
||||||
@@ -464,18 +468,6 @@ def xpu_post_process_specualate(
|
|||||||
|
|
||||||
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
|
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
|
||||||
|
|
||||||
# Update pre_ids through accept tokens
|
|
||||||
speculate_set_value_by_flags_and_idx(
|
|
||||||
model_output.pre_ids,
|
|
||||||
model_output.accept_tokens,
|
|
||||||
model_output.accept_num,
|
|
||||||
model_output.stop_flags,
|
|
||||||
model_output.seq_lens_this_time,
|
|
||||||
model_output.seq_lens_encoder,
|
|
||||||
model_output.seq_lens_decoder,
|
|
||||||
model_output.step_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def step_xpu(
|
def step_xpu(
|
||||||
share_inputs: Dict[str, paddle.Tensor],
|
share_inputs: Dict[str, paddle.Tensor],
|
||||||
|
|||||||
@@ -286,20 +286,12 @@ class InputBatch:
|
|||||||
fill_value=max_draft_token_num,
|
fill_value=max_draft_token_num,
|
||||||
dtype="int32",
|
dtype="int32",
|
||||||
)
|
)
|
||||||
if current_platform.is_cuda():
|
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
|
||||||
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
|
self.batch_id_per_token_output = paddle.full(
|
||||||
self.batch_id_per_token_output = paddle.full(
|
shape=[max_num_seqs * (max_draft_token_num + 1)],
|
||||||
shape=[max_num_seqs * (max_draft_token_num + 1)],
|
fill_value=0,
|
||||||
fill_value=0,
|
dtype="int32",
|
||||||
dtype="int32",
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
|
||||||
self.output_padding_offset = paddle.full(
|
|
||||||
shape=[max_num_seqs * (max_draft_token_num + 1)],
|
|
||||||
fill_value=0,
|
|
||||||
dtype="int32",
|
|
||||||
)
|
|
||||||
# For V1_KVCACHE_SCHEDULER
|
# For V1_KVCACHE_SCHEDULER
|
||||||
self.step_draft_tokens = paddle.full(
|
self.step_draft_tokens = paddle.full(
|
||||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||||
@@ -437,7 +429,7 @@ class InputBatch:
|
|||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
swap_data(self.cu_seqlens_q_output, i1, i2)
|
swap_data(self.cu_seqlens_q_output, i1, i2)
|
||||||
else:
|
else:
|
||||||
swap_data(self.output_cum_offsets, i1, i2)
|
swap_data(self.cu_seqlens_q_output, i1, i2)
|
||||||
swap_data(self.step_draft_tokens, i1, i2)
|
swap_data(self.step_draft_tokens, i1, i2)
|
||||||
swap_data(self.step_seq_lens_this_time, i1, i2)
|
swap_data(self.step_seq_lens_this_time, i1, i2)
|
||||||
swap_data(self.draft_logits, i1, i2)
|
swap_data(self.draft_logits, i1, i2)
|
||||||
@@ -628,8 +620,8 @@ class InputBatch:
|
|||||||
fill_paddle_tensor(self, "accept_num", 0)
|
fill_paddle_tensor(self, "accept_num", 0)
|
||||||
fill_paddle_tensor(self, "draft_tokens", -1)
|
fill_paddle_tensor(self, "draft_tokens", -1)
|
||||||
fill_paddle_tensor(self, "actual_draft_token_num", max_draft_token_num)
|
fill_paddle_tensor(self, "actual_draft_token_num", max_draft_token_num)
|
||||||
fill_paddle_tensor(self, "output_cum_offsets", 0)
|
fill_paddle_tensor(self, "cu_seqlens_q_output", 0)
|
||||||
fill_paddle_tensor(self, "output_padding_offset", 0)
|
fill_paddle_tensor(self, "batch_id_per_token_output", 0)
|
||||||
fill_paddle_tensor(self, "step_draft_tokens", 0)
|
fill_paddle_tensor(self, "step_draft_tokens", 0)
|
||||||
fill_paddle_tensor(self, "step_seq_lens_this_time", 0)
|
fill_paddle_tensor(self, "step_seq_lens_this_time", 0)
|
||||||
fill_paddle_tensor(self, "draft_logits", -1)
|
fill_paddle_tensor(self, "draft_logits", -1)
|
||||||
@@ -742,8 +734,8 @@ class ProposerInputBatch(InputBatch):
|
|||||||
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
|
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
|
||||||
self.token_ids_all = None
|
self.token_ids_all = None
|
||||||
else:
|
else:
|
||||||
self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"])
|
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
|
||||||
self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"])
|
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
|
||||||
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
|
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
|
||||||
self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
|
self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
|
||||||
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])
|
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])
|
||||||
|
|||||||
@@ -135,8 +135,8 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
self.encoder_cache = None
|
self.encoder_cache = None
|
||||||
|
|
||||||
self.device_id = device_id
|
self.device_id = device_id
|
||||||
self.speculative_method = self.fd_config.speculative_config.method
|
self.spec_method = self.fd_config.speculative_config.method
|
||||||
self.speculative_decoding = self.speculative_method is not None
|
self.speculative_decoding = self.spec_method is not None
|
||||||
|
|
||||||
# used by SamplingMetadata
|
# used by SamplingMetadata
|
||||||
self.enable_logprob = fd_config.model_config.enable_logprob # fd_config.model_config.enable_logprob
|
self.enable_logprob = fd_config.model_config.enable_logprob # fd_config.model_config.enable_logprob
|
||||||
@@ -728,7 +728,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
if has_prefill_task or has_decode_task:
|
if has_prefill_task or has_decode_task:
|
||||||
self.share_inputs["not_need_stop"][0] = True
|
self.share_inputs["not_need_stop"][0] = True
|
||||||
|
|
||||||
if self.speculative_method == SpecMethod.MTP:
|
if self.spec_method == SpecMethod.MTP:
|
||||||
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
|
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
|
||||||
|
|
||||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||||
@@ -877,7 +877,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
self.share_inputs["not_need_stop"][0] = True
|
self.share_inputs["not_need_stop"][0] = True
|
||||||
|
|
||||||
if self.speculative_method == SpecMethod.MTP:
|
if self.spec_method == SpecMethod.MTP:
|
||||||
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request(
|
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request(
|
||||||
request, "temp_scaled_logprobs", False
|
request, "temp_scaled_logprobs", False
|
||||||
)
|
)
|
||||||
@@ -1070,12 +1070,18 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
fill_value=max_draft_token_num,
|
fill_value=max_draft_token_num,
|
||||||
dtype="int32",
|
dtype="int32",
|
||||||
)
|
)
|
||||||
self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
self.share_inputs["cu_seqlens_q_output"] = paddle.full(
|
||||||
self.share_inputs["output_padding_offset"] = paddle.full(
|
shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32"
|
||||||
|
)
|
||||||
|
self.share_inputs["batch_id_per_token_output"] = paddle.full(
|
||||||
shape=[max_num_seqs * (max_draft_token_num + 1)],
|
shape=[max_num_seqs * (max_draft_token_num + 1)],
|
||||||
fill_value=0,
|
fill_value=0,
|
||||||
dtype="int32",
|
dtype="int32",
|
||||||
)
|
)
|
||||||
|
# reasoning_status: per-sequence reasoning phase indicator
|
||||||
|
# 0=thinking, 1=emitting boundary, 2=response, 3=end
|
||||||
|
# verify_draft_tokens 在 reasoning_status==1 时强制拒绝所有 draft token
|
||||||
|
self.share_inputs["reasoning_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||||
# For V1_KVCACHE_SCHEDULER
|
# For V1_KVCACHE_SCHEDULER
|
||||||
self.share_inputs["step_draft_tokens"] = paddle.full(
|
self.share_inputs["step_draft_tokens"] = paddle.full(
|
||||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||||
@@ -1439,7 +1445,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
block_num=block_num,
|
block_num=block_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_method == SpecMethod.MTP:
|
if self.spec_method == SpecMethod.MTP:
|
||||||
self.proposer.dummy_prefill_inputs(
|
self.proposer.dummy_prefill_inputs(
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@@ -1456,19 +1462,16 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
"""
|
"""
|
||||||
Init speculative proposer
|
Init speculative proposer
|
||||||
"""
|
"""
|
||||||
if self.speculative_method == SpecMethod.NGRAM:
|
if self.spec_method is None:
|
||||||
# xpu not support ngram proposer now
|
|
||||||
self.proposer = None
|
|
||||||
elif self.speculative_method == SpecMethod.MTP:
|
|
||||||
self.proposer = self.speculative_method.create_proposer(
|
|
||||||
self.fd_config,
|
|
||||||
main_model=self.get_model(),
|
|
||||||
local_rank=self.local_rank,
|
|
||||||
device_id=self.device_id,
|
|
||||||
share_inputs=self.share_inputs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.proposer = None
|
self.proposer = None
|
||||||
|
return
|
||||||
|
self.proposer = self.spec_method.create_proposer(
|
||||||
|
self.fd_config,
|
||||||
|
main_model=self.get_model(),
|
||||||
|
local_rank=self.local_rank,
|
||||||
|
device_id=self.device_id,
|
||||||
|
share_inputs=self.share_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
def _set_debug_level(
|
def _set_debug_level(
|
||||||
self, debug_level: int = 0x1, model_forward_batch: Optional[List[Request]] = None, is_dummy_run: bool = False
|
self, debug_level: int = 0x1, model_forward_batch: Optional[List[Request]] = None, is_dummy_run: bool = False
|
||||||
@@ -1670,6 +1673,8 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs,
|
self.share_inputs,
|
||||||
self.parallel_config.data_parallel_size > 1,
|
self.parallel_config.data_parallel_size > 1,
|
||||||
skip_save_output,
|
skip_save_output,
|
||||||
|
is_naive_mode=(self.speculative_decoding and self.proposer is None),
|
||||||
|
prefill_one_step_stop=self.parallel_config.prefill_one_step_stop,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
xpu_post_process_normal(
|
xpu_post_process_normal(
|
||||||
@@ -1685,8 +1690,11 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 6. Draft model propose
|
# 6. Draft model propose
|
||||||
if self.speculative_method == SpecMethod.MTP:
|
if self.speculative_decoding and self.proposer is not None:
|
||||||
self.proposer.run(full_hidden_states=model_output)
|
if self.spec_method == SpecMethod.MTP:
|
||||||
|
self.proposer.run(full_hidden_states=model_output)
|
||||||
|
else:
|
||||||
|
self.proposer.run(share_inputs=self.share_inputs)
|
||||||
|
|
||||||
# 7. Updata 'infer_seed' and step_paddle()
|
# 7. Updata 'infer_seed' and step_paddle()
|
||||||
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
|
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
|
||||||
@@ -1738,7 +1746,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
|
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
|
||||||
|
|
||||||
self.num_gpu_blocks = self.cache_config.total_block_num
|
self.num_gpu_blocks = self.cache_config.total_block_num
|
||||||
if self.speculative_method == SpecMethod.MTP:
|
if self.spec_method == SpecMethod.MTP:
|
||||||
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
|
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
|
||||||
self.initialize_kv_cache(profile=True)
|
self.initialize_kv_cache(profile=True)
|
||||||
|
|
||||||
@@ -1760,7 +1768,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
|
|
||||||
# Reset block table and kv cache with global block num
|
# Reset block table and kv cache with global block num
|
||||||
if self.speculative_method == SpecMethod.MTP:
|
if self.spec_method == SpecMethod.MTP:
|
||||||
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
|
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
|
||||||
self.initialize_kv_cache()
|
self.initialize_kv_cache()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user