mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Bugfix] Fix mtp logprob hang problem when include stop_seq (#5927)
* fix mtp logprob hang when include stop_seq
This commit is contained in:
@@ -54,6 +54,10 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
|||||||
if (!save_each_rank && rank_id > 0) {
|
if (!save_each_rank && rank_id > 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int max_draft_tokens = sampled_token_ids.shape()[1];
|
||||||
|
int bsz = token_num_per_batch.shape()[0];
|
||||||
|
|
||||||
auto sampled_token_ids_cpu =
|
auto sampled_token_ids_cpu =
|
||||||
sampled_token_ids.copy_to(paddle::CPUPlace(), false);
|
sampled_token_ids.copy_to(paddle::CPUPlace(), false);
|
||||||
auto logprob_token_ids_cpu =
|
auto logprob_token_ids_cpu =
|
||||||
@@ -128,7 +132,6 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
|||||||
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
|
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
|
||||||
: -inference_msg_id_from_env;
|
: -inference_msg_id_from_env;
|
||||||
msg_sed.meta[1] = message_flag;
|
msg_sed.meta[1] = message_flag;
|
||||||
int bsz = token_num_per_batch.shape()[0];
|
|
||||||
msg_sed.meta[2] = bsz;
|
msg_sed.meta[2] = bsz;
|
||||||
int max_num_logprobs = logprob_token_ids.shape()[1];
|
int max_num_logprobs = logprob_token_ids.shape()[1];
|
||||||
for (int i = 0; i < bsz; i++) {
|
for (int i = 0; i < bsz; i++) {
|
||||||
@@ -146,7 +149,7 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
|||||||
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
|
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
|
||||||
for (int k = 0; k < K + 1; k++) {
|
for (int k = 0; k < K + 1; k++) {
|
||||||
if (k == 0) {
|
if (k == 0) {
|
||||||
cur_tokens[k] = (int)sampled_token_ids_data[token_offset + j];
|
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
|
||||||
cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k];
|
cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k];
|
||||||
} else if (k < max_num_logprobs) {
|
} else if (k < max_num_logprobs) {
|
||||||
cur_tokens[k] =
|
cur_tokens[k] =
|
||||||
|
|||||||
@@ -839,15 +839,15 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
raw_logprobs = target_logits.clone()
|
raw_logprobs = target_logits.clone()
|
||||||
|
|
||||||
logprobs_tensors = None
|
logprobs_tensors = None
|
||||||
token_ids = share_inputs["accept_tokens"]
|
|
||||||
if num_logprobs is not None:
|
if num_logprobs is not None:
|
||||||
|
token_ids = share_inputs["accept_tokens"]
|
||||||
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
|
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
|
||||||
mask = idx < share_inputs["accept_num"].unsqueeze(1)
|
mask = idx < share_inputs["accept_num"].unsqueeze(1)
|
||||||
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
|
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
|
||||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||||
|
|
||||||
sampler_output = SamplerOutput(
|
sampler_output = SamplerOutput(
|
||||||
sampled_token_ids=token_ids,
|
sampled_token_ids=share_inputs["accept_tokens"],
|
||||||
logprobs_tensors=logprobs_tensors,
|
logprobs_tensors=logprobs_tensors,
|
||||||
token_num_per_batch=share_inputs["accept_num"],
|
token_num_per_batch=share_inputs["accept_num"],
|
||||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||||
|
|||||||
Reference in New Issue
Block a user