[Optim] Robust sync status when preempted happens (#5796)

* [Bug fix] Sync status for caching output cache

* fix

* fix

* fix bug

* fix

* fix

* support xpu

* fix

* fix

* fix

* fix

* fix

* fix ci

* fix ci

* fix xpu

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
chenjian
2026-01-14 12:07:33 +08:00
committed by GitHub
parent 0d1a5e70bc
commit 74d0f1c01f
17 changed files with 442 additions and 354 deletions
@@ -430,6 +430,7 @@ def post_process_normal(
save_output(
sampler_output.sampled_token_ids,
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
save_each_rank,
)
@@ -440,6 +441,7 @@ def post_process_normal(
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
)
@@ -505,6 +507,7 @@ def post_process_specualate(
model_output.not_need_stop,
model_output.seq_lens_decoder,
model_output.prompt_lens,
share_inputs["preempted_idx"],
model_output.mp_rank,
save_each_rank,
envs.ENABLE_V1_KVCACHE_SCHEDULER,
@@ -520,6 +523,7 @@ def post_process_specualate(
model_output.not_need_stop,
model_output.seq_lens_decoder,
model_output.prompt_lens,
share_inputs["preempted_idx"],
3, # mtype
model_output.mp_rank,
save_each_rank,
@@ -592,6 +596,7 @@ def post_process(
line_break_id,
enable_entropy,
)
share_inputs["preempted_idx"][:] = 0
def step_cuda(