From 650d1e49aaf2f6fedf7354655e11480fe7aeba6d Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Fri, 17 Apr 2026 21:37:42 +0800 Subject: [PATCH] [Cherry-Pick][Speculative Decoding] Add MTP logprob support for PD disaggregation (#7442) (#7464) * support mtp logprob in pd * fix * fix * fix * fix xpu bugs --- .../mtp_save_first_token_with_topk.cc | 218 ++++++++++++++++++ .../speculate_get_output_with_topk.cc | 45 ++-- .../speculate_logprob_msg.h | 39 ++++ .../speculate_save_output_with_topk.cc | 44 ++-- .../model_executor/pre_and_post_process.py | 75 +++++- fastdeploy/spec_decode/mtp.py | 38 +-- fastdeploy/worker/gpu_model_runner.py | 8 +- 7 files changed, 389 insertions(+), 78 deletions(-) create mode 100644 custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc create mode 100644 custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc new file mode 100644 index 0000000000..02203a51cf --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc @@ -0,0 +1,218 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" +#include "../../custom_ftok.h" +#include "../speculate_logprob_msg.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, + const paddle::Tensor& logprob_token_ids, + const paddle::Tensor& logprob_scores, + const paddle::Tensor& logprob_ranks, + const paddle::Tensor& token_num_per_batch, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& preempted_idx, + int message_flag, // Target: 3, Draft: 4 + int64_t rank_id, + bool save_each_rank) { + if (!save_each_rank && rank_id > 0) { + return; + } + + int max_draft_tokens = sampled_token_ids.shape()[1]; + int bsz = token_num_per_batch.shape()[0]; + + auto sampled_token_ids_cpu = + sampled_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_token_ids_cpu = + logprob_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false); + auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false); + auto token_num_per_batch_cpu = + token_num_per_batch.copy_to(paddle::CPUPlace(), false); + auto cu_batch_token_offset_cpu = + cu_batch_token_offset.copy_to(paddle::CPUPlace(), false); + auto seq_lens_decoder_cpu = + seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true); + int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data(); + int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data(); + float* logprob_scores_data = logprob_scores_cpu.data(); + int64_t* logprob_ranks_data = logprob_ranks_cpu.data(); + int* token_num_per_batch_data = token_num_per_batch_cpu.data(); + int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data(); + int* seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + int64_t* prompt_lens_data = prompt_lens_cpu.data(); + const int32_t* preempted_idx_data = preempted_idx.data(); + + static struct msgdata msg_sed; + int msg_queue_id = 1; + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); + msg_queue_id = inference_msg_queue_id_from_env; +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." + << std::endl; +#endif + } + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is perserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } + static key_t key = custom_ftok("/dev/shm", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output_key: " << key << std::endl; + std::cout << "save msgid: " << msgid << std::endl; +#endif + msg_sed.mtype = 1; + msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env + : -inference_msg_id_from_env; + msg_sed.meta[1] = message_flag; + msg_sed.meta[2] = bsz; + int max_num_logprobs = logprob_token_ids.shape()[1]; + for (int i = 0; i < bsz; i++) { + int cur_token_num; + if (seq_lens_decoder_data[i] < prompt_lens_data[i] || + token_num_per_batch_data[i] == 0) { + // chunk prefill or stop slots + cur_token_num = 0; + } else { + cur_token_num = token_num_per_batch_data[i] + 1; + } + msg_sed.meta[3 + i] = cur_token_num; + if (preempted_idx_data[i] == 1) { + msg_sed.meta[3 + i] = -9; + } + + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + int token_offset = cu_batch_token_offset_data[i]; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; + if (j == 0) { + // first token has full logprobs + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + if (k == 0) { + cur_tokens[k] = + (int)sampled_token_ids_data[i * max_draft_tokens + j]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; + } else if (k < max_num_logprobs) { + // only for first token + cur_tokens[k] = + (int)logprob_token_ids_data[(token_offset + j) * + (SPEC_LOGPROB_K + 1) + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; + } else { + cur_tokens[k] = -1; + cur_scores[k] = 0.0; + } + } + cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; + } else { + // draft token only has token_id + cur_tokens[0] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; + } + } + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "msg data: " << std::endl; + std::cout << "stop_flag: " << msg_sed.meta[0] + << ", message_flag: " << msg_sed.meta[1] + << ", bsz: " << msg_sed.meta[2] << std::endl; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_sed.meta[3 + i]; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; + std::cout << "tokens: "; + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + std::cout << cur_tokens[k] << " "; + } + std::cout << std::endl; + std::cout << "scores: "; + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + std::cout << cur_scores[k] << " "; + } + std::cout << std::endl; + std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl; + } + } + std::cout << std::endl; +#endif + if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) { + printf("full msg buffer\n"); + } +} + +PD_BUILD_STATIC_OP(mtp_save_first_token_with_topk) + .Inputs({"sampled_token_ids", + "logprob_token_ids", + "logprob_scores", + "logprob_ranks", + "token_num_per_batch", + "cu_batch_token_offset", + "not_need_stop", + "seq_lens_decoder", + "prompt_lens", + "preempted_idx"}) + .Attrs({"message_flag: int", "rank_id: int64_t", "save_each_rank: bool"}) + .SetKernelFn(PD_KERNEL(MTPSaveFirstTokenWithTopK)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc index 76ff5e190d..4fd7d4103c 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -19,27 +19,12 @@ #include #include "paddle/extension.h" #include "../custom_ftok.h" +#include "speculate_logprob_msg.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -#define MAX_BSZ 512 -#define K 20 -#define MAX_DRAFT_TOKEN_NUM 6 - -struct batch_msgdata { - int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - int ranks[MAX_DRAFT_TOKEN_NUM]; -}; - -struct msgdata { - long mtype; - int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums - batch_msgdata mtext[MAX_BSZ]; -}; - void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, const paddle::Tensor& output_scores, const paddle::Tensor& output_ranks, @@ -93,22 +78,22 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; - int output_tokens_offset = 3 + MAX_BSZ; + int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_rcv.meta[3 + i]; output_tokens_data[3 + i] = (int64_t)cur_token_num; // batch_token_nums auto* cur_output_token = output_tokens_data + output_tokens_offset + - i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); auto* cur_output_score = - output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; for (int j = 0; j < cur_token_num; j++) { for (int k = 0; k < real_k + 1; k++) { - cur_output_token[j * (K + 1) + k] = - (int64_t)cur_batch_msg_rcv->tokens[j * (K + 1) + k]; - cur_output_score[j * (K + 1) + k] = - cur_batch_msg_rcv->scores[j * (K + 1) + k]; + cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] = + (int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k]; + cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] = + cur_batch_msg_rcv->scores[j * (SPEC_LOGPROB_K + 1) + k]; } output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] = (int64_t)cur_batch_msg_rcv->ranks[j]; @@ -124,17 +109,19 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; for (int j = 0; j < cur_token_num; j++) { std::cout << "tokens: "; - for (int k = 0; k < K + 1; k++) { + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { std::cout << output_tokens_data[output_tokens_offset + - i * MAX_DRAFT_TOKEN_NUM * (K + 1) + - j * (K + 1) + k] + i * MAX_DRAFT_TOKEN_NUM * + (SPEC_LOGPROB_K + 1) + + j * (SPEC_LOGPROB_K + 1) + k] << " "; } std::cout << std::endl; std::cout << "scores: "; - for (int k = 0; k < K + 1; k++) { - std::cout << output_scores_data[i * MAX_DRAFT_TOKEN_NUM * (K + 1) + - j * (K + 1) + k] + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + std::cout << output_scores_data[i * MAX_DRAFT_TOKEN_NUM * + (SPEC_LOGPROB_K + 1) + + j * (SPEC_LOGPROB_K + 1) + k] << " "; } std::cout << std::endl; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h new file mode 100644 index 0000000000..dc2c6f399f --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h @@ -0,0 +1,39 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#define SPEC_LOGPROB_MAX_BSZ 512 +#define SPEC_LOGPROB_K 20 +#define MAX_DRAFT_TOKEN_NUM 6 + +struct batch_msgdata { + int tokens[MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; +}; + +struct msgdata { + long mtype; + // stop_flag, message_flag, bsz, batch_token_nums + int meta[3 + SPEC_LOGPROB_MAX_BSZ]; + batch_msgdata mtext[SPEC_LOGPROB_MAX_BSZ]; +}; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 3d75886bd2..0b3de384ce 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -19,27 +19,12 @@ #include #include "paddle/extension.h" #include "../custom_ftok.h" +#include "speculate_logprob_msg.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -#define MAX_BSZ 512 -#define K 20 -#define MAX_DRAFT_TOKEN_NUM 6 - -struct batch_msgdata { - int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - int ranks[MAX_DRAFT_TOKEN_NUM]; -}; - -struct msgdata { - long mtype; - int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums - batch_msgdata mtext[MAX_BSZ]; -}; - void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, const paddle::Tensor& logprob_token_ids, const paddle::Tensor& logprob_scores, @@ -154,16 +139,21 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; for (int j = 0; j < cur_token_num; j++) { - auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; - auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; - for (int k = 0; k < K + 1; k++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { if (k == 0) { cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; - cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; } else if (k < max_num_logprobs) { - cur_tokens[k] = - (int)logprob_token_ids_data[(token_offset + j) * (K + 1) + k]; - cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) + + k]; } else { cur_tokens[k] = -1; cur_scores[k] = 0.0; @@ -182,15 +172,15 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; for (int j = 0; j < cur_token_num; j++) { - auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; - auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; std::cout << "tokens: "; - for (int k = 0; k < K + 1; k++) { + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { std::cout << cur_tokens[k] << " "; } std::cout << std::endl; std::cout << "scores: "; - for (int k = 0; k < K + 1; k++) { + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { std::cout << cur_scores[k] << " "; } std::cout << std::endl; diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 0fc6bfde5d..29fc423538 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -22,9 +22,14 @@ import paddle from fastdeploy import envs from fastdeploy.config import SpeculativeConfig +from fastdeploy.model_executor.ops.gpu import ( + mtp_save_first_token, + mtp_save_first_token_with_topk, +) from fastdeploy.platforms import current_platform from fastdeploy.worker.input_batch import ( InputBatch, + ProposerInputBatch, recover_batch_index_for_output, recover_batch_index_for_sampler_output, ) @@ -525,10 +530,76 @@ def save_output_specualate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: InputBatch, + proposer_share_inputs: ProposerInputBatch, + local_rank: int, + tensor_parallel_rank: int, save_each_rank: bool = False, - skip_save_output: bool = False, + is_mtp_prefill: bool = False, ): - if not skip_save_output: + if is_mtp_prefill: + if tensor_parallel_rank == 0: + skip_chunk_prefill = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + if sampler_output.logprobs_tensors is None: + recover_proposer_share_inputs_map = recover_batch_index_for_output( + proposer_share_inputs, + proposer_share_inputs.index_to_batch_id, + proposer_share_inputs.enable_pd_reorder, + [ + "base_model_draft_tokens", + "seq_lens_decoder", + "prompt_lens", + "step_idx", + ], + ) + mtp_save_first_token( + recover_proposer_share_inputs_map["base_model_draft_tokens"], + proposer_share_inputs["not_need_stop"], + recover_proposer_share_inputs_map["seq_lens_decoder"], + recover_proposer_share_inputs_map["prompt_lens"], + recover_proposer_share_inputs_map["step_idx"], + local_rank, + save_each_rank, + skip_chunk_prefill, + ) + else: + recover_share_inputs_map = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + [ + "sampled_token_ids", + "accept_tokens_cpu", + "accept_num_cpu", + "seq_lens_decoder_cpu", + "prompt_lens_cpu", + "last_preempted_idx", + ], + ) + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + recover_proposer_share_inputs_map = recover_batch_index_for_output( + proposer_share_inputs, + proposer_share_inputs.index_to_batch_id, + proposer_share_inputs.enable_pd_reorder, + ["base_model_draft_tokens"], + ) + mtp_save_first_token_with_topk( + recover_proposer_share_inputs_map["base_model_draft_tokens"], + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + recover_share_inputs_map["accept_num_cpu"], + sampler_output.cu_batch_token_offset, + model_output.not_need_stop, + recover_share_inputs_map["seq_lens_decoder_cpu"], + recover_share_inputs_map["prompt_lens_cpu"], + recover_share_inputs_map["last_preempted_idx"], + 3, # mtype + model_output.mp_rank, + save_each_rank, + ) + else: if sampler_output.logprobs_tensors is None: recover_share_inputs = recover_batch_index_for_output( share_inputs, diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 9ca5f535ab..acf7bee27a 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -62,7 +62,6 @@ else: eagle_get_self_hidden_states, eagle_gather_hidden_states, hybrid_mtp_ngram, - mtp_save_first_token, mtp_step_paddle, share_external_data, speculate_get_logits, @@ -835,23 +834,26 @@ class MTPProposer(Proposer): ) if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: - skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder, - ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], - ) - mtp_save_first_token( - recover_model_output_map["base_model_draft_tokens"], - self.model_inputs["not_need_stop"], - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - recover_model_output_map["step_idx"], - self.local_rank, - self.parallel_config.use_ep, - skip_save, - ) + if current_platform.is_xpu(): + # Note(wangyanpeng): mtp_save_first_token for GPU platforms has been moved to model_runner. + # Only XPU platform is retained here. + skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder, + ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], + ) + mtp_save_first_token( + recover_model_output_map["base_model_draft_tokens"], + self.model_inputs["not_need_stop"], + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + recover_model_output_map["step_idx"], + self.local_rank, + self.parallel_config.use_ep, + skip_save, + ) # Ensure only save first token once. paddle.assign( paddle.where( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 2bdbdb345b..43478e1a81 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2478,13 +2478,17 @@ class GPUModelRunner(ModelRunnerBase): sampler_output, ): if self.speculative_decoding: - skip_save_output = self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" save_output_specualate( sampler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, + proposer_share_inputs=self.proposer.model_inputs, + local_rank=self.local_rank, + tensor_parallel_rank=self.parallel_config.tensor_parallel_rank, save_each_rank=self.parallel_config.use_ep, - skip_save_output=skip_save_output, + is_mtp_prefill=( + self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" + ), ) else: save_output_normal(