[XPU] fix dp4 (#5946)

This commit is contained in:
zhupengyang
2026-01-09 20:36:53 +08:00
committed by GitHub
parent 384ffd6952
commit 9db48ecb34
8 changed files with 152 additions and 46 deletions
+16 -5
View File
@@ -20,6 +20,7 @@
#include "msg_utils.h"
#include "paddle/extension.h"
// #define GET_OUTPUT_DEBUG
void GetOutputKVSignal(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
@@ -114,16 +115,26 @@ void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
GetOutput(x, rank_id, wait_flag, 1);
}
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
if (rank_id > 0) {
return;
}
GetOutput(x, rank_id, wait_flag, msg_queue_id);
}
void GetOutputEPStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
GetOutput(x, rank_id, wait_flag, 1);
}
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
void GetOutputEPDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
GetOutput(x, rank_id, wait_flag, msg_queue_id);
}
@@ -153,4 +164,4 @@ PD_BUILD_OP(get_output_ep_dynamic)
.Attrs({"rank_id: int64_t", "wait_flag: bool", "msg_queue_id: int"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutputDynamic));
.SetKernelFn(PD_KERNEL(GetOutputEPDynamic));
+17 -12
View File
@@ -362,15 +362,20 @@ std::vector<paddle::Tensor> GetInferParam(
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag);
void GetOutputEPStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag);
void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id);
void GetOutputEPStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag);
void GetOutputEPDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
@@ -842,13 +847,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("wait_flag"),
"get_output function");
m.def("get_output_ep",
&GetOutputEPStatic,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
"get_output_ep function");
m.def("get_output_dynamic",
&GetOutputDynamic,
py::arg("x"),
@@ -857,8 +855,15 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("msg_queue_id"),
"get_output_dynamic function");
m.def("get_output_ep",
&GetOutputEPStatic,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
"get_output_ep function");
m.def("get_output_ep_dynamic",
&GetOutputDynamic,
&GetOutputEPDynamic,
py::arg("x"),
py::arg("rank_id"),
py::arg("wait_flag"),
@@ -20,13 +20,17 @@
#include "msg_utils.h"
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
// #define SAVE_WITH_OUTPUT_DEBUG
void SaveOutMmsg(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
if (!save_each_rank && rank_id > 0) {
if (rank_id > 0) {
return;
}
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
@@ -119,14 +123,14 @@ void SaveOutMmsgDynamic(const paddle::Tensor& x,
SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
}
PD_BUILD_OP(save_output)
PD_BUILD_STATIC_OP(save_output)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMmsgStatic));
PD_BUILD_OP(save_output_dynamic)
PD_BUILD_STATIC_OP(save_output_dynamic)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Outputs({"x_out"})