[Model Runner] Refactor execute_model for GPU async scheduling (#6176)

This commit is contained in:
sunxin
2026-01-28 14:19:33 +08:00
committed by GitHub
parent ce06c6dfb3
commit 27f8799f04
9 changed files with 188 additions and 62 deletions
@@ -316,9 +316,6 @@ def post_process_normal(
share_inputs: Dict[str, paddle.Tensor],
sampling_metadata: SamplingMetadata,
block_size: int = 64,
save_each_rank: bool = False,
skip_save_output: bool = False,
async_output_queue: queue.Queue = None,
think_end_id: int = -1,
line_break_id: int = -1,
enable_entropy: bool = False,
@@ -388,7 +385,7 @@ def post_process_normal(
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
model_output.not_need_stop_device,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
@@ -404,7 +401,7 @@ def post_process_normal(
else:
update_inputs(
model_output.stop_flags,
model_output.not_need_stop,
model_output.not_need_stop_device,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
@@ -412,36 +409,45 @@ def post_process_normal(
sampler_output.sampled_token_ids,
model_output.is_block_step,
)
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(
sampler_output.sampled_token_ids,
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_list=model_output.prompt_logprobs_list,
)
async_output_queue.put(output)
def save_output_normal(
model_output: ModelOutputData,
sampler_output: SamplerOutput,
share_inputs: Dict[str, paddle.Tensor],
async_output_queue: queue.Queue = None,
save_each_rank: bool = False,
):
# Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(
sampler_output.sampled_token_ids,
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_list=model_output.prompt_logprobs_list,
)
async_output_queue.put(output)
else:
if sampler_output.logprobs_tensors is None:
save_output(
share_inputs["sampled_token_ids"],
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
save_each_rank,
)
else:
if sampler_output.logprobs_tensors is None:
save_output(
sampler_output.sampled_token_ids,
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
save_each_rank,
)
else:
save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
)
save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
)
share_inputs["preempted_idx"][:] = 0
def post_process_specualate(
@@ -540,6 +546,7 @@ def post_process_specualate(
model_output.seq_lens_decoder,
model_output.step_idx,
)
share_inputs["preempted_idx"][:] = 0
def post_process(
@@ -588,14 +595,10 @@ def post_process(
share_inputs,
sampling_metadata,
block_size,
save_each_rank,
skip_save_output,
async_output_queue,
think_end_id,
line_break_id,
enable_entropy,
)
share_inputs["preempted_idx"][:] = 0
def step_cuda(
@@ -936,3 +939,5 @@ def post_process_pooling(
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(output_tokens=None, pooler_outputs=pooler_output.outputs)
async_output_queue.put(output)
share_inputs["preempted_idx"][:] = 0