[XPU] Support EP4TP1 in pd disaggregation (#5860)

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2026-01-06 15:25:36 +08:00
committed by GitHub
parent e99ec4c9d5
commit 733014bf32
4 changed files with 355 additions and 14 deletions
+18 -12
View File
@@ -20,11 +20,11 @@
#include "msg_utils.h"
#include "paddle/extension.h"
void GetOutputKVSignal(const paddle::Tensor &x,
void GetOutputKVSignal(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
int msg_queue_id = 1024;
if (const char *msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) {
if (const char* msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string msg_que_str(msg_que_str_tmp);
msg_queue_id = std::stoi(msg_que_str);
}
@@ -33,7 +33,7 @@ void GetOutputKVSignal(const paddle::Tensor &x,
static key_t key = ftok("/opt/", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
int *out_data = const_cast<int *>(x.data<int>());
int* out_data = const_cast<int*>(x.data<int>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
@@ -53,15 +53,12 @@ void GetOutputKVSignal(const paddle::Tensor &x,
return;
}
void GetOutput(const paddle::Tensor &x,
void GetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
if (rank_id > 0) {
return;
}
static struct msgdata msg_rcv;
if (const char *inference_msg_queue_id_env_p =
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
@@ -82,7 +79,7 @@ void GetOutput(const paddle::Tensor &x,
std::cout << "get_output wait_flag: " << wait_flag << std::endl;
#endif
int64_t *out_data = const_cast<int64_t *>(x.data<int64_t>());
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
@@ -110,11 +107,20 @@ void GetOutput(const paddle::Tensor &x,
return;
}
void GetOutputStatic(const paddle::Tensor &x, int64_t rank_id, bool wait_flag) {
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
if (rank_id > 0) {
return;
}
GetOutput(x, rank_id, wait_flag, 1);
}
void GetOutputDynamic(const paddle::Tensor &x,
void GetOutputEPStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
GetOutput(x, rank_id, wait_flag, 1);
}
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
@@ -140,7 +146,7 @@ PD_BUILD_OP(get_output_ep)
.Attrs({"rank_id: int64_t", "wait_flag: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutputStatic));
.SetKernelFn(PD_KERNEL(GetOutputEPStatic));
PD_BUILD_OP(get_output_ep_dynamic)
.Inputs({"x"})
+5 -1
View File
@@ -362,6 +362,10 @@ std::vector<paddle::Tensor> GetInferParam(
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag);
void GetOutputEPStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag);
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
@@ -839,7 +843,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
"get_output function");
m.def("get_output_ep",
&GetOutputStatic,
&GetOutputEPStatic,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),