mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Speculative Decoding with PD (#5856)
* [XPU] Speculative Decoding with PD * fix post process * share kv cache sender * support speculate decoding step system cache * support speculate decoding step system cache --------- Co-authored-by: root <root@gajl-bbc-onlinec-com-1512108.gajl.baidu.com>
This commit is contained in:
@@ -18,29 +18,35 @@
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define MAX_BSZ 256
|
||||
|
||||
// #define SAVE_WITH_OUTPUT_DEBUG
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
struct msgdata {
|
||||
long mtype; // NOLINT
|
||||
int mtext[2 + MAX_BSZ +
|
||||
MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens
|
||||
};
|
||||
#include "speculate_msg.h"
|
||||
|
||||
void MTPSaveFirstToken(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
bool save_each_rank,
|
||||
bool skip_chunk_prefill) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
int x_dim = x.shape()[1];
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t* x_data = x_cpu.data<int64_t>();
|
||||
static struct msgdata msg_sed;
|
||||
|
||||
auto seq_lens_decoder_cpu =
|
||||
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
|
||||
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||
|
||||
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
|
||||
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
|
||||
|
||||
auto step_idx_cpu = step_idx.copy_to(paddle::CPUPlace(), true);
|
||||
int64_t* step_idx_data = step_idx_cpu.data<int64_t>();
|
||||
|
||||
static struct speculate_msgdata msg_sed;
|
||||
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
|
||||
@@ -53,6 +59,7 @@ void MTPSaveFirstToken(const paddle::Tensor& x,
|
||||
#endif
|
||||
msg_queue_id = inference_msg_queue_id_from_env;
|
||||
}
|
||||
|
||||
static key_t key = ftok("./", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
@@ -98,11 +105,25 @@ void MTPSaveFirstToken(const paddle::Tensor& x,
|
||||
static_cast<int>(x_data[i * x_dim]),
|
||||
static_cast<int>(x_data[i * x_dim + 1]));
|
||||
#endif
|
||||
msg_sed.mtext[i + 2] = 2;
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] =
|
||||
static_cast<int>(x_data[i * x_dim]);
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] =
|
||||
static_cast<int>(x_data[i * x_dim + 1]);
|
||||
if ((skip_chunk_prefill &&
|
||||
seq_lens_decoder_data[i] < prompt_lens_data[i]) ||
|
||||
step_idx_data[i] == 0) {
|
||||
msg_sed.mtext[i + 2] = 0;
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("bid[%d] skip save mtp output \n", i);
|
||||
#endif
|
||||
continue;
|
||||
} else if (step_idx_data[i] == 1) {
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("bid[%d] save mtp tokens \n", i);
|
||||
#endif
|
||||
msg_sed.mtext[i + 2] = 2;
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] =
|
||||
static_cast<int>(x_data[i * x_dim]);
|
||||
msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] =
|
||||
static_cast<int>(x_data[i * x_dim + 1]);
|
||||
}
|
||||
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
printf("mtext[%d]:%d. mtext[%d]:%d. \n",
|
||||
i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ,
|
||||
@@ -131,29 +152,60 @@ void MTPSaveFirstToken(const paddle::Tensor& x,
|
||||
|
||||
void MTPSaveFirstTokenStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank) {
|
||||
MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank);
|
||||
bool save_each_rank,
|
||||
bool skip_chunk_prefill) {
|
||||
MTPSaveFirstToken(x,
|
||||
not_need_stop,
|
||||
seq_lens_decoder,
|
||||
prompt_lens,
|
||||
step_idx,
|
||||
rank_id,
|
||||
1,
|
||||
save_each_rank,
|
||||
skip_chunk_prefill);
|
||||
}
|
||||
|
||||
void MTPSaveFirstTokenDynamic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& prompt_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
|
||||
bool save_each_rank,
|
||||
bool skip_chunk_prefill) {
|
||||
MTPSaveFirstToken(x,
|
||||
not_need_stop,
|
||||
seq_lens_decoder,
|
||||
prompt_lens,
|
||||
step_idx,
|
||||
rank_id,
|
||||
msg_queue_id,
|
||||
save_each_rank,
|
||||
skip_chunk_prefill);
|
||||
}
|
||||
|
||||
PD_BUILD_OP(mtp_save_first_token)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
|
||||
.Inputs(
|
||||
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
|
||||
.Attrs({"rank_id: int64_t",
|
||||
"save_each_rank: bool",
|
||||
"skip_chunk_prefill: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic));
|
||||
|
||||
PD_BUILD_OP(mtp_save_first_token_dynamic)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
|
||||
.Inputs(
|
||||
{"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"})
|
||||
.Attrs({"rank_id: int64_t",
|
||||
"msg_queue_id: int",
|
||||
"save_each_rank: bool",
|
||||
"skip_chunk_prefill: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic));
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
#include <sys/types.h>
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#define MAX_BSZ 256
|
||||
#define MAX_BSZ 512
|
||||
#define MAX_DRAFT_TOKENS 6
|
||||
|
||||
struct speculate_msgdata {
|
||||
|
||||
Reference in New Issue
Block a user