mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Bugfix] Fix mtp logprob hang problem when include stop_seq (#5927)
* fix mtp logprob hang when include stop_seq
This commit is contained in:
@@ -54,6 +54,10 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int max_draft_tokens = sampled_token_ids.shape()[1];
|
||||
int bsz = token_num_per_batch.shape()[0];
|
||||
|
||||
auto sampled_token_ids_cpu =
|
||||
sampled_token_ids.copy_to(paddle::CPUPlace(), false);
|
||||
auto logprob_token_ids_cpu =
|
||||
@@ -128,7 +132,6 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
msg_sed.meta[1] = message_flag;
|
||||
int bsz = token_num_per_batch.shape()[0];
|
||||
msg_sed.meta[2] = bsz;
|
||||
int max_num_logprobs = logprob_token_ids.shape()[1];
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
@@ -146,7 +149,7 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
|
||||
for (int k = 0; k < K + 1; k++) {
|
||||
if (k == 0) {
|
||||
cur_tokens[k] = (int)sampled_token_ids_data[token_offset + j];
|
||||
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
|
||||
cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k];
|
||||
} else if (k < max_num_logprobs) {
|
||||
cur_tokens[k] =
|
||||
|
||||
Reference in New Issue
Block a user