diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 4e547d2977..9ef563c62c 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -54,6 +54,10 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, if (!save_each_rank && rank_id > 0) { return; } + + int max_draft_tokens = sampled_token_ids.shape()[1]; + int bsz = token_num_per_batch.shape()[0]; + auto sampled_token_ids_cpu = sampled_token_ids.copy_to(paddle::CPUPlace(), false); auto logprob_token_ids_cpu = @@ -128,7 +132,6 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; msg_sed.meta[1] = message_flag; - int bsz = token_num_per_batch.shape()[0]; msg_sed.meta[2] = bsz; int max_num_logprobs = logprob_token_ids.shape()[1]; for (int i = 0; i < bsz; i++) { @@ -146,7 +149,7 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; for (int k = 0; k < K + 1; k++) { if (k == 0) { - cur_tokens[k] = (int)sampled_token_ids_data[token_offset + j]; + cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; } else if (k < max_num_logprobs) { cur_tokens[k] = diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 96edebdda5..1cf8d8bc1b 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -839,15 +839,15 @@ class SpeculativeSampler(nn.Layer): raw_logprobs = target_logits.clone() logprobs_tensors = None - token_ids = share_inputs["accept_tokens"] if num_logprobs is not None: + token_ids = share_inputs["accept_tokens"] idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32") mask = idx < share_inputs["accept_num"].unsqueeze(1) token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) sampler_output = SamplerOutput( - sampled_token_ids=token_ids, + sampled_token_ids=share_inputs["accept_tokens"], logprobs_tensors=logprobs_tensors, token_num_per_batch=share_inputs["accept_num"], cu_batch_token_offset=share_inputs["cu_batch_token_offset"],