[Optim] Robust sync status when preempted happens (#5796)

* [Bug fix] Sync status for caching output cache

* fix

* fix

* fix bug

* fix

* fix

* support xpu

* fix

* fix

* fix

* fix

* fix

* fix ci

* fix ci

* fix xpu

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
chenjian
2026-01-14 12:07:33 +08:00
committed by GitHub
parent 0d1a5e70bc
commit 74d0f1c01f
17 changed files with 442 additions and 354 deletions
@@ -496,11 +496,13 @@ void SpeculateGetLogits(const paddle::Tensor& draft_logits,
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& preempted_idx,
int64_t rank_id,
bool save_each_rank);
void SaveOutMmsgDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& preempted_idx,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank);
@@ -971,6 +973,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&SaveOutMmsgStatic,
py::arg("x"),
py::arg("not_need_stop"),
py::arg("preempted_idx"),
py::arg("rank_id"),
py::arg("save_each_rank"),
"Save output function");
@@ -979,6 +982,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&SaveOutMmsgDynamic,
py::arg("x"),
py::arg("not_need_stop"),
py::arg("preempted_idx"),
py::arg("rank_id"),
py::arg("msg_queue_id"),
py::arg("save_each_rank"),
@@ -39,6 +39,7 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
const paddle::Tensor& logprob_scores, // [bsz, k+1]
const paddle::Tensor& ranks,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& preempted_idx,
int64_t rank_id) {
if (rank_id > 0) {
return;
@@ -52,6 +53,7 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
float* logprob_scores_data = logprob_scores_cpu.data<float>();
int64_t* ranks_data = ranks_cpu.data<int64_t>();
const int32_t* preempted_idx_data = preempted_idx.data<int32_t>();
static struct msgdata msg_sed;
int msg_queue_id = 1;
if (const char* inference_msg_queue_id_env_p =
@@ -121,6 +123,9 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
msg_sed.mtext[offset + 2] = -1;
msg_sed.mtext_f[offset] = 0.0;
}
if (preempted_idx_data[i] == 1) {
msg_sed.mtext[offset + 2] = -9;
}
}
msg_sed.mtext_ranks[i] = (int)ranks_data[i];
}
@@ -142,7 +147,12 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
}
PD_BUILD_STATIC_OP(save_output_topk)
.Inputs({"x", "topk_ids", "logprob_scores", "ranks", "not_need_stop"})
.Inputs({"x",
"topk_ids",
"logprob_scores",
"ranks",
"not_need_stop",
"preempted_idx"})
.Attrs({"rank_id: int64_t"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
@@ -25,8 +25,9 @@
#endif
// #define SAVE_WITH_OUTPUT_DEBUG
void SaveOutMmsg(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
void SaveOutMmsg(const paddle::Tensor &x,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &preempted_idx,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
@@ -34,10 +35,11 @@ void SaveOutMmsg(const paddle::Tensor& x,
return;
}
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
int64_t *x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;
const int32_t *preempted_idx_data = preempted_idx.data<int32_t>();
if (const char* inference_msg_queue_id_env_p =
if (const char *inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
@@ -54,7 +56,7 @@ void SaveOutMmsg(const paddle::Tensor& x,
#endif
}
int inference_msg_id_from_env = 1;
if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
if (const char *inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) {
std::string inference_msg_id_env_str(inference_msg_id_env_p);
inference_msg_id_from_env = std::stoi(inference_msg_id_env_str);
if (inference_msg_id_from_env == 2) {
@@ -94,6 +96,9 @@ void SaveOutMmsg(const paddle::Tensor& x,
msg_sed.mtext[1] = bsz;
for (int i = 2; i < bsz + 2; i++) {
msg_sed.mtext[i] = (int)x_data[i - 2];
if (preempted_idx_data[i - 2] == 1) {
msg_sed.mtext[i] = -9;
}
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "save_output msg data: ";
@@ -108,30 +113,33 @@ void SaveOutMmsg(const paddle::Tensor& x,
return;
}
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
void SaveOutMmsgStatic(const paddle::Tensor &x,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &preempted_idx,
int64_t rank_id,
bool save_each_rank) {
SaveOutMmsg(x, not_need_stop, rank_id, 1, save_each_rank);
SaveOutMmsg(x, not_need_stop, preempted_idx, rank_id, 1, save_each_rank);
}
void SaveOutMmsgDynamic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
void SaveOutMmsgDynamic(const paddle::Tensor &x,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &preempted_idx,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id, save_each_rank);
SaveOutMmsg(
x, not_need_stop, preempted_idx, rank_id, msg_queue_id, save_each_rank);
}
PD_BUILD_STATIC_OP(save_output)
.Inputs({"x", "not_need_stop"})
PD_BUILD_OP(save_output)
.Inputs({"x", "not_need_stop", "preempted_idx"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMmsgStatic));
PD_BUILD_STATIC_OP(save_output_dynamic)
.Inputs({"x", "not_need_stop"})
PD_BUILD_OP(save_output_dynamic)
.Inputs({"x", "not_need_stop", "preempted_idx"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})