[Bugfix] Fix mtp logprob hang problem when include stop_seq (#5927)

* fix mtp logprob hang when include stop_seq
This commit is contained in:
GoldPancake
2026-01-08 14:21:24 +08:00
committed by GitHub
parent dc170e3005
commit a1fc4e249e
2 changed files with 7 additions and 4 deletions
@@ -54,6 +54,10 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
if (!save_each_rank && rank_id > 0) {
return;
}
int max_draft_tokens = sampled_token_ids.shape()[1];
int bsz = token_num_per_batch.shape()[0];
auto sampled_token_ids_cpu =
sampled_token_ids.copy_to(paddle::CPUPlace(), false);
auto logprob_token_ids_cpu =
@@ -128,7 +132,6 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
: -inference_msg_id_from_env;
msg_sed.meta[1] = message_flag;
int bsz = token_num_per_batch.shape()[0];
msg_sed.meta[2] = bsz;
int max_num_logprobs = logprob_token_ids.shape()[1];
for (int i = 0; i < bsz; i++) {
@@ -146,7 +149,7 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
for (int k = 0; k < K + 1; k++) {
if (k == 0) {
cur_tokens[k] = (int)sampled_token_ids_data[token_offset + j];
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k];
} else if (k < max_num_logprobs) {
cur_tokens[k] =