mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL][Cherry-Pick] Support Fully Async and PrefixCache (#6599)
* cherry-pick Support Fully Async and PrefixCache step 1 * copy routing_indices_cache.py from 2.4 * cherry-pick [RL] R3 Fix the bug for determining the end of a request (#6388) * cherry-pick [RL] Clear Requests status of R3 (#6569) * delete code * fix rename bug * fix status shape bug * fix ci
This commit is contained in:
@@ -106,6 +106,9 @@ from fastdeploy.model_executor.entropy_utils import (
|
||||
calculate_logits_entropy,
|
||||
speculate_calculate_logits_entropy,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||
RoutingReplayManager,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||
@@ -256,6 +259,7 @@ def post_process_normal(
|
||||
think_end_id: int = -1,
|
||||
splitwise_role_is_decode: bool = False,
|
||||
enable_entropy: bool = False,
|
||||
routing_replay_manager: RoutingReplayManager = None,
|
||||
):
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
if think_end_id > 0:
|
||||
@@ -319,6 +323,21 @@ def post_process_normal(
|
||||
if enable_entropy:
|
||||
calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
|
||||
|
||||
# Routing replay
|
||||
if routing_replay_manager is not None:
|
||||
# Update host cache
|
||||
slot_mapping = routing_replay_manager.compute_slot_mapping(
|
||||
positions=routing_replay_manager.pending_update_positions
|
||||
)
|
||||
routing_replay_manager.update_host_cache(
|
||||
positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping
|
||||
)
|
||||
|
||||
# Put routing of finished requests to store
|
||||
finished_batch_ids = paddle.flatten(paddle.isin(sampler_output.sampled_token_ids, model_output.eos_token_id))
|
||||
context_lens = model_output.seq_lens_decoder + model_output.seq_lens_encoder
|
||||
routing_replay_manager.put_finished_batch(finished_batch_ids=finished_batch_ids, seq_lens_decoder=context_lens)
|
||||
|
||||
# 2. Update the input buffer of the model
|
||||
with paddle.framework._no_check_dy2st_diff():
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
@@ -425,6 +444,7 @@ def post_process_specualate(
|
||||
enable_entropy: bool = False,
|
||||
is_naive_mode: bool = False,
|
||||
prefill_one_step_stop: bool = False,
|
||||
routing_replay_manager: RoutingReplayManager = None,
|
||||
):
|
||||
if think_end_id > 0:
|
||||
speculate_limit_thinking_content_length(
|
||||
@@ -457,6 +477,30 @@ def post_process_specualate(
|
||||
if enable_entropy:
|
||||
speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
|
||||
|
||||
# Routing replay
|
||||
if routing_replay_manager is not None:
|
||||
# Update host cache
|
||||
slot_mapping = routing_replay_manager.compute_slot_mapping(
|
||||
positions=routing_replay_manager.pending_update_positions
|
||||
)
|
||||
routing_replay_manager.update_host_cache(
|
||||
positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping
|
||||
)
|
||||
|
||||
# Put routing of finished requests to store
|
||||
last_accept_token = paddle.full_like(model_output.accept_tokens, -1)
|
||||
col_indices = paddle.arange(model_output.accept_tokens.shape[1], dtype=model_output.accept_num.dtype)
|
||||
mask = col_indices < paddle.unsqueeze(model_output.accept_num, 1)
|
||||
last_accept_token[mask] = model_output.accept_tokens[mask]
|
||||
eos_tokens_flat = model_output.eos_token_id.flatten()
|
||||
isin_mask = paddle.isin(last_accept_token, eos_tokens_flat)
|
||||
finished_batch_ids = isin_mask.any(axis=-1)
|
||||
context_lens = model_output.seq_lens_encoder + model_output.seq_lens_decoder
|
||||
routing_replay_manager.put_finished_batch(
|
||||
finished_batch_ids=finished_batch_ids,
|
||||
seq_lens_decoder=context_lens,
|
||||
)
|
||||
|
||||
# Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx
|
||||
# into a single kernel launch. For MTP/ngram paths, verify_draft_tokens has already
|
||||
# handled EOS/max_dec_len detection (replacing tokens + updating step_idx), so
|
||||
@@ -550,6 +594,7 @@ def post_process(
|
||||
enable_entropy: bool = False,
|
||||
is_naive_mode: bool = False,
|
||||
prefill_one_step_stop: bool = False,
|
||||
routing_replay_manager: RoutingReplayManager = None,
|
||||
) -> None:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
|
||||
@@ -562,6 +607,7 @@ def post_process(
|
||||
save_each_rank,
|
||||
skip_save_output,
|
||||
async_output_queue,
|
||||
routing_replay_manager,
|
||||
)
|
||||
else:
|
||||
if speculative_decoding:
|
||||
@@ -577,6 +623,7 @@ def post_process(
|
||||
enable_entropy,
|
||||
is_naive_mode,
|
||||
prefill_one_step_stop,
|
||||
routing_replay_manager,
|
||||
)
|
||||
else:
|
||||
post_process_normal(
|
||||
@@ -588,6 +635,7 @@ def post_process(
|
||||
think_end_id,
|
||||
splitwise_role_is_decode,
|
||||
enable_entropy,
|
||||
routing_replay_manager,
|
||||
)
|
||||
share_inputs["last_preempted_idx"].copy_(share_inputs["preempted_idx"])
|
||||
share_inputs["preempted_idx"][:] = 0
|
||||
@@ -883,6 +931,7 @@ def post_process_pooling(
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
async_output_queue: queue.Queue = None,
|
||||
routing_replay_manager: RoutingReplayManager = None,
|
||||
) -> None:
|
||||
|
||||
paddle.assign(
|
||||
@@ -900,6 +949,10 @@ def post_process_pooling(
|
||||
model_output.stop_flags,
|
||||
)
|
||||
|
||||
# Routing replay
|
||||
if routing_replay_manager is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
with paddle.framework._no_check_dy2st_diff():
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
dummy_sampled_tokens = paddle.full_like(model_output.next_tokens, -1, dtype="int64")
|
||||
|
||||
Reference in New Issue
Block a user