[XPU] Speculative Decoding with PD (#5856)

* [XPU] Speculative Decoding with PD

* fix post process

* share kv cache sender

* support speculate decoding step system cache

* support speculate decoding step system cache

---------

Co-authored-by: root <root@gajl-bbc-onlinec-com-1512108.gajl.baidu.com>
This commit is contained in:
cmcamdy
2026-01-05 17:31:03 +08:00
committed by GitHub
parent ac39c0f887
commit 690d4bcdb0
5 changed files with 193 additions and 74 deletions
@@ -18,29 +18,35 @@
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#define MAX_BSZ 256
// #define SAVE_WITH_OUTPUT_DEBUG
#define MAX_DRAFT_TOKENS 6
struct msgdata {
long mtype; // NOLINT
int mtext[2 + MAX_BSZ +
MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
};
#include "speculate_msg.h"
void MTPSaveFirstToken(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
bool save_each_rank,
bool skip_chunk_prefill) {
if (!save_each_rank && rank_id > 0) {
return;
}
int x_dim = x.shape()[1];
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;
auto seq_lens_decoder_cpu =
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
auto step_idx_cpu = step_idx.copy_to(paddle::CPUPlace(), true);
int64_t* step_idx_data = step_idx_cpu.data<int64_t>();
static struct speculate_msgdata msg_sed;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
@@ -53,6 +59,7 @@ void MTPSaveFirstToken(const paddle::Tensor& x,
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("./", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
@@ -98,11 +105,25 @@ void MTPSaveFirstToken(const paddle::Tensor& x,
static_cast<int>(x_data[i * x_dim]),
static_cast<int>(x_data[i * x_dim + 1]));
#endif
msg_sed.mtext[i + 2] = 2;
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] =
static_cast<int>(x_data[i * x_dim]);
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] =
static_cast<int>(x_data[i * x_dim + 1]);
if ((skip_chunk_prefill &&
seq_lens_decoder_data[i] < prompt_lens_data[i]) ||
step_idx_data[i] == 0) {
msg_sed.mtext[i + 2] = 0;
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid[%d] skip save mtp output \n", i);
#endif
continue;
} else if (step_idx_data[i] == 1) {
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("bid[%d] save mtp tokens \n", i);
#endif
msg_sed.mtext[i + 2] = 2;
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] =
static_cast<int>(x_data[i * x_dim]);
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] =
static_cast<int>(x_data[i * x_dim + 1]);
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
printf("mtext[%d]:%d. mtext[%d]:%d. \n",
i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
@@ -131,29 +152,60 @@ void MTPSaveFirstToken(const paddle::Tensor& x,
void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
int64_t rank_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
bool save_each_rank,
bool skip_chunk_prefill) {
MTPSaveFirstToken(x,
not_need_stop,
seq_lens_decoder,
prompt_lens,
step_idx,
rank_id,
1,
save_each_rank,
skip_chunk_prefill);
}
void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
bool save_each_rank,
bool skip_chunk_prefill) {
MTPSaveFirstToken(x,
not_need_stop,
seq_lens_decoder,
prompt_lens,
step_idx,
rank_id,
msg_queue_id,
save_each_rank,
skip_chunk_prefill);
}
PD_BUILD_OP(mtp_save_first_token)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
.Inputs(
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
.Attrs({"rank_id: int64_t",
"save_each_rank: bool",
"skip_chunk_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));
PD_BUILD_OP(mtp_save_first_token_dynamic)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Inputs(
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
.Attrs({"rank_id: int64_t",
"msg_queue_id: int",
"save_each_rank: bool",
"skip_chunk_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));
@@ -21,7 +21,7 @@
#include <sys/types.h>
#include "paddle/extension.h"
#define MAX_BSZ 256
#define MAX_BSZ 512
#define MAX_DRAFT_TOKENS 6
struct speculate_msgdata {