mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] fix dp4 (#5946)
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
#include "msg_utils.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
// #define GET_OUTPUT_DEBUG
|
||||
void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
@@ -114,16 +115,26 @@ void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
|
||||
GetOutput(x, rank_id, wait_flag, 1);
|
||||
}
|
||||
|
||||
void GetOutputDynamic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
GetOutput(x, rank_id, wait_flag, msg_queue_id);
|
||||
}
|
||||
|
||||
void GetOutputEPStatic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
GetOutput(x, rank_id, wait_flag, 1);
|
||||
}
|
||||
|
||||
void GetOutputDynamic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id) {
|
||||
void GetOutputEPDynamic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id) {
|
||||
GetOutput(x, rank_id, wait_flag, msg_queue_id);
|
||||
}
|
||||
|
||||
@@ -153,4 +164,4 @@ PD_BUILD_OP(get_output_ep_dynamic)
|
||||
.Attrs({"rank_id: int64_t", "wait_flag: bool", "msg_queue_id: int"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetOutputDynamic));
|
||||
.SetKernelFn(PD_KERNEL(GetOutputEPDynamic));
|
||||
|
||||
@@ -362,15 +362,20 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
|
||||
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag);
|
||||
|
||||
void GetOutputEPStatic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag);
|
||||
|
||||
void GetOutputDynamic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id);
|
||||
|
||||
void GetOutputEPStatic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag);
|
||||
|
||||
void GetOutputEPDynamic(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag,
|
||||
int msg_queue_id);
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& token_num,
|
||||
@@ -842,13 +847,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("wait_flag"),
|
||||
"get_output function");
|
||||
|
||||
m.def("get_output_ep",
|
||||
&GetOutputEPStatic,
|
||||
py::arg("x"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("wait_flag"),
|
||||
"get_output_ep function");
|
||||
|
||||
m.def("get_output_dynamic",
|
||||
&GetOutputDynamic,
|
||||
py::arg("x"),
|
||||
@@ -857,8 +855,15 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("msg_queue_id"),
|
||||
"get_output_dynamic function");
|
||||
|
||||
m.def("get_output_ep",
|
||||
&GetOutputEPStatic,
|
||||
py::arg("x"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("wait_flag"),
|
||||
"get_output_ep function");
|
||||
|
||||
m.def("get_output_ep_dynamic",
|
||||
&GetOutputDynamic,
|
||||
&GetOutputEPDynamic,
|
||||
py::arg("x"),
|
||||
py::arg("rank_id"),
|
||||
py::arg("wait_flag"),
|
||||
|
||||
@@ -20,13 +20,17 @@
|
||||
#include "msg_utils.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
// #define SAVE_WITH_OUTPUT_DEBUG
|
||||
void SaveOutMmsg(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
@@ -119,14 +123,14 @@ void SaveOutMmsgDynamic(const paddle::Tensor& x,
|
||||
SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
|
||||
}
|
||||
|
||||
PD_BUILD_OP(save_output)
|
||||
PD_BUILD_STATIC_OP(save_output)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SaveOutMmsgStatic));
|
||||
|
||||
PD_BUILD_OP(save_output_dynamic)
|
||||
PD_BUILD_STATIC_OP(save_output_dynamic)
|
||||
.Inputs({"x", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
|
||||
.Outputs({"x_out"})
|
||||
|
||||
Reference in New Issue
Block a user