""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import paddle from paddleformers.utils.log import logger from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.logits_processor import build_logits_processors from fastdeploy.platforms import current_platform class InputBatch: def __getitem__(self, key): """Support dictionary-style attribute access""" if hasattr(self, key): return getattr(self, key) raise KeyError(f"'{key}' is not a valid attribute of InputBatch") def __setitem__(self, key, value): """Support dictionary-style attribute setting, overwrite if exists, add if not exists""" setattr(self, key, value) def __contains__(self, key): """Support 'in' operator to check if attribute exists""" return hasattr(self, key) def update(self, values: dict): """Batch update attributes, similar to dict's update method""" for key, value in values.items(): setattr(self, key, value) def get(self, key, default=None): if hasattr(self, key): return getattr(self, key) return default def pop(self, key, default=None): """ Pop an attribute, similar to dict's pop method Args: key: Name of the attribute to pop default: Default value to return if attribute does not exist Returns: Popped attribute value, or default if attribute doesn't exist and default is provided """ if hasattr(self, key): value = getattr(self, key) delattr(self, key) return value elif default is not None: return default else: raise KeyError(f"'{key}' is not a valid attribute of InputBatch") def __delitem__(self, key): """ Delete an attribute using del operator Args: key: Name of the attribute to delete Raises: KeyError: If attribute does not exist """ if hasattr(self, key): delattr(self, key) else: raise KeyError(f"'{key}' is not a valid attribute of InputBatch") def __init__(self, fd_config: FDConfig) -> None: """ Initialize all share buffers for model inputs. """ self.num_running_requests = 0 self.running_requests_ids = [] self.fd_config: FDConfig = fd_config self.model_config: ModelConfig = fd_config.model_config self.cache_config: CacheConfig = fd_config.cache_config self.scheduler_config = fd_config.scheduler_config self.speculative_config: SpeculativeConfig = fd_config.speculative_config self.speculative_decoding = self.speculative_config.method is not None self.enable_mm = self.model_config.enable_mm self.enable_expert_parallel = fd_config.parallel_config.enable_expert_parallel self.index_to_batch_id = {} self.enable_pd_reorder = False # Qwen vl etc. do not support mm_max_tokens_per_item now if self.enable_mm and self.model_config.mm_max_tokens_per_item is None: self.max_chunk_tokens = self.model_config.max_model_len else: self.max_chunk_tokens = self.fd_config.get_max_chunk_tokens(self.model_config.mm_max_tokens_per_item) def init_share_inputs(self): max_num_seqs = self.scheduler_config.max_num_seqs self.token_ids_all = paddle.full( [max_num_seqs, self.model_config.max_model_len], -1, dtype="int64", ) self.input_ids = paddle.full( [max_num_seqs, self.max_chunk_tokens], self.model_config.pad_token_id, dtype="int64", ) self.eos_token_id = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64") self.top_p = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") self.top_k = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.top_k_list = [0] * max_num_seqs self.min_p = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") self.min_p_list = [0.0] * max_num_seqs self.temperature = paddle.full([max_num_seqs, 1], self.model_config.temperature, dtype="float32") self.penalty_score = paddle.full([max_num_seqs, 1], self.model_config.penalty_score, dtype="float32") self.frequency_score = paddle.full( [max_num_seqs, 1], self.model_config.frequency_score, dtype="float32", ) self.presence_score = paddle.full([max_num_seqs, 1], self.model_config.presence_score, dtype="float32") self.temp_scaled_logprobs = paddle.full([max_num_seqs, 1], False, dtype="bool") self.top_p_normalized_logprobs = paddle.full([max_num_seqs, 1], False, dtype="bool") self.min_dec_len = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") self.max_dec_len = paddle.full([max_num_seqs, 1], self.model_config.max_model_len, dtype="int64") self.seq_lens_this_time_buffer = paddle.full([max_num_seqs], 0, dtype="int32") self.seq_lens_this_time = paddle.full([max_num_seqs], 0, dtype="int32") self.seq_lens_encoder = paddle.full([max_num_seqs], 0, dtype="int32") self.seq_lens_decoder = paddle.full([max_num_seqs], 0, dtype="int32") self.step_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.step_seq_lens_decoder = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.prompt_lens = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.step_idx = paddle.full([max_num_seqs, 1], 0, dtype="int64") if current_platform.is_maca(): self.not_need_stop = paddle.full([1], False, dtype="bool", device="cpu") self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64", device="cpu") self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32", device="cpu") self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool", device="cpu") else: self.not_need_stop = paddle.full([1], False, dtype="bool").cpu() self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory() self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory() self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory() self.not_need_stop_device = paddle.full([1], False, dtype="bool") self.stop_flags = paddle.full([max_num_seqs, 1], True, dtype="bool") self.bad_tokens = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64") self.bad_tokens_len = paddle.full([max_num_seqs], 1, dtype="int64") self.next_tokens = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.is_block_step = paddle.full([max_num_seqs], False, dtype="bool") self.is_chunk_step = paddle.full([max_num_seqs], False, dtype="bool", device="cpu") self.encoder_block_lens = paddle.full([max_num_seqs], 0, dtype="int32") self.step_block_list = paddle.full([max_num_seqs], -1, dtype="int32") self.step_lens = paddle.full([1], 0, dtype="int32") self.recover_block_list = paddle.full([max_num_seqs], -1, dtype="int32") self.recover_lens = paddle.full([1], 0, dtype="int32") self.need_block_list = paddle.full([max_num_seqs], -1, dtype="int32") self.need_block_len = paddle.full([1], 0, dtype="int32") self.used_list_len = paddle.full([max_num_seqs], 0, dtype="int32") self.infer_seed = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.first_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.ori_seq_lens_encoder = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.system_lens = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.system_ids = paddle.full([max_num_seqs, 1], -1, dtype="int32") self.generated_modality = paddle.full([max_num_seqs], -1, dtype="int32") self.ids_remove_padding = paddle.full( [max_num_seqs * self.max_chunk_tokens], 0, dtype="int64", ) self.batch_id_per_token = paddle.full([max_num_seqs * self.max_chunk_tokens, 1], 0, dtype="int32") self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32") self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32") # Declare AttentionBackend buffers self.decoder_batch_ids = None self.decoder_tile_ids_per_batch = None self.decoder_num_blocks_cpu = None # Pinning Memory self.decoder_num_blocks_device = None self.decoder_chunk_size_device = None self.max_len_tensor_cpu = None # CPU self.encoder_batch_ids = None self.encoder_tile_ids_per_batch = None self.encoder_num_blocks_x_cpu = None # CPU self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU # Initialize thinking related buffers self.enable_thinking = paddle.full(shape=[max_num_seqs, 1], fill_value=True, dtype="bool") self.max_think_lens = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32") self.max_reply_lens = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32") self.limit_think_status = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") self.inject_token_ids = paddle.to_tensor(self.model_config.think_truncate_prompt_ids, dtype="int64").reshape( [-1, 1] ) # NOTE(liuzichang): token after \n\n\n must be 100973 or 100975 # It is a hard code to cover up model's performance # Detailed notes can be found in FastDeploy/custom_ops/gpu_ops/reasoning_phase_token_constraint.cu self.reasoning_status = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") self.reasoning_allowed_tokens = paddle.to_tensor([100973, 100975], dtype="int64") # Initialize rotary position embedding if not self.enable_mm: self.rope_emb = get_rope( rotary_dim=self.model_config.head_dim, position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)), base=self.model_config.rope_theta, model_config=self.model_config, partial_rotary_factor=self.model_config.partial_rotary_factor, ) # Set block tables pre_max_block_num = ( self.model_config.max_model_len + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num self.block_tables = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32") # Initialize free list free_list = list( range( self.cache_config.total_block_num - 1, int(self.cache_config.total_block_num * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list_len = len(free_list) self.free_list = paddle.to_tensor(free_list, dtype="int32") self.free_list_len = paddle.full([1], self.free_list_len, dtype="int32") # Initialize stop seqs self.stop_seqs_len = paddle.full([max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32") self.stop_seqs = paddle.full( [ max_num_seqs, self.model_config.max_stop_seqs_num, self.model_config.stop_seqs_max_len, ], -1, dtype="int64", ) self.req_ids = [""] * max_num_seqs self.entropy_list = [[] for _ in range(max_num_seqs)] if self.speculative_decoding: max_draft_token_num = self.speculative_config.num_speculative_tokens self.input_ids_cpu = paddle.full( shape=[max_num_seqs, self.model_config.max_model_len], fill_value=-1, dtype="int64", device="cpu", ) self.accept_tokens = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64", ) self.accept_num = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32") self.draft_tokens = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64", ) self.actual_draft_token_num = paddle.full( shape=[max_num_seqs], fill_value=max_draft_token_num, dtype="int32", ) if current_platform.is_cuda(): self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32") self.batch_id_per_token_output = paddle.full( shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0, dtype="int32", ) else: self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") self.output_padding_offset = paddle.full( shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0, dtype="int32", ) # For V1_KVCACHE_SCHEDULER self.step_draft_tokens = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], fill_value=-1, dtype="int64", ) self.step_seq_lens_this_time = paddle.full([max_num_seqs, 1], 0, dtype="int32") # For MTP Logprob self.draft_logits = paddle.full( [max_num_seqs * (self.speculative_config.num_speculative_tokens + 1), self.model_config.vocab_size], -1, dtype="float32", ) self.cu_batch_token_offset = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32") if self.enable_mm: head_dim = self.model_config.head_dim if ( "qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type ): # neox style = True rope_head_dim = head_dim else: # neox style = False rope_head_dim = head_dim // 2 self.rope_emb = paddle.full( shape=[ max_num_seqs, 2, 1, self.model_config.max_model_len, 1, rope_head_dim, ], fill_value=0, dtype="float32", ) self.image_features = None # Built before the forward self.image_features_list = None # For logits processors self.logits_processors = build_logits_processors(self.fd_config) self.logits_processors_args = [{} for _ in range(max_num_seqs)] logger.info(f"Enabled logits processors: {self.logits_processors}") self.mask_rollback = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") self.preempted_idx = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32", device="cpu") self.last_preempted_idx = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32", device="cpu") def swap_states(self, i1, i2) -> None: """Swap the data at indices i1 and i2 for all array-like attributes""" def swap_data(tensor, idx1, idx2): """Safely swap tensor slices using clone""" temp = tensor[idx1].clone() tensor[idx1] = tensor[idx2].clone() tensor[idx2] = temp self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1] swap_data(self.token_ids_all, i1, i2) swap_data(self.input_ids, i1, i2) swap_data(self.top_p, i1, i2) swap_data(self.top_k, i1, i2) swap_data(self.min_p, i1, i2) swap_data(self.temperature, i1, i2) swap_data(self.penalty_score, i1, i2) swap_data(self.frequency_score, i1, i2) swap_data(self.presence_score, i1, i2) swap_data(self.temp_scaled_logprobs, i1, i2) swap_data(self.top_p_normalized_logprobs, i1, i2) swap_data(self.min_dec_len, i1, i2) swap_data(self.max_dec_len, i1, i2) swap_data(self.seq_lens_this_time_buffer, i1, i2) swap_data(self.seq_lens_this_time_cpu, i1, i2) swap_data(self.seq_lens_encoder, i1, i2) swap_data(self.seq_lens_decoder, i1, i2) swap_data(self.step_seq_lens_encoder, i1, i2) swap_data(self.step_seq_lens_decoder, i1, i2) swap_data(self.prompt_lens, i1, i2) swap_data(self.step_idx, i1, i2) swap_data(self.sampled_token_ids, i1, i2) swap_data(self.stop_flags, i1, i2) # swap_data(self.recompute_token_num, i1, i2) # # Swap list-based arrays (lists don't need clone) self.top_k_list[i1], self.top_k_list[i2] = self.top_k_list[i2], self.top_k_list[i1] self.min_p_list[i1], self.min_p_list[i2] = self.min_p_list[i2], self.min_p_list[i1] # Swap 1D arrays swap_data(self.bad_tokens, i1, i2) swap_data(self.bad_tokens_len, i1, i2) swap_data(self.next_tokens, i1, i2) swap_data(self.is_block_step, i1, i2) swap_data(self.is_block_step_cpu, i1, i2) swap_data(self.is_chunk_step, i1, i2) swap_data(self.encoder_block_lens, i1, i2) swap_data(self.step_block_list, i1, i2) swap_data(self.recover_block_list, i1, i2) swap_data(self.need_block_list, i1, i2) swap_data(self.used_list_len, i1, i2) swap_data(self.infer_seed, i1, i2) swap_data(self.first_token_ids, i1, i2) swap_data(self.ori_seq_lens_encoder, i1, i2) swap_data(self.system_lens, i1, i2) swap_data(self.system_ids, i1, i2) swap_data(self.enable_thinking, i1, i2) swap_data(self.max_think_lens, i1, i2) swap_data(self.limit_think_status, i1, i2) # # Swap block tables swap_data(self.block_tables, i1, i2) # # Swap stop sequences swap_data(self.stop_seqs_len, i1, i2) swap_data(self.stop_seqs, i1, i2) swap_data(self.preempted_idx, i1, i2) swap_data(self.last_preempted_idx, i1, i2) swap_data(self.reasoning_status, i1, i2) # Swap speculative decoding buffers if enabled if self.speculative_decoding: swap_data(self.input_ids_cpu, i1, i2) swap_data(self.accept_tokens, i1, i2) swap_data(self.accept_num, i1, i2) swap_data(self.draft_tokens, i1, i2) swap_data(self.actual_draft_token_num, i1, i2) if current_platform.is_cuda(): swap_data(self.cu_seqlens_q_output, i1, i2) else: swap_data(self.output_cum_offsets, i1, i2) swap_data(self.step_draft_tokens, i1, i2) swap_data(self.step_seq_lens_this_time, i1, i2) swap_data(self.draft_logits, i1, i2) swap_data(self.cu_batch_token_offset, i1, i2) if self.enable_mm: if self.image_features_list is not None: self.image_features_list[i1], self.image_features_list[i2] = ( self.image_features_list[i2], self.image_features_list[i1], ) swap_data(self.share_inputs["rope_emb"], i1, i2) # Swap mask rollback swap_data(self.mask_rollback, i1, i2) def condense(self) -> None: """ Condense the input batch by keeping only the running requests and moving their data to the front. Running requests are identified by self.running_requests_ids. Also updates index_to_batch_id to remove mappings for non-running requests. """ # Get the indices of running requests from index_to_batch_id running_indices = [ idx for idx, batch_id in self.index_to_batch_id.items() if batch_id in self.running_requests_ids ] # Sort the indices to maintain order running_indices.sort() if self.num_running_requests == len(self.index_to_batch_id): return # Move data of running requests to the front for new_idx, old_idx in enumerate(running_indices): if new_idx != old_idx: self.swap_states(new_idx, old_idx) # Update index_to_batch_id mapping - only keep mappings for running requests # After swap_states, the mapping has been updated, just remove non-running ones keys_to_remove = [ key for key in list(self.index_to_batch_id.keys()) if self.index_to_batch_id[key] not in self.running_requests_ids ] for key in keys_to_remove: del self.index_to_batch_id[key] def get_index_by_batch_id(self, batch_id): """ Get the index corresponding to the given batch_id Args: batch_id: The batch_id to look up Returns: The index corresponding to the batch_id, or add new key if not found """ for index, bid in self.index_to_batch_id.items(): if bid == batch_id: return index if batch_id in self.index_to_batch_id: # In PD reordering, some req_idx that are no longer used will be removed and # the remaining requests will be re-sorted by index. # # If req_idx = 2 was removed in the previous step and request 12 later occupied # slot 2 (i.e. {2: 12}), inserting a new request with req_id = 2 may overwrite # the existing request (req_idx = 12), leading to incorrect behavior. # # To avoid index collision, we always assign a new slot using the current length # as the new index, instead of reusing a previously freed req_idx. self.index_to_batch_id[len(self.index_to_batch_id)] = batch_id else: self.index_to_batch_id[batch_id] = batch_id return batch_id def reset_share_inputs(self): """ Reset all paddle tensors to their initial state. This method clears the content of the shared input buffers while preserving their shapes and data types. """ try: logger.info("Resetting share_inputs to initial state...") from fastdeploy.utils import fill_paddle_tensor # Reset all paddle tensors to their initial fill values max_num_seqs = self.scheduler_config.max_num_seqs # Reset basic tensors to their default values fill_paddle_tensor(self, "token_ids_all", -1) fill_paddle_tensor(self, "input_ids", self.model_config.pad_token_id) fill_paddle_tensor(self, "eos_token_id", 0) fill_paddle_tensor(self, "top_p", self.model_config.top_p) fill_paddle_tensor(self, "top_k", 0) fill_paddle_tensor(self, "min_p", 0.0) fill_paddle_tensor(self, "temperature", self.model_config.temperature) fill_paddle_tensor(self, "penalty_score", self.model_config.penalty_score) fill_paddle_tensor(self, "frequency_score", self.model_config.frequency_score) fill_paddle_tensor(self, "presence_score", self.model_config.presence_score) fill_paddle_tensor(self, "temp_scaled_logprobs", False) fill_paddle_tensor(self, "top_p_normalized_logprobs", False) # Reset list variables (not paddle tensors) self.top_k_list = [0] * max_num_seqs self.min_p_list = [0.0] * max_num_seqs fill_paddle_tensor(self, "min_dec_len", self.model_config.min_length) fill_paddle_tensor(self, "max_dec_len", self.model_config.max_model_len) # Reset sequence length related buffers fill_paddle_tensor(self, "seq_lens_this_time_buffer", 0) fill_paddle_tensor(self, "seq_lens_this_time", 0) fill_paddle_tensor(self, "seq_lens_encoder", 0) fill_paddle_tensor(self, "seq_lens_decoder", 0) fill_paddle_tensor(self, "step_seq_lens_encoder", 0) fill_paddle_tensor(self, "step_seq_lens_decoder", 0) fill_paddle_tensor(self, "prompt_lens", 0) fill_paddle_tensor(self, "step_idx", 0) # fill_paddle_tensor(self, "not_need_stop", False) fill_paddle_tensor(self, "not_need_stop_device", False) fill_paddle_tensor(self, "sampled_token_ids", -1) fill_paddle_tensor(self, "stop_flags", True) fill_paddle_tensor(self, "bad_tokens", -1) fill_paddle_tensor(self, "bad_tokens_len", 1) fill_paddle_tensor(self, "next_tokens", -1) fill_paddle_tensor(self, "is_block_step", False) fill_paddle_tensor(self, "is_chunk_step", False) fill_paddle_tensor(self, "encoder_block_lens", 0) fill_paddle_tensor(self, "step_block_list", -1) fill_paddle_tensor(self, "step_lens", 0) fill_paddle_tensor(self, "recover_block_list", -1) fill_paddle_tensor(self, "recover_lens", 0) fill_paddle_tensor(self, "need_block_list", -1) fill_paddle_tensor(self, "need_block_len", 0) fill_paddle_tensor(self, "used_list_len", 0) fill_paddle_tensor(self, "infer_seed", 0) fill_paddle_tensor(self, "first_token_ids", -1) fill_paddle_tensor(self, "ori_seq_lens_encoder", 0) fill_paddle_tensor(self, "system_lens", 0) fill_paddle_tensor(self, "system_ids", -1) fill_paddle_tensor(self, "ids_remove_padding", 0) fill_paddle_tensor(self, "batch_id_per_token", 0) fill_paddle_tensor(self, "cu_seqlens_q", 0) fill_paddle_tensor(self, "cu_seqlens_k", 0) # Reset thinking related buffers fill_paddle_tensor(self, "enable_thinking", True) fill_paddle_tensor(self, "max_think_lens", -1) fill_paddle_tensor(self, "limit_think_status", 0) # Reset reasoning buffers fill_paddle_tensor(self, "reasoning_status", 0) # Reset reasoning allowed tokens (not using fill_paddle_tensor since it's a fixed tensor) self.reasoning_allowed_tokens = paddle.to_tensor([100973, 100975], dtype="int64") # Reset block tables fill_paddle_tensor(self, "block_tables", -1) # Reset free list (requires special handling) free_list = list( range( self.cache_config.total_block_num - 1, int(self.cache_config.total_block_num * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list = paddle.to_tensor(free_list, dtype="int32") self.free_list_len = paddle.full([1], len(free_list), dtype="int32") # Reset stop sequences fill_paddle_tensor(self, "stop_seqs_len", 0) fill_paddle_tensor(self, "stop_seqs", -1) # Reset other list variables self.req_ids = [""] * max_num_seqs self.entropy_list = [[] for _ in range(max_num_seqs)] self.logits_processors_args = [{} for _ in range(max_num_seqs)] # Reset speculative decoding tensors if enabled if self.speculative_decoding: max_draft_token_num = self.speculative_config.num_speculative_tokens fill_paddle_tensor(self, "input_ids_cpu", -1) fill_paddle_tensor(self, "accept_tokens", 0) fill_paddle_tensor(self, "accept_num", 0) fill_paddle_tensor(self, "draft_tokens", -1) fill_paddle_tensor(self, "actual_draft_token_num", max_draft_token_num) fill_paddle_tensor(self, "output_cum_offsets", 0) fill_paddle_tensor(self, "output_padding_offset", 0) fill_paddle_tensor(self, "step_draft_tokens", 0) fill_paddle_tensor(self, "step_seq_lens_this_time", 0) fill_paddle_tensor(self, "draft_logits", -1) fill_paddle_tensor(self, "cu_batch_token_offset", 0) # Reset multimodal related tensors if self.enable_mm: head_dim = self.model_config.head_dim if "qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type: rope_head_dim = head_dim else: rope_head_dim = head_dim // 2 self.rope_emb = paddle.full( shape=[ max_num_seqs, 2, 1, self.model_config.max_model_len, 1, rope_head_dim, ], fill_value=0, dtype="float32", ) self.image_features = None self.image_features_list = None else: # Reset non-multimodal rope_emb self.rope_emb = get_rope( rotary_dim=self.model_config.head_dim, position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)), base=self.model_config.rope_theta, model_config=self.model_config, partial_rotary_factor=self.model_config.partial_rotary_factor, ) # Reset other miscellaneous tensors fill_paddle_tensor(self, "mask_rollback", 0) fill_paddle_tensor(self, "preempted_idx", 0) logger.info("share_inputs reset completed") except Exception as e: logger.error(f"Resetting share inputs failed, skipping reset, error message is {e}") class ProposerInputBatch(InputBatch): def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) -> None: self.enable_mm = fd_config.model_config.enable_mm self.num_model_steps = fd_config.speculative_config.num_model_steps self.index_to_batch_id = {} self.target_model_input_batch = target_model_input_batch self.fd_config: FDConfig = fd_config self.scheduler_config = fd_config.scheduler_config self.model_config: ModelConfig = fd_config.model_config self.cache_config: CacheConfig = fd_config.cache_config self.speculative_config: SpeculativeConfig = fd_config.speculative_config self.enable_pd_reorder: bool = False def init_share_inputs(self): # share with targe model self.enable_pd_reorder = getattr(self.target_model_input_batch, "enable_pd_reorder", False) self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"]) self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"]) self.input_ids_cpu = paddle.full( shape=[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], fill_value=-1, dtype="int64", device="cpu", ) self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"]) self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"]) self.seq_lens_decoder = paddle.clone(self.target_model_input_batch["seq_lens_decoder"]) self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"]) self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"]) self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu") if current_platform.is_cuda(): self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"]) self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"]) if "token_ids_all" in self.target_model_input_batch: self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"]) # TODO: delete pre_ids in mtp self.pre_ids = paddle.full( [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int64", ) for bs_idx in range(self.scheduler_config.max_num_seqs): prompt_len = self.target_model_input_batch["prompt_lens"][bs_idx] pre_ids_len = self.model_config.max_model_len - prompt_len self.pre_ids[bs_idx, :pre_ids_len] = self.target_model_input_batch["token_ids_all"][ bs_idx, prompt_len: ] else: self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.token_ids_all = None else: self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"]) self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"]) self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"]) self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"]) self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"]) self.cu_seqlens_k = paddle.clone(self.target_model_input_batch["cu_seqlens_k"]) self.target_hidden_states = paddle.full( [ self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens, self.model_config.hidden_size, ], 0, dtype="bfloat16", ) tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) self.rope_emb = get_rope( rotary_dim=self.model_config.head_dim, position_ids=tmp_position_ids, base=self.model_config.rope_theta, model_config=self.model_config, partial_rotary_factor=self.model_config.partial_rotary_factor, ) # self.caches = self.cache_kvs # Inherit generation hyperparameters from the main model for consistency self.prompt_lens = self.target_model_input_batch["prompt_lens"] self.fake_prompt_lens = paddle.full([self.scheduler_config.max_num_seqs, 1], 0, dtype="int64") self.top_p = self.target_model_input_batch["top_p"] self.top_k = self.target_model_input_batch["top_k"] self.temperature = self.target_model_input_batch["temperature"] self.eos_token_id = self.target_model_input_batch["eos_token_id"] self.penalty_score = self.target_model_input_batch["penalty_score"] self.frequency_score = self.target_model_input_batch["frequency_score"] self.presence_score = self.target_model_input_batch["presence_score"] self.infer_seed = self.target_model_input_batch["infer_seed"] self.max_dec_len = self.target_model_input_batch["max_dec_len"] self.min_dec_len = self.target_model_input_batch["min_dec_len"] self.bad_tokens = self.target_model_input_batch["bad_tokens"] self.bad_tokens_len = self.target_model_input_batch["bad_tokens_len"] # Integraad_tokens"]te the updated results in model forward self.base_model_draft_tokens = self.target_model_input_batch["draft_tokens"] self.substep = 0 # Declare AttentionBackend buffers self.decoder_batch_ids = None self.decoder_tile_ids_per_batch = None self.decoder_num_blocks_cpu = None # Pinning Memory self.decoder_num_blocks_device = None self.decoder_chunk_size_device = None self.max_len_tensor_cpu = None # CPU self.encoder_batch_ids = None self.encoder_tile_ids_per_batch = None self.encoder_num_blocks_x_cpu = None # CPU self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU # Input tokens self.draft_tokens = paddle.full( shape=[self.scheduler_config.max_num_seqs, self.speculative_config.num_speculative_tokens + 1], fill_value=-1, dtype="int64", ) self.encoder_block_lens = paddle.clone(self.target_model_input_batch["encoder_block_lens"]) self.free_list = list( range( self.cache_config.total_block_num - 1, int(self.cache_config.total_block_num * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list_len = len(self.free_list) self.free_list = paddle.to_tensor(self.free_list, dtype="int32") self.free_list_len = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32") self.is_block_step = paddle.full(shape=[self.scheduler_config.max_num_seqs, 1], fill_value=False, dtype="bool") self.batch_drop = paddle.full(shape=[self.scheduler_config.max_num_seqs, 1], fill_value=False, dtype="bool") self.used_list_len = paddle.full(shape=[self.scheduler_config.max_num_seqs], fill_value=0, dtype="int32") if self.num_model_steps > 1: self.last_seq_lens_this_time = paddle.full_like( self.target_model_input_batch["seq_lens_this_time"], fill_value=-1, dtype="int32" ) self.input_ids_len = paddle.zeros(shape=[self.scheduler_config.max_num_seqs, 1], dtype="int64", device="cpu") self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"] self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"] self.accept_num = self.target_model_input_batch["accept_num"] self.accept_tokens = self.target_model_input_batch["accept_tokens"] self.draft_logits = self.target_model_input_batch["draft_logits"] self.first_token_hidden_states = paddle.full( [self.scheduler_config.max_num_seqs, self.model_config.hidden_size], -1 ) self.batch_token_num = paddle.full(shape=[self.scheduler_config.max_num_seqs], fill_value=0, dtype="int32") self.next_token_num = paddle.full(shape=[self.scheduler_config.max_num_seqs], fill_value=0, dtype="int32") self.cu_batch_token_offset = paddle.full_like( self.target_model_input_batch["cu_batch_token_offset"], fill_value=0, dtype="int32" ) self.cu_next_token_offset = paddle.full( shape=[self.scheduler_config.max_num_seqs + 1], fill_value=0, dtype="int32" ) self.mask_rollback = paddle.full([self.scheduler_config.max_num_seqs, 1], 0, dtype="int32") # NOTE(liuzichang): In speculative decoding, accepted tokens' KV cache is recomputed # using the target model's hidden states. self.recompute_token_num = paddle.full( [self.scheduler_config.max_num_seqs, 1], self.num_model_steps - 1, dtype="int32" ) # attn_mask if self.enable_mm: self.attn_mask_offsets = paddle.full( shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len], fill_value=-1, dtype="int32", ) self.attn_mask_offsets_full = paddle.full( [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32" ) self.attn_mask_offsets_decoder = paddle.full([self.scheduler_config.max_num_seqs, 1], -1, dtype="int32") self.decode_states = paddle.full( [self.scheduler_config.max_num_seqs, self.speculative_config.num_speculative_tokens + 1], -1, dtype="int32", ) def swap_states(self, i1, i2) -> None: def swap_data(tensor, idx1, idx2): """Safely swap tensor slices using clone""" temp = tensor[idx1].clone() tensor[idx1] = tensor[idx2].clone() tensor[idx2] = temp self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1] swap_data(self.block_tables, i1, i2) swap_data(self.input_ids, i1, i2) swap_data(self.input_ids_cpu, i1, i2) swap_data(self.seq_lens_this_time_buffer, i1, i2) swap_data(self.seq_lens_encoder, i1, i2) swap_data(self.seq_lens_decoder, i1, i2) swap_data(self.step_idx, i1, i2) swap_data(self.pre_ids, i1, i2) swap_data(self.encoder_block_lens, i1, i2) swap_data(self.input_ids_len, i1, i2) swap_data(self.mask_rollback, i1, i2) swap_data(self.recompute_token_num, i1, i2) if self.enable_mm: swap_data(self.attn_mask_offsets_full, i1, i2) swap_data(self.attn_mask_offsets_decoder, i1, i2) def reset_model_inputs(self) -> None: """ Reset all paddle tensors in self to their initial state. This method clears the content of the model input buffers while preserving their shapes and data types. """ try: logger.info("Resetting model_inputs to initial state...") from fastdeploy.utils import fill_paddle_tensor # Reset all paddle tensors to their default values # Clone the target model inputs to restore initial values self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"]) self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"]) fill_paddle_tensor(self, "input_ids_cpu", -1) # acceptance rate decline when reset seq_lens_this_time # self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"]) self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"]) self.seq_lens_decoder = paddle.clone(self.target_model_input_batch["seq_lens_decoder"]) self.prompt_lens = self.target_model_input_batch["prompt_lens"] self.fake_prompt_lens = paddle.full([self.scheduler_config.max_num_seqs, 1], 0, dtype="int64") self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"]) self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"]) self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu") if current_platform.is_cuda(): if "token_ids_all" in self.target_model_input_batch: self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"]) # TODO: delete pre_ids in mtp self.pre_ids = paddle.full( [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int64", ) for bs_idx in range(self.scheduler_config.max_num_seqs): prompt_len = self.target_model_input_batch["prompt_lens"][bs_idx] pre_ids_len = self.model_config.max_model_len - prompt_len self.pre_ids[bs_idx, :pre_ids_len] = self.target_model_input_batch["token_ids_all"][ bs_idx, prompt_len: ] else: self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.token_ids_all = None else: self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"]) self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"]) self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"]) self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"]) self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"]) self.cu_seqlens_k = paddle.clone(self.target_model_input_batch["cu_seqlens_k"]) # Reset target hidden states fill_paddle_tensor(self, "target_hidden_states", 0) # Reset rope embedding by recreating with default position_ids tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) self.rope_emb = get_rope( rotary_dim=self.model_config.head_dim, position_ids=tmp_position_ids, base=self.model_config.rope_theta, model_config=self.model_config, partial_rotary_factor=self.model_config.partial_rotary_factor, ) # Reset generation hyperparameters from the main model self.top_p = self.target_model_input_batch["top_p"] self.top_k = self.target_model_input_batch["top_k"] self.temperature = self.target_model_input_batch["temperature"] self.eos_token_id = self.target_model_input_batch["eos_token_id"] self.penalty_score = self.target_model_input_batch["penalty_score"] self.frequency_score = self.target_model_input_batch["frequency_score"] self.presence_score = self.target_model_input_batch["presence_score"] self.infer_seed = self.target_model_input_batch["infer_seed"] self.max_dec_len = self.target_model_input_batch["max_dec_len"] self.min_dec_len = self.target_model_input_batch["min_dec_len"] self.bad_tokens = self.target_model_input_batch["bad_tokens"] self.bad_tokens_len = self.target_model_input_batch["bad_tokens_len"] # Reset speculative decoding specific tensors self.base_model_draft_tokens = self.target_model_input_batch["draft_tokens"] self.substep = 0 # Reset draft tokens fill_paddle_tensor(self, "draft_tokens", -1) # Reset encoder block lens self.encoder_block_lens = paddle.clone(self.target_model_input_batch["encoder_block_lens"]) # Reset free list (recreate with current cache config) free_list = list( range( self.cache_config.total_block_num - 1, int(self.cache_config.total_block_num * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list = paddle.to_tensor(free_list, dtype="int32") self.free_list_len = paddle.full(shape=[1], fill_value=len(free_list), dtype="int32") # Reset step and drop flags fill_paddle_tensor(self, "is_block_step", False) fill_paddle_tensor(self, "batch_drop", False) fill_paddle_tensor(self, "used_list_len", 0) # Reset last sequence lengths if applicable if self.num_model_steps > 1: fill_paddle_tensor(self, "last_seq_lens_this_time", -1) # Reset input IDs length fill_paddle_tensor(self, "input_ids_len", 0) # Reset various scores and flags self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"] self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"] self.accept_num = self.target_model_input_batch["accept_num"] self.accept_tokens = self.target_model_input_batch["accept_tokens"] self.draft_logits = self.target_model_input_batch["draft_logits"] fill_paddle_tensor(self, "first_token_hidden_states", -1) fill_paddle_tensor(self, "batch_token_num", 0) fill_paddle_tensor(self, "next_token_num", 0) fill_paddle_tensor(self, "cu_batch_token_offset", 0) fill_paddle_tensor(self, "cu_next_token_offset", 0) fill_paddle_tensor(self, "mask_rollback", 0) fill_paddle_tensor(self, "recompute_token_num", self.num_model_steps - 1) # Reset multimodal tensors if enabled if self.enable_mm: fill_paddle_tensor(self, "attn_mask_offsets", -1) fill_paddle_tensor(self, "attn_mask_offsets_full", -1) fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1) fill_paddle_tensor(self, "decode_states", -1) logger.info("model_inputs reset completed") except Exception as e: logger.error(f"Resetting model inputs failed, skipping reset, error message is {e}") def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch, target_model_input_batch: dict): mtp_index_2_mtp_id = {v: k for k, v in input_batch.index_to_batch_id.items()} for target_model_id in target_model_input_batch: target_model_index = target_model_input_batch[target_model_id] if input_batch.index_to_batch_id[target_model_id] == target_model_index: continue mtp_id = mtp_index_2_mtp_id[target_model_index] v1 = input_batch.index_to_batch_id[target_model_id] v2 = input_batch.index_to_batch_id[mtp_id] input_batch.swap_states(target_model_id, mtp_id) # update mapping mtp_index_2_mtp_id[v1] = mtp_id mtp_index_2_mtp_id[v2] = target_model_id keys_to_remove = input_batch.index_to_batch_id.keys() - target_model_input_batch.keys() for key in keys_to_remove: del input_batch.index_to_batch_id[key] for k, v in mtp_index_2_mtp_id.items(): if v == key: del mtp_index_2_mtp_id[k] break def reorder_split_prefill_and_decode(input_batch: InputBatch): """ Reorder input_batch data to place decode requests first and prefill requests last. Args: input_batch: Input batch data Returns: None: Modifies the input_batch data order in place """ # 1. Identify decode (prefill_len=0) vs prefill (prefill_len>0) requests decode_mask = input_batch.seq_lens_encoder == 0 # Get batch size batch_size = input_batch.num_running_requests # 2. Use two-pointer algorithm to swap prefill to the back and decode to the front left = 0 # Pointer for decode section start right = batch_size - 1 # Pointer for prefill section start while left <= right: if decode_mask[left]: # Left position is decode request, no swap needed, move right left += 1 elif not decode_mask[right]: # Right position is prefill request, no swap needed, move left right -= 1 else: # Swap: left position is prefill, right position is decode, need to swap input_batch.swap_states(left, right) left += 1 right -= 1 def _recover_tensor(recover_tensor, index_to_batch_id_list): """ Reorder recover_tensor according to index_to_batch_id_list mapping. Args: recover_tensor: paddle.Tensor to be reordered. index_to_batch_id_list: List mapping current indices to original batch IDs. Returns: A paddle.Tensor with elements restored to the original batch order. """ sort_len = len(index_to_batch_id_list) if isinstance(recover_tensor.place, paddle.CUDAPinnedPlace): recover_res_tensor = paddle.empty_like(recover_tensor, device="cpu") else: recover_res_tensor = paddle.empty_like(recover_tensor) recover_res_tensor[:sort_len] = recover_tensor[:sort_len][index_to_batch_id_list] if sort_len < recover_res_tensor.shape[0]: recover_res_tensor[sort_len:] = recover_tensor[sort_len:] return recover_res_tensor def recover_batch_index_for_output(output_cls, index_to_batch_id, enable_pd_reorder, recover_list): """ Reorder model_output according to index_to_batch_id mapping. Args: model_output: Model output object containing sampled_token_ids and other attributes index_to_batch_id: Dict mapping indices to original batch IDs Returns: Updated model_output object with reordered attributes """ res_map = {} is_not_swapped = all(i == v for i, v in index_to_batch_id.items()) or not enable_pd_reorder # Create a new tensor to store the reordered results if not is_not_swapped: src_order = [k for k, v in sorted(index_to_batch_id.items(), key=lambda x: x[1])] for recover_name in recover_list: if isinstance(output_cls, dict): recover_tensor = output_cls[recover_name] else: recover_tensor = getattr(output_cls, recover_name) if is_not_swapped: res_map[recover_name] = recover_tensor continue if isinstance(recover_tensor, paddle.Tensor): # Create a new tensor to store the reordered results res_map[recover_name] = _recover_tensor(recover_tensor, src_order) elif isinstance(recover_tensor, list): real_recover_tensor = recover_tensor.copy() for i1, i2 in enumerate(index_to_batch_id): real_recover_tensor[i1], real_recover_tensor[i2] = real_recover_tensor[i2], real_recover_tensor[i1] res_map[recover_name] = real_recover_tensor else: logger.info("Unsupported type of {}".format(recover_name)) return res_map def recover_batch_index_for_sampler_output(sampler_output, index_to_batch_id, enable_pd_reorder): """ Reorder sampled_token_ids according to index_to_batch_id mapping. Args: sampler_output: Sampler output object containing sampled_token_ids and other attributes index_to_batch_id: Dict mapping indices to original batch IDs Returns: Updated sampler_output object with reordered sampled_token_ids """ if not enable_pd_reorder or all(i == v for i, v in index_to_batch_id.items()): return sampled_token_ids = sampler_output.sampled_token_ids # Create a new tensor to store the reordered results src_order = [k for k, v in sorted(index_to_batch_id.items(), key=lambda x: x[1])] real_token_ids = _recover_tensor(sampled_token_ids, src_order) sampler_output.sampled_token_ids = real_token_ids if sampler_output.logprobs_tensors is not None: logprob_token_ids = sampler_output.logprobs_tensors.logprob_token_ids logprobs = sampler_output.logprobs_tensors.logprobs selected_token_ranks = sampler_output.logprobs_tensors.selected_token_ranks real_logprob_token_ids = _recover_tensor(logprob_token_ids, src_order) real_logprobs = _recover_tensor(logprobs, src_order) real_selected_token_ranks = _recover_tensor(selected_token_ranks, src_order) sampler_output.logprobs_tensors.logprob_token_ids = real_logprob_token_ids sampler_output.logprobs_tensors.logprobs = real_logprobs sampler_output.logprobs_tensors.sampled_token_ranks = real_selected_token_ranks if sampler_output.token_num_per_batch is not None: token_num_per_batch = sampler_output.token_num_per_batch real_token_num_per_batch = _recover_tensor(token_num_per_batch, src_order) sampler_output.token_num_per_batch = real_token_num_per_batch if sampler_output.cu_batch_token_offset is not None: cu_batch_token_offset = sampler_output.cu_batch_token_offset real_cu_batch_token_offset = _recover_tensor(cu_batch_token_offset, src_order) sampler_output.cu_batch_token_offset = real_cu_batch_token_offset if sampler_output.logits is not None: logits = sampler_output.logits real_logits = _recover_tensor(logits, src_order) sampler_output.logits = real_logits