[Model Runner] Support overlap schedule (#6259)

This commit is contained in:
sunxin
2026-02-04 10:49:44 +08:00
committed by GitHub
parent 6225439778
commit 9b0a82cfa9
8 changed files with 132 additions and 57 deletions
@@ -449,13 +449,13 @@ def save_output_normal(
share_inputs,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
["preempted_idx"],
["last_preempted_idx"],
)
if sampler_output.logprobs_tensors is None:
save_output(
share_inputs["sampled_token_ids"],
model_output.not_need_stop,
recover_share_inputs_map["preempted_idx"],
recover_share_inputs_map["last_preempted_idx"],
model_output.mp_rank,
save_each_rank,
)
@@ -469,10 +469,10 @@ def save_output_normal(
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
recover_share_inputs_map["preempted_idx"],
recover_share_inputs_map["last_preempted_idx"],
model_output.mp_rank,
)
share_inputs["preempted_idx"][:] = 0
share_inputs["last_preempted_idx"][:] = 0
def post_process_specualate(
@@ -592,7 +592,6 @@ def post_process_specualate(
model_output.seq_lens_decoder,
model_output.step_idx,
)
share_inputs["preempted_idx"][:] = 0
def post_process(
@@ -645,6 +644,8 @@ def post_process(
line_break_id,
enable_entropy,
)
share_inputs["last_preempted_idx"].copy_(share_inputs["preempted_idx"])
share_inputs["preempted_idx"][:] = 0
def step_cuda(
@@ -985,5 +986,3 @@ def post_process_pooling(
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(output_tokens=None, pooler_outputs=pooler_output.outputs)
async_output_queue.put(output)
share_inputs["preempted_idx"][:] = 0