From d05f5f087732ae28d0323cf608dedd53d766cb02 Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Thu, 8 Jan 2026 14:21:33 +0800 Subject: [PATCH] [Cherry-Pick][Bugfix] Fix mtp logprob hang problem when include stop_seq (#5927) (#5928) * fix mtp logprob hang when include stop_seq --- .../speculate_decoding/speculate_save_output_with_topk.cc | 7 +++++-- fastdeploy/model_executor/layers/sample/sampler.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 4e547d2977..9ef563c62c 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -54,6 +54,10 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, if (!save_each_rank && rank_id > 0) { return; } + + int max_draft_tokens = sampled_token_ids.shape()[1]; + int bsz = token_num_per_batch.shape()[0]; + auto sampled_token_ids_cpu = sampled_token_ids.copy_to(paddle::CPUPlace(), false); auto logprob_token_ids_cpu = @@ -128,7 +132,6 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; msg_sed.meta[1] = message_flag; - int bsz = token_num_per_batch.shape()[0]; msg_sed.meta[2] = bsz; int max_num_logprobs = logprob_token_ids.shape()[1]; for (int i = 0; i < bsz; i++) { @@ -146,7 +149,7 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; for (int k = 0; k < K + 1; k++) { if (k == 0) { - cur_tokens[k] = (int)sampled_token_ids_data[token_offset + j]; + cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; } else if (k < max_num_logprobs) { cur_tokens[k] = diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index afc8b725ce..84abe02d45 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -839,15 +839,15 @@ class SpeculativeSampler(nn.Layer): raw_logprobs = target_logits.clone() logprobs_tensors = None - token_ids = share_inputs["accept_tokens"] if num_logprobs is not None: + token_ids = share_inputs["accept_tokens"] idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32") mask = idx < share_inputs["accept_num"].unsqueeze(1) token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) sampler_output = SamplerOutput( - sampled_token_ids=token_ids, + sampled_token_ids=share_inputs["accept_tokens"], logprobs_tensors=logprobs_tensors, token_num_per_batch=share_inputs["accept_num"], cu_batch_token_offset=share_inputs["cu_batch_token_offset"],