Files
FastDeploy/fastdeploy/model_executor/pre_and_post_process.py
T
Zero Rains 7af95be052 [KSM] fix logz when topk (#7232)
* [KSM] fix logz when topk

* add the logz renormalize
2026-04-08 06:19:40 -07:00

1020 lines
38 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import queue
from typing import Dict, List, Optional, Union
import numpy as np
import paddle
from fastdeploy import envs
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import ZmqIpcClient
from fastdeploy.platforms import current_platform
from fastdeploy.worker.input_batch import (
InputBatch,
recover_batch_index_for_output,
recover_batch_index_for_sampler_output,
)
if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
get_padding_offset,
limit_thinking_content_length,
save_output,
set_stop_value_multi_ends,
step_paddle,
update_inputs,
update_inputs_v1,
)
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import (
get_padding_offset,
save_output,
set_stop_value_multi_ends,
update_inputs,
)
elif current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,
save_output,
set_stop_value_multi_ends,
step_paddle,
update_inputs,
)
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,
limit_thinking_content_length,
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_get_seq_lens_output,
speculate_limit_thinking_content_length,
speculate_save_output,
speculate_save_output_topk,
speculate_set_stop_value_multi_seqs,
speculate_set_value_by_flags_and_idx,
speculate_step_paddle,
speculate_step_reschedule,
speculate_step_system_cache,
speculate_update,
step_paddle,
step_reschedule,
step_system_cache,
update_inputs,
update_inputs_v1,
)
elif current_platform.is_intel_hpu():
pass
else:
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_get_seq_lens_output,
speculate_save_output,
speculate_save_output_topk,
speculate_set_value_by_flags_and_idx,
speculate_step_paddle,
speculate_step_system_cache,
speculate_update,
speculate_set_stop_value_multi_seqs,
step_paddle,
step_system_cache,
update_inputs,
step_reschedule,
update_inputs_v1,
speculate_step_reschedule,
limit_thinking_content_length,
speculate_limit_thinking_content_length,
)
from fastdeploy.model_executor.entropy_utils import (
calculate_logits_entropy,
speculate_calculate_logits_entropy,
)
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
RoutingReplayManager,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, SamplerOutput
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
def pre_process(
token_num_cpu: int,
input_ids: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
speculative_decoding: bool,
draft_tokens: Optional[paddle.Tensor] = None,
seq_lens_encoder: Optional[paddle.Tensor] = None,
seq_lens_decoder: Optional[paddle.Tensor] = None,
):
"""
Preprocessing before embedding.
Args:
input_ids:
seq_lens_this_time:
speculative_decoding:
draft_tokens:
seq_lens_encoder:
Return:
ids_remove_padding:
cum_offsets:
batch_id_per_token:
cu_seqlens_q:
cu_seqlens_k:
"""
specific_platform = current_platform.is_cuda() or current_platform.is_maca() or current_platform.is_iluvatar()
if specific_platform and not speculative_decoding:
# Note(ZKK): This case's code is very simple!
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
input_ids, seq_lens_this_time, None, None, token_num_cpu
)
return (
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
None,
None,
)
# Remove padding
if speculative_decoding:
(
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu)
# compute each batch's output token num
seq_lens_output = speculate_get_seq_lens_output(
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
)
if isinstance(seq_lens_output, list):
seq_lens_output = seq_lens_output[0]
output_token_num = paddle.sum(seq_lens_output)
useless_input_ids = input_ids
_, batch_id_per_token_output, cu_seqlens_q_output, _ = get_padding_offset(
useless_input_ids,
seq_lens_output,
None,
None,
output_token_num.item(),
)
return (
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_output,
batch_id_per_token_output,
)
def _build_stream_transfer_data(
output_tokens: paddle.Tensor,
pooler_outputs: List[PoolingSequenceGroupOutput] = None,
logprobs: Optional[LogprobsTensors] = None,
prompt_logprobs_list: Optional[LogprobsTensors] = None,
sampling_mask: Optional[List[np.ndarray]] = None,
):
"""Split output_tokens and output"""
stream_transfer_datas = []
if output_tokens is not None:
output_tokens = output_tokens.reshape([-1]).numpy()
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])
sampling_mask_list = sampling_mask
for bid, output_token_per_sample in enumerate(output_tokens_lists):
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
)
if logprobs:
stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1)
if prompt_logprobs_list:
stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid]
if sampling_mask_list is not None:
stream_transfer_data.sampling_mask = sampling_mask_list[bid]
stream_transfer_datas.append(stream_transfer_data)
elif pooler_outputs is not None:
for bid, pooler_output in enumerate(pooler_outputs):
if pooler_output is None:
continue
if pooler_output.dtype == paddle.bfloat16:
pooler_output = pooler_output.astype("float32")
pooler_output = pooler_output.numpy()
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, pooler_output=pooler_output, batch_id=bid
)
stream_transfer_datas.append(stream_transfer_data)
return stream_transfer_datas
def post_process_normal(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: InputBatch,
sampling_metadata: SamplingMetadata,
block_size: int = 64,
think_end_id: int = -1,
splitwise_role_is_decode: bool = False,
enable_entropy: bool = False,
routing_replay_manager: RoutingReplayManager = None,
):
"""Post-processing steps after completing a single token generation."""
if think_end_id > 0:
limit_thinking_content_length(
sampler_output.sampled_token_ids,
share_inputs["max_think_lens"],
share_inputs["max_reply_lens"],
share_inputs["step_idx"],
share_inputs["limit_think_status"],
share_inputs["stop_flags"],
share_inputs["eos_token_id"],
share_inputs["inject_token_ids"],
think_end_id,
splitwise_role_is_decode,
)
# 1. Set stop value
paddle.assign(
paddle.where(
model_output.stop_flags,
model_output.step_idx,
model_output.step_idx + 1,
),
model_output.step_idx,
)
length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len)
paddle.assign(
paddle.logical_or(model_output.stop_flags, length_cond),
model_output.stop_flags,
)
if (
current_platform.is_cuda()
or current_platform.is_iluvatar()
or current_platform.is_dcu()
or current_platform.is_maca()
):
set_stop_value_multi_ends(
sampler_output.sampled_token_ids,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.eos_token_id,
model_output.next_tokens,
model_output.pre_ids,
model_output.step_idx,
model_output.stop_token_ids,
model_output.stop_seqs_len,
model_output.min_tokens,
False,
) # multi ends
else:
set_stop_value_multi_ends(
sampler_output.sampled_token_ids,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.eos_token_id,
model_output.next_tokens,
False,
)
if enable_entropy:
calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
# Routing replay
if routing_replay_manager is not None:
# Update host cache
slot_mapping = routing_replay_manager.compute_slot_mapping(
positions=routing_replay_manager.pending_update_positions
)
routing_replay_manager.update_host_cache(
positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping
)
# Put routing of finished requests to store
finished_batch_ids = paddle.flatten(paddle.isin(sampler_output.sampled_token_ids, model_output.eos_token_id))
context_lens = model_output.seq_lens_decoder + model_output.seq_lens_encoder
routing_replay_manager.put_finished_batch(finished_batch_ids=finished_batch_ids, seq_lens_decoder=context_lens)
# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop_device,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
share_inputs["step_seq_lens_decoder"],
share_inputs["prompt_lens"],
sampler_output.sampled_token_ids,
model_output.input_ids,
share_inputs["block_tables"],
model_output.next_tokens,
model_output.is_block_step,
block_size,
)
else:
update_inputs(
model_output.stop_flags,
model_output.not_need_stop_device,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.input_ids,
sampler_output.sampled_token_ids,
model_output.is_block_step,
)
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
# logprobs_tensors.logprobs: [B, max_num_logprobs + 1]
logprobs = sampler_output.logprobs_tensors.logprobs
# logz_per_batch: [B], log(sum(probs in candidate set K)) for each request
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
# Renormalize: log π_masked = log π_full - log Z_K
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1), # broadcast subtraction
paddle.full_like(logprobs, float("-inf")),
)
# Update logprobs_tensors with normalized values
sampler_output.logprobs_tensors = LogprobsTensors(
logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
)
def save_output_normal(
model_output: ModelOutputData,
sampler_output: SamplerOutput,
share_inputs: Dict[str, paddle.Tensor],
async_output_queue: queue.Queue = None,
save_each_rank: bool = False,
sampling_mask_zmq_client: Optional[ZmqIpcClient] = None,
):
# Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
recover_batch_index_for_sampler_output(
sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder
)
output = _build_stream_transfer_data(
sampler_output.sampled_token_ids,
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_list=model_output.prompt_logprobs_list,
sampling_mask=sampler_output.sampling_mask,
)
async_output_queue.put(output)
else:
if sampler_output.logprobs_tensors is None:
recover_share_inputs_map = recover_batch_index_for_output(
share_inputs,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
["last_preempted_idx", "sampled_token_ids"],
)
save_output(
recover_share_inputs_map["sampled_token_ids"],
model_output.not_need_stop,
recover_share_inputs_map["last_preempted_idx"],
model_output.mp_rank,
save_each_rank,
)
else:
recover_share_inputs_map = recover_batch_index_for_output(
share_inputs,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
["last_preempted_idx"],
)
recover_batch_index_for_sampler_output(
sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder
)
save_output_topk(
share_inputs["sampled_token_ids"],
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
recover_share_inputs_map["last_preempted_idx"],
model_output.mp_rank,
)
# Send sampling_mask via ZMQ side-channel when enabled.
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0:
# sampling_mask is List[np.ndarray] of sparse int indices, one array per request.
mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)}
sampling_mask_zmq_client.send_pyobj(mask_dict)
share_inputs["last_preempted_idx"][:] = 0
def post_process_specualate(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: InputBatch,
sampling_metadata: SamplingMetadata,
save_each_rank: bool = False,
skip_save_output: bool = False,
think_end_id: int = -1,
splitwise_role_is_decode: bool = False,
enable_entropy: bool = False,
routing_replay_manager: RoutingReplayManager = None,
sampling_mask_zmq_client: ZmqIpcClient = None,
):
if think_end_id > 0:
speculate_limit_thinking_content_length(
share_inputs["accept_tokens"],
share_inputs["max_think_lens"],
share_inputs["max_reply_lens"],
share_inputs["step_idx"],
share_inputs["limit_think_status"],
share_inputs["accept_num"],
share_inputs["stop_flags"],
share_inputs["eos_token_id"],
share_inputs["inject_token_ids"],
think_end_id,
splitwise_role_is_decode,
)
speculate_set_stop_value_multi_seqs(
model_output.accept_tokens,
model_output.accept_num,
model_output.pre_ids,
model_output.step_idx,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.stop_token_ids,
model_output.stop_seqs_len,
model_output.eos_token_id,
model_output.min_tokens,
)
if enable_entropy:
speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
# Routing replay
if routing_replay_manager is not None:
# Update host cache
slot_mapping = routing_replay_manager.compute_slot_mapping(
positions=routing_replay_manager.pending_update_positions
)
routing_replay_manager.update_host_cache(
positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping
)
# Put routing of finished requests to store
last_accept_token = paddle.full_like(model_output.accept_tokens, -1)
col_indices = paddle.arange(model_output.accept_tokens.shape[1], dtype=model_output.accept_num.dtype)
mask = col_indices < paddle.unsqueeze(model_output.accept_num, 1)
last_accept_token[mask] = model_output.accept_tokens[mask]
eos_tokens_flat = model_output.eos_token_id.flatten()
isin_mask = paddle.isin(last_accept_token, eos_tokens_flat)
finished_batch_ids = isin_mask.any(axis=-1)
context_lens = model_output.seq_lens_encoder + model_output.seq_lens_decoder
routing_replay_manager.put_finished_batch(
finished_batch_ids=finished_batch_ids,
seq_lens_decoder=context_lens,
)
speculate_update(
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.not_need_stop,
model_output.draft_tokens,
model_output.actual_draft_token_num,
model_output.accept_tokens,
model_output.accept_num,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.is_block_step,
model_output.mask_rollback,
)
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
logprobs = sampler_output.logprobs_tensors.logprobs
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf"))
)
sampler_output.logprobs_tensors = LogprobsTensors(
logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
)
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
recover_model_output_map = recover_batch_index_for_output(
model_output,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
["accept_tokens", "accept_num", "seq_lens_decoder", "prompt_lens"],
)
recover_share_inputs = recover_batch_index_for_output(
share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"]
)
speculate_save_output(
recover_model_output_map["accept_tokens"],
recover_model_output_map["accept_num"],
model_output.not_need_stop,
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_share_inputs["preempted_idx"],
model_output.mp_rank,
save_each_rank,
bool(envs.ENABLE_V1_KVCACHE_SCHEDULER),
)
else:
recover_batch_index_for_sampler_output(
sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder
)
recover_model_output_map = recover_batch_index_for_output(
model_output,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
["seq_lens_decoder", "prompt_lens"],
)
recover_share_inputs = recover_batch_index_for_output(
share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"]
)
speculate_save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
sampler_output.token_num_per_batch,
sampler_output.cu_batch_token_offset,
model_output.not_need_stop,
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_share_inputs["preempted_idx"],
3, # mtype
model_output.mp_rank,
save_each_rank,
)
# Send sampling_mask via ZMQ side-channel when enabled.
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0:
# sampling_mask is List[np.ndarray] of sparse int indices, length = total_accepted_tokens.
# Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req).
real_bsz = model_output.accept_num.shape[0]
accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
mask_dict = {}
offset = 0
for i, n in enumerate(accept_nums):
n = int(n)
if n > 0:
# List of n sparse index arrays, one per accepted token
mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]]
offset += n
sampling_mask_zmq_client.send_pyobj(mask_dict)
# Update pre_ids through accept tokens
speculate_set_value_by_flags_and_idx(
model_output.pre_ids,
model_output.accept_tokens,
model_output.accept_num,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.step_idx,
)
def post_process(
sampler_or_pooler_output: Union[SamplerOutput, PoolerOutput],
model_output: ModelOutputData,
share_inputs: InputBatch,
sampling_metadata: SamplingMetadata = None,
block_size: int = 64,
save_each_rank: bool = False,
speculative_decoding: bool = False,
skip_save_output: bool = False,
async_output_queue: queue.Queue = None,
think_end_id: int = -1,
splitwise_role_is_decode: bool = False,
enable_entropy: bool = False,
routing_replay_manager: RoutingReplayManager = None,
sampling_mask_zmq_client: ZmqIpcClient = None,
) -> None:
"""Post-processing steps after completing a single token generation."""
if isinstance(sampler_or_pooler_output, PoolerOutput):
post_process_pooling(
sampler_or_pooler_output,
model_output,
share_inputs,
block_size,
save_each_rank,
skip_save_output,
async_output_queue,
routing_replay_manager,
)
else:
if speculative_decoding:
post_process_specualate(
sampler_or_pooler_output,
model_output,
share_inputs,
sampling_metadata,
save_each_rank,
skip_save_output,
think_end_id,
splitwise_role_is_decode,
enable_entropy,
routing_replay_manager,
sampling_mask_zmq_client,
)
else:
post_process_normal(
sampler_or_pooler_output,
model_output,
share_inputs,
sampling_metadata,
block_size,
think_end_id,
splitwise_role_is_decode,
enable_entropy,
routing_replay_manager,
)
share_inputs["last_preempted_idx"].copy_(share_inputs["preempted_idx"])
share_inputs["preempted_idx"][:] = 0
def step_cuda(
share_inputs: InputBatch,
block_size: int,
enc_dec_block_num: int,
speculative_config: SpeculativeConfig,
enable_prefix_caching: bool = False,
) -> None:
"""
TODO(gongshaotian): normalization name
"""
if speculative_config.method is not None:
if DISABLE_RECOVER:
speculate_step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
if enable_prefix_caching:
speculate_step_system_cache(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
speculate_step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
if DISABLE_RECOVER:
step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
else:
if enable_prefix_caching:
step_system_cache(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
else:
step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
def rebuild_padding(
tmp_out: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
seq_len_this_time: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
batch_id_per_token_output: Optional[paddle.Tensor] = None,
cu_seqlens_q_output: Optional[paddle.Tensor] = None,
first_token_out: Optional[paddle.Tensor] = None,
enable_logprob: Optional[bool] = False,
):
"""
Args:
Returns:
"""
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
enable_logprob,
)
elif current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
)
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
first_token_out,
enable_logprob,
)
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
)
elif current_platform.is_cpu():
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
hidden_states = rebuild_padding_cpu(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
)
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cu_seqlens_q,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token_output,
cu_seqlens_q_output,
first_token_out,
enable_logprob,
)
else:
raise RuntimeError("Not supported platform")
return hidden_states
def post_process_pooling(
pooler_output: PoolerOutput,
model_output: ModelOutputData,
share_inputs: InputBatch,
block_size: int = 64,
save_each_rank: bool = False,
skip_save_output: bool = False,
async_output_queue: queue.Queue = None,
routing_replay_manager: RoutingReplayManager = None,
) -> None:
paddle.assign(
paddle.where(
model_output.stop_flags,
model_output.step_idx,
model_output.step_idx + 1,
),
model_output.step_idx,
)
length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len)
paddle.assign(
paddle.logical_or(model_output.stop_flags, length_cond),
model_output.stop_flags,
)
# Routing replay
if routing_replay_manager is not None:
raise NotImplementedError
with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
dummy_sampled_tokens = paddle.full_like(model_output.next_tokens, -1, dtype="int64")
paddle.assign(
paddle.ones_like(model_output.stop_flags, dtype="bool"),
model_output.stop_flags,
)
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
share_inputs["step_seq_lens_decoder"],
share_inputs["prompt_lens"],
dummy_sampled_tokens,
model_output.input_ids,
share_inputs["block_tables"],
model_output.next_tokens,
model_output.is_block_step,
block_size,
)
if not skip_save_output:
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(output_tokens=None, pooler_outputs=pooler_output.outputs)
async_output_queue.put(output)