mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
* fix mtp logprob hang when include stop_seq
This commit is contained in:
@@ -54,6 +54,10 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int max_draft_tokens = sampled_token_ids.shape()[1];
|
||||
int bsz = token_num_per_batch.shape()[0];
|
||||
|
||||
auto sampled_token_ids_cpu =
|
||||
sampled_token_ids.copy_to(paddle::CPUPlace(), false);
|
||||
auto logprob_token_ids_cpu =
|
||||
@@ -128,7 +132,6 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
msg_sed.meta[0] = not_need_stop.data<bool>()[0] ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
msg_sed.meta[1] = message_flag;
|
||||
int bsz = token_num_per_batch.shape()[0];
|
||||
msg_sed.meta[2] = bsz;
|
||||
int max_num_logprobs = logprob_token_ids.shape()[1];
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
@@ -146,7 +149,7 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
|
||||
for (int k = 0; k < K + 1; k++) {
|
||||
if (k == 0) {
|
||||
cur_tokens[k] = (int)sampled_token_ids_data[token_offset + j];
|
||||
cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
|
||||
cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k];
|
||||
} else if (k < max_num_logprobs) {
|
||||
cur_tokens[k] =
|
||||
|
||||
@@ -839,15 +839,15 @@ class SpeculativeSampler(nn.Layer):
|
||||
raw_logprobs = target_logits.clone()
|
||||
|
||||
logprobs_tensors = None
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
if num_logprobs is not None:
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
|
||||
mask = idx < share_inputs["accept_num"].unsqueeze(1)
|
||||
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
sampled_token_ids=share_inputs["accept_tokens"],
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
|
||||
Reference in New Issue
Block a user