""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import os import time from typing import TYPE_CHECKING, List import numpy as np import paddle from paddleformers.utils.log import logger from fastdeploy import envs from fastdeploy.engine.request import Request, RequestType from fastdeploy.inter_communicator import IPCSignal from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import MTPSampler from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models import ModelForCasualLM from fastdeploy.platforms import current_platform if current_platform.is_xpu(): from fastdeploy.model_executor.ops.xpu import ( draft_model_postprocess, draft_model_preprocess, draft_model_update, eagle_get_hidden_states, eagle_get_self_hidden_states, mtp_save_first_token, mtp_step_paddle, set_data_ipc, share_external_data, update_attn_mask_offsets, ) # temporary solution from fastdeploy.model_executor.xpu_pre_and_post_process import ( async_set_value, xpu_pre_process, xpu_process_output, ) else: from fastdeploy.model_executor.ops.gpu import ( draft_model_postprocess, draft_model_preprocess, draft_model_update, eagle_get_hidden_states, eagle_get_self_hidden_states, eagle_gather_hidden_states, hybrid_mtp_ngram, mtp_step_paddle, share_external_data, speculate_get_logits, speculate_save_output_topk, update_attn_mask_offsets, set_data_ipc, unset_data_ipc, ) from fastdeploy.model_executor.pre_and_post_process import async_set_value, pre_process from fastdeploy.worker.input_batch import ( ProposerInputBatch, recover_batch_index_for_output, recover_batch_index_for_sampler_output, reorder_split_prefill_and_decode_form_index_to_batch_id, ) from .base import Proposer if TYPE_CHECKING: from fastdeploy.config import FDConfig class MTPProposer(Proposer): """ Proposer for Multi-Token-Prediction(MTP) """ def __init__( self, fd_config: "FDConfig", main_model: ModelForCasualLM, local_rank: int, device_id: int, # physical device id target_model_inputs, # main model share inputs ): super().__init__(fd_config) self.num_main_model_layers = self.model_config.num_hidden_layers self.local_rank = local_rank self.device_id = device_id self.use_attn_mask_offset = self.enable_mm self._update_mtp_config(main_model) self._load_model() self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps self.enable_logprob = self.model_config.enable_logprob self.enable_draft_logprob = self.speculative_config.enable_draft_logprob self.cache_kvs_map = {} # [mixed, prefill, decoder] self.role = self.scheduler_config.splitwise_role self.pd_disaggregation_mode = fd_config.parallel_config.pd_disaggregation_mode if current_platform.is_xpu(): self._prepare_inputs = self._prepare_inputs_xpu self._propose = self._propose_xpu elif current_platform.is_cuda() or current_platform.is_maca(): self._prepare_inputs = self._prepare_inputs_cuda self._propose = self._propose_cuda else: raise RuntimeError( f"Unsupported platform for MTP: {current_platform}. " f"Supported platforms: CUDA, MACA, XPU" ) self.sampler = MTPSampler(fd_config) self.model_inputs = ProposerInputBatch(self.fd_config, self.target_model_inputs) self.model_inputs.init_share_inputs() # CUDA Graph self.draft_model_use_cudagraph = self.graph_opt_config.draft_model_use_cudagraph self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes self.attn_backends: list[AttentionBackend] = [] self._initialize_attn_backend() # Forward meta store the global meta information of the forward self.forward_meta = None self.exist_prefill_flag = False def _update_mtp_config(self, main_model): """ Update config for MTP from global config """ self.forward_meta: ForwardMeta = None self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP") self.speculative_config.sharing_model = main_model # TODO (wangyanpeng): The number of MTP layers should be read from model config self.model_config.num_hidden_layers = 1 self.model_config.model = self.speculative_config.model if "Ernie" in self.model_config.architectures[0]: self.model_config.pretrained_config.prefix_name = "ernie.mtp_block" self.model_config.prefix_layer_name = "mtp_block" if self.speculative_config.quantization != "": self.model_config.quantization = self.speculative_config.quantization self.model_config.start_layer_index = self.num_main_model_layers self.speculative_config.model_type = "mtp" if not self.use_attn_mask_offset: self.model_config.causal = True def _load_model(self): """ Load MTP Layer """ model_loader = get_model_loader(load_config=self.fd_config.load_config) self.model = model_loader.load_model(fd_config=self.fd_config) def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): """Set dummy prefill inputs to model_inputs""" max_dec_len = expected_decode_len + 1 input_length = min( num_tokens // batch_size, self.model_config.max_model_len - max_dec_len, ) # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. if self.fd_config.parallel_config.enable_expert_parallel: input_length = min(input_length, 32) block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num for i in range(batch_size): idx = i self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = input_length self.model_inputs["seq_lens_encoder"][idx : idx + 1] = input_length self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.model_inputs["step_idx"][idx : idx + 1] = 0 self.model_inputs["max_dec_len"][idx : idx + 1] = max_dec_len self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["encoder_block_lens"][idx : idx + 1] = block_num self.model_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange( idx * block_num, (idx + 1) * block_num, 1 ) self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"] def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): """ Initialize kv cache """ self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.cache_kvs = {} # Get kv cache dtype cache_type = self.model_config.dtype kv_cache_quant_type = None if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_type = self._get_cache_type() kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) if kv_cache_quant_type == "block_wise_fp8": kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32) cache_ready_signal = IPCSignal( name="cache_ready_signal", array=cache_ready_signal_data, dtype=np.int32, suffix=self.parallel_config.local_engine_worker_queue_port, create=False, ) # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. # 2. If no need to profile, create kv cache if cache managers do not exist. create_cache_tensor = profile or not ( self.fd_config.cache_config.num_cpu_blocks > 0 or self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) if not create_cache_tensor: logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}") while cache_ready_signal.value[local_rank] != 1: time.sleep(1) logger.info(f"OK! Stop waiting. {cache_ready_signal.value}") logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") if not create_cache_tensor: cache_kvs_list = [] for i in range( self.num_main_model_layers, self.num_main_model_layers + self.model_config.num_hidden_layers, ): logger.info( f"..attaching kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}" ) key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" key_cache = self._share_external_data(key_cache, key_cache_name, key_cache_shape) self.cache_kvs_map[key_cache_name] = key_cache cache_kvs_list.append(key_cache) value_cache = paddle.empty(shape=[], dtype=cache_type) value_cache = self._share_external_data(value_cache, val_cache_name, value_cache_shape) self.cache_kvs_map[val_cache_name] = value_cache cache_kvs_list.append(value_cache) if kv_cache_quant_type == "block_wise_fp8": scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}" scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}" key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) key_scale_cache = self._share_external_data( key_scale_cache, scale_key_cache_name, kv_cache_scale_shape ) self.cache_kvs_map[scale_key_cache_name] = key_scale_cache cache_kvs_list.append(key_scale_cache) value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) value_scale_cache = self._share_external_data( value_scale_cache, scale_val_cache_name, kv_cache_scale_shape ) self.cache_kvs_map[scale_val_cache_name] = value_scale_cache cache_kvs_list.append(value_scale_cache) self.model_inputs["caches"] = cache_kvs_list else: cache_kvs_list = [] for i in range( self.num_main_model_layers, self.num_main_model_layers + self.model_config.num_hidden_layers, ): logger.info(f"..creating kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") key_cache = paddle.full( shape=key_cache_shape, fill_value=0, dtype=cache_type, ) key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" set_data_ipc(key_cache, key_cache_name) self.cache_kvs_map[key_cache_name] = key_cache cache_kvs_list.append(key_cache) val_cache = paddle.full( shape=value_cache_shape, fill_value=0, dtype=cache_type, ) val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" set_data_ipc(val_cache, val_cache_name) self.cache_kvs_map[val_cache_name] = val_cache cache_kvs_list.append(val_cache) if kv_cache_quant_type == "block_wise_fp8": key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype(), ) key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}" set_data_ipc(key_cache_scales, key_cache_scales_name) self.cache_kvs_map[key_cache_scales_name] = key_cache_scales cache_kvs_list.append(key_cache_scales) val_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype(), ) val_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}" set_data_ipc(val_cache_scales, val_cache_scales_name) self.cache_kvs_map[val_cache_scales_name] = val_cache_scales cache_kvs_list.append(val_cache_scales) self.model_inputs["caches"] = cache_kvs_list self._empty_cache() def _initialize_attn_backend( self, ) -> None: """ Initialize attention backends and forward metadata """ assert len(self.attn_backends) == 0 num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size self.model_config.kv_num_heads = max( 1, int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size, ) head_dim = self.model_config.head_dim # Initialize AttentionBackend buffers encoder_block_shape_q = 64 decoder_block_shape_q = 16 self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( self.target_model_inputs["decoder_tile_ids_per_batch"] ) if current_platform.is_xpu() or current_platform.is_maca(): self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( self.target_model_inputs["decoder_num_blocks_cpu"] ).cpu() else: self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( self.target_model_inputs["decoder_num_blocks_cpu"] ).pin_memory() self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like( self.target_model_inputs["decoder_num_blocks_device"] ) self.model_inputs["decoder_chunk_size_device"] = paddle.zeros_like( self.target_model_inputs["decoder_chunk_size_device"] ) self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like( self.target_model_inputs["max_len_tensor_cpu"] ).cpu() self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"]) self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like( self.target_model_inputs["encoder_tile_ids_per_batch"] ) self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like( self.target_model_inputs["encoder_num_blocks_x_cpu"] ).cpu() self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"]) self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like( self.target_model_inputs["kv_tile_ids_per_batch"] ) self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like( self.target_model_inputs["kv_num_blocks_x_cpu"] ).cpu() # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( self.fd_config, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim, encoder_block_shape_q=encoder_block_shape_q, decoder_block_shape_q=decoder_block_shape_q, ) if attn_backend is None: raise NotImplementedError( "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." ) self.attn_backends.append(attn_backend) def clear_mtp_cache(self, profile=False): """ Clear allocated cacheKV """ create_cache_tensor = profile or not ( self.fd_config.cache_config.num_cpu_blocks > 0 or self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) if not create_cache_tensor: for name, tensor in self.cache_kvs_map.items(): unset_data_ipc(tensor, name, True, False) self.cache_kvs_map.clear() del self.model_inputs["caches"] if self.forward_meta is not None: del self.forward_meta.caches def update_mtp_block_num(self, num_gpu_blocks, skip_cache_init: bool = False) -> None: """ Update MTP block num by theoretical calculation Args: num_gpu_blocks: Main model GPU block count. skip_cache_init: When True, skip internal initialize_kv_cache call. Set this when the caller (e.g. gpu_model_runner with enable_cache_manager_v1) has already re-created MTP cache via cache_controller. """ # Reset block table and kv cache with global block num self.main_model_num_gpu_blocks = num_gpu_blocks if not skip_cache_init: self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks) # Reset free list free_list = list( range( self.num_gpu_blocks - 1, int(self.main_model_num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list_len = len(free_list) self.model_inputs.update( { "free_list": paddle.to_tensor(free_list, dtype="int32"), "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), } ) def insert_tasks_v1( self, req_dicts: List[Request], num_running_requests: int, target_model_index_to_batch_id: dict = {} ): if "caches" not in self.model_inputs: self.initialize_kv_cache() req_len = len(req_dicts) self.model_inputs["num_running_requests"] = num_running_requests self.model_inputs["running_requests_ids"] = range(num_running_requests) if target_model_index_to_batch_id: self.model_inputs.index_to_batch_id = dict(target_model_index_to_batch_id) for i in range(req_len): request = req_dicts[i] logger.debug(f"{i}th request-{request.request_id}: {request}") idx = self.model_inputs.get_index_by_batch_id(request.idx) if request.task_type.value == RequestType.PREFILL.value: # prefill task prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index input_ids = request.prompt_token_ids + request.output_token_ids self.model_inputs["input_ids_len"][idx] = length - 1 async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1) self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ idx : idx + 1, 1:length ] # TODO: use token_all_ids replace with input_ids_cpu if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs: self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length].cpu() encoder_block_num = len(request.block_tables) async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) async_set_value( self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables ) async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], False) async_set_value(self.model_inputs["batch_drop"][idx : idx + 1], False) async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], length) self.exist_prefill_flag = True async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index) async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length) async_set_value( self.model_inputs["step_idx"][idx : idx + 1], len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0, ) if self.use_attn_mask_offset: inputs = request.multimodal_inputs self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = ( paddle.to_tensor( inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], dtype="int32" ) ) # GPU don't need it anymore # NOTE: XPU backend needs decoder attention mask offset; GPU backend does not use it if current_platform.is_xpu(): self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = ( inputs["attention_mask_offset"][prefill_end_index - 1] + 1 ) if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generates first token async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0) self.exist_prefill_flag = False async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length + 1) # NOTE(liuzichang): # extra 1 : P-D split need rollback one step async_set_value(self.model_inputs["recompute_token_num"][idx : idx + 1], 0) async_set_value(self.model_inputs["mask_rollback"][idx : idx + 1], 1) # has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task encoder_block_num = len(request.block_tables) async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) if current_platform.is_cuda(): async_set_value( self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables ) else: self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) else: async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], True) async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0) async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], 0) async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0) async_set_value(self.model_inputs["is_block_step"][idx : idx + 1], False) continue # TODO(liuzichang): Solve splitewise-p bug to restore # self.model_inputs["seq_lens_this_time"] = self.model_inputs["seq_lens_this_time_buffer"][:num_running_requests] self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"] def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): """ Process inputs for prefill tasks and insert it to model_inputs buffer """ # TODO:Init role in initialize process if req_dicts[-1].disaggregate_info is not None: if req_dicts[-1].disaggregate_info["role"] == "prefill": self.role = "prefill" os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1" elif req_dicts[-1].disaggregate_info["role"] == "decode": self.role = "decode" else: self.role = "mixed" req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i] idx = request.idx length = len(request.prompt_token_ids) self.model_inputs.input_ids_len[idx] = length - 1 if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode": length = len(request.prompt_token_ids) if length > 1: self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length] self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( request.prompt_token_ids )[1:] self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1] prefill_token_num = self.max_draft_token_num + 1 self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor( request.draft_token_ids[1:2], dtype="int64" ) self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = length self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = prefill_token_num self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["batch_drop"][idx : idx + 1] = False self.model_inputs["step_idx"][idx : idx + 1] = 1 encoder_block_num = len(request.block_tables) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) else: length = len(request.prompt_token_ids) if length > 1: self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length] self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( request.prompt_token_ids )[1:] self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["step_idx"][idx : idx + 1] = 0 if self.cache_config.enable_chunked_prefill: token_chunk_size = request.prefill_chunk_info[0] self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = token_chunk_size else: self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length self.model_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["batch_drop"][idx : idx + 1] = False encoder_block_num = len(request.get("block_tables")) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.get("block_tables"), dtype="int32" ) self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"] def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0): """ Initialize forward meta and attention meta data """ # Initialize forward meta self.forward_meta = ForwardMeta( ids_remove_padding=self.model_inputs["ids_remove_padding"], rotary_embs=self.model_inputs["rope_emb"], attn_backend=self.attn_backends[0], decoder_batch_ids=self.model_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"], decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"], decoder_num_blocks_device=self.model_inputs["decoder_num_blocks_device"], decoder_chunk_size_device=self.model_inputs["decoder_chunk_size_device"], max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.model_inputs["seq_lens_encoder"], seq_lens_decoder=self.model_inputs["seq_lens_decoder"], seq_lens_this_time=self.model_inputs["seq_lens_this_time"], batch_id_per_token=self.model_inputs["batch_id_per_token"], cu_seqlens_q=self.model_inputs["cu_seqlens_q"], cu_seqlens_k=self.model_inputs["cu_seqlens_k"], block_tables=self.model_inputs["block_tables"], caches=self.model_inputs["caches"], encoder_batch_ids=self.model_inputs["encoder_batch_ids"], encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"], encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"], kv_batch_ids=self.model_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"], attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None, ) # Initialize attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) # Notes(liuzichang): # 1. CUDA Graph capture sizes must be recorded in descending order (large → small). # 2. In multi-step execution, only the first step should be captured. self.forward_meta.step_use_cudagraph = ( step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run) ) def _initialize_forward_meta_xpu(self): self.forward_meta.decoder_batch_ids = (self.model_inputs["decoder_batch_ids"],) self.forward_meta.decoder_tile_ids_per_batch = (self.model_inputs["decoder_tile_ids_per_batch"],) self.forward_meta.decoder_num_blocks_cpu = (self.model_inputs["decoder_num_blocks_cpu"],) self.forward_meta.decoder_num_blocks_device = (self.model_inputs["decoder_num_blocks_device"],) self.forward_meta.decoder_chunk_size_device = (self.model_inputs["decoder_chunk_size_device"],) self.forward_meta.max_len_tensor_cpu = (self.model_inputs["max_len_tensor_cpu"],) self.forward_meta.encoder_batch_ids = (self.model_inputs["encoder_batch_ids"],) self.forward_meta.encoder_tile_ids_per_batch = (self.model_inputs["encoder_tile_ids_per_batch"],) self.forward_meta.encoder_num_blocks_x_cpu = (self.model_inputs["encoder_num_blocks_x_cpu"],) self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],) self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],) self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],) self.forward_meta.attn_backend = self.attn_backends[0] if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query": self.forward_meta.kv_signal_sender = self.target_model_inputs["kv_signal_sender"] self.forward_meta.is_draft = True # Initialize attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) def exist_prefill(self): """ check whether prefill stage exist """ return self.exist_prefill_flag def _prepare_inputs_cuda(self, full_hidden_states): """ Prepare MTP inputs MTP state (seq_lens_decoder, step_idx) is "shadow state": - Initialized from target model state each round - Used for MTP forward, but not committed until verify - No rollback needed since it's always re-initialized """ draft_model_preprocess( self.model_inputs["draft_tokens"], self.model_inputs["input_ids"], self.model_inputs["stop_flags"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["step_idx"], self.model_inputs["not_need_stop_device"], self.model_inputs["pre_ids"], self.target_model_inputs["accept_tokens"], self.target_model_inputs["accept_num"], self.target_model_inputs["seq_lens_encoder"], self.target_model_inputs["seq_lens_decoder"], self.target_model_inputs["step_idx"], self.target_model_inputs["stop_flags"], self.model_inputs["max_dec_len"], self.target_model_inputs["draft_tokens"], self.num_model_steps, self.role == "prefill", # is_splitwise_prefill ) target_hidden_states, _ = eagle_get_hidden_states( full_hidden_states, self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["stop_flags"], self.target_model_inputs["accept_num"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.num_model_steps, ) self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) def _prepare_inputs_xpu(self, full_hidden_states): use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER) draft_model_preprocess( self.model_inputs["draft_tokens"], self.model_inputs["input_ids"], self.model_inputs["stop_flags"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["step_idx"], self.model_inputs["not_need_stop"], self.model_inputs["batch_drop"], self.model_inputs["is_block_step"], self.model_inputs["pre_ids"], self.model_inputs["mask_rollback"], self.model_inputs["recompute_token_num"], self.target_model_inputs["accept_tokens"], self.target_model_inputs["accept_num"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.target_model_inputs["seq_lens_decoder"], self.target_model_inputs["step_idx"], self.target_model_inputs["stop_flags"], self.target_model_inputs["is_block_step"], self.target_model_inputs["draft_tokens"], self.num_model_steps, True, self.role == "prefill", use_v1_cache_scheduler, ) target_hidden_states = eagle_get_hidden_states( full_hidden_states, self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["stop_flags"], self.target_model_inputs["accept_num"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.num_model_steps, ) self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) def _post_process(self, sampled_token_ids): """ PostProcess for generation """ draft_model_update( sampled_token_ids, self.model_inputs["draft_tokens"], self.model_inputs["pre_ids"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["step_idx"], # Note(ZKK): # I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend # like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358 self.model_inputs["cu_seqlens_q_output"], self.model_inputs["stop_flags"], ( self.model_inputs["not_need_stop_device"] if current_platform.is_cuda() else self.model_inputs["not_need_stop"] ), self.model_inputs["max_dec_len"], self.model_inputs["eos_token_id"], self.model_inputs["base_model_draft_tokens"], self.max_model_len, self.model_inputs["substep"], ) if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: if current_platform.is_xpu(): # Note(wangyanpeng): mtp_save_first_token for GPU platforms has been moved to model_runner. # Only XPU platform is retained here. skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) recover_model_output_map = recover_batch_index_for_output( self.model_inputs, self.model_inputs.index_to_batch_id, self.model_inputs.enable_pd_reorder, ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], ) mtp_save_first_token( recover_model_output_map["base_model_draft_tokens"], self.model_inputs["not_need_stop"], recover_model_output_map["seq_lens_decoder"], recover_model_output_map["prompt_lens"], recover_model_output_map["step_idx"], self.local_rank, self.parallel_config.use_ep, skip_save, ) # Ensure only save first token once. paddle.assign( paddle.where( self.model_inputs["stop_flags"], paddle.zeros_like(self.model_inputs["step_idx"]), self.model_inputs["step_idx"], ), self.model_inputs["step_idx"], ) def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0): """ Main process for MTP inference. Args: step_use_cudagraph: bool Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. """ is_blocking = ( (not self.fd_config.scheduler_config.enable_overlap_schedule) or is_dummy_run or self.exist_prefill() or real_bsz == 0 ) for substep in range(self.num_model_steps): if is_blocking: token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item() else: if substep == 0: token_num_cpu = real_bsz * (self.max_draft_token_num + 1) else: token_num_cpu = real_bsz if token_num_cpu > 0: self.model_inputs["substep"] = substep # Remove padding ( ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, cu_seqlens_q_output, batch_id_per_token_output, real_output_token_num, ) = pre_process( token_num_cpu, self.model_inputs["input_ids"], self.model_inputs["seq_lens_this_time"], True, self.model_inputs["draft_tokens"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], ) if self.use_attn_mask_offset: attn_mask_offsets = update_attn_mask_offsets( ids_remove_padding, getattr( self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"] ), self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], cu_seqlens_q, self.model_inputs["attn_mask_offsets_full"], self.model_inputs["is_block_step"], self.model_inputs["decode_states"], ) self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) # Initialize forward meta data self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.model_inputs["batch_id_per_token"][:] = -1 self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) # For speculative decoding self.model_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False) self.model_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False) # Initialize forward meta data self._initialize_forward_meta( step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep ) self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) self.forward_meta.real_bsz = real_bsz # Padding inputs for cuda graph self.padding_cudagraph_inputs() # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.model_inputs["temperature"], top_p=self.model_inputs["top_p"], top_k=self.model_inputs["top_k"], seed=self.model_inputs["infer_seed"], step_idx=self.model_inputs["step_idx"], token_ids_all=self.model_inputs["token_ids_all"], pre_token_ids=self.model_inputs["pre_ids"], prompt_lens=self.model_inputs["prompt_lens"], fake_prompt_lens=self.model_inputs["fake_prompt_lens"], frequency_penalties=self.model_inputs["frequency_score"], presence_penalties=self.model_inputs["presence_score"], repetition_penalties=self.model_inputs["penalty_score"], min_dec_lens=self.model_inputs["min_dec_len"], bad_words_token_ids=self.model_inputs["bad_tokens"], bad_words_token_len=self.model_inputs["bad_tokens_len"], eos_token_ids=self.model_inputs["eos_token_id"], max_num_logprobs=20 if self.enable_logprob else None, temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], share_inputs=self.model_inputs, ) real_num = self.model_inputs["ids_remove_padding"].shape[0] target_hidden_states = self.model_inputs["target_hidden_states"][:real_num] model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], previous_hidden_states=target_hidden_states, forward_meta=self.forward_meta, ) if self.forward_meta.step_use_cudagraph: model_output = model_output[: self.real_token_num] hidden_states, _ = eagle_gather_hidden_states( model_output, self.model_inputs["cu_seqlens_q"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_decoder"], self.model_inputs["seq_lens_encoder"], self.model_inputs["batch_id_per_token_output"], self.model_inputs["cu_seqlens_q_output"], real_output_token_num, ) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) if self.enable_logprob and self.enable_draft_logprob and substep == 0: first_token_logits = self.model.compute_logits( self.model_inputs["first_token_hidden_states"], forward_meta=self.forward_meta ) speculate_get_logits( self.model_inputs["draft_logits"], self.model_inputs["next_token_num"], self.model_inputs["batch_token_num"], self.model_inputs["cu_next_token_offset"], self.model_inputs["cu_batch_token_offset"], logits, first_token_logits, self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], ) sampled_token_ids, sampler_output = self.sampler( logits, self.sampling_metadata, self.max_model_len, self.model_inputs, ) if ( not is_dummy_run and self.parallel_config.tensor_parallel_rank == 0 and substep == 0 and sampler_output.logprobs_tensors is not None ): real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) recover_model_output_map = recover_batch_index_for_output( self.model_inputs, self.model_inputs.index_to_batch_id, self.model_inputs.enable_pd_reorder, ["batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"], ) speculate_save_output_topk( sampler_output.sampled_token_ids, sampler_output.logprobs_tensors.logprob_token_ids, sampler_output.logprobs_tensors.logprobs, sampler_output.logprobs_tensors.selected_token_ranks, recover_model_output_map["batch_token_num"][:real_bsz], recover_model_output_map["cu_batch_token_offset"][:real_bsz], self.model_inputs["not_need_stop"], recover_model_output_map["seq_lens_decoder"], recover_model_output_map["prompt_lens"], 4, # mtype self.local_rank, self.parallel_config.use_ep, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( sampled_token_ids, self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) self._post_process(sampled_token_ids) self.model_inputs["target_hidden_states"].copy_(hidden_states, False) else: if hasattr(self.model, "empty_input_forward") and not is_dummy_run: self.model.empty_input_forward(forward_meta=self.forward_meta) self.exist_prefill_flag = False def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0): """ Main process for MTP inference. Args: step_use_cudagraph: bool Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. """ for substep in range(self.num_model_steps): if self.model_inputs["not_need_stop"]: self.model_inputs["substep"] = substep # Remove padding self.forward_meta = xpu_pre_process( self.model_inputs["input_ids"], self.model_inputs["seq_lens_this_time"], self.model_inputs, True, self.cache_config.block_size, self.model_inputs["draft_tokens"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], num_speculative_tokens=self.speculative_config.num_speculative_tokens, ) if self.enable_mm: attn_mask_offsets = update_attn_mask_offsets( self.model_inputs["ids_remove_padding"], getattr( self.model_inputs, "seq_lens_this_time", self.model_inputs["seq_lens_this_time_buffer"] ), self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["cu_seqlens_q"], self.model_inputs["attn_mask_offsets_full"], self.model_inputs["attn_mask_offsets_decoder"], self.model_inputs["is_block_step"], self.model_inputs["decode_states"], self.model_inputs["mask_rollback"], ) self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) self._initialize_forward_meta_xpu() # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.model_inputs["temperature"], top_p=self.model_inputs["top_p"], top_k=self.model_inputs["top_k"], seed=self.model_inputs["infer_seed"], step_idx=self.model_inputs["step_idx"], pre_token_ids=self.model_inputs["pre_ids"], frequency_penalties=self.model_inputs["frequency_score"], presence_penalties=self.model_inputs["presence_score"], repetition_penalties=self.model_inputs["penalty_score"], min_dec_lens=self.model_inputs["min_dec_len"], bad_words_token_ids=self.model_inputs["bad_tokens"], eos_token_ids=self.model_inputs["eos_token_id"], max_num_logprobs=20 if self.enable_logprob else None, temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], share_inputs=self.model_inputs, ) if self.num_model_steps > 1: self.model_inputs.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], previous_hidden_states=self.model_inputs["target_hidden_states"], forward_meta=self.forward_meta, ) hidden_states = xpu_process_output(model_output, self.forward_meta, self.model_inputs) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) sampled_token_ids, sampler_output = self.sampler( logits, self.sampling_metadata, self.max_model_len, self.model_inputs, ) if substep == 0 and sampler_output.logprobs_tensors is not None: real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] recover_batch_index_for_sampler_output(sampler_output, self.model_inputs.index_to_batch_id) recover_model_output_map = recover_batch_index_for_output( self.model_inputs, self.model_inputs.index_to_batch_id, self.model_inputs.enable_pd_reorder, ["batch_token_num", "cu_batch_token_offset"], ) speculate_save_output_topk( sampler_output.sampled_token_ids, sampler_output.logprobs_tensors.logprob_token_ids, sampler_output.logprobs_tensors.logprobs, sampler_output.logprobs_tensors.selected_token_ranks, recover_model_output_map["batch_token_num"][:real_bsz], recover_model_output_map["cu_batch_token_offset"][:real_bsz], self.model_inputs["not_need_stop"], 4, # mtype self.local_rank, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( sampled_token_ids, self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) self._post_process(sampled_token_ids) if substep != self.num_model_steps - 1: self._get_self_hidden_states_xpu(hidden_states) else: if hasattr(self.model, "empty_input_forward") and not is_dummy_run: self.model.empty_input_forward(self.forward_meta) def _get_self_hidden_states_xpu(self, hidden_states): target_hidden_states = eagle_get_self_hidden_states( hidden_states, self.model_inputs.last_seq_lens_this_time, self.model_inputs["seq_lens_this_time"], self.model_inputs["step_idx"], ) self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) def update_task_chunk_prefill(self, task): """ Update single task's chunk_prefill info """ idx = self.model_inputs.get_index_by_batch_id(task.idx) start_idx = sum(task.prefill_chunk_info[: task.chunk_idx]) if task.chunk_idx == len(task.prefill_chunk_info): self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.model_inputs["step_idx"][idx : idx + 1] = 1 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) else: token_chunk_size = task.prefill_chunk_info[task.chunk_idx] if task.chunk_idx < len(task.prefill_chunk_info) - 1: self.model_inputs["input_ids"][idx, :token_chunk_size] = np.array( task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size + 1] ) # Last prefill else: self.model_inputs["input_ids"][idx, : token_chunk_size - 1] = np.array( task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size] ) self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.model_inputs["step_idx"][idx : idx + 1] = 0 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) def _update_status(self): """ Update main-model's forward info in next step. Allocate/Free block of MPT. """ draft_model_postprocess( self.target_model_inputs["draft_tokens"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.target_model_inputs["stop_flags"], ) if not envs.ENABLE_V1_KVCACHE_SCHEDULER: mtp_step_paddle( self.target_model_inputs["stop_flags"], self.model_inputs["stop_flags"], self.model_inputs["batch_drop"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["block_tables"], self.model_inputs["encoder_block_lens"], self.model_inputs["used_list_len"], self.model_inputs["free_list"], self.model_inputs["free_list_len"], self.cache_config.block_size, self.max_draft_token_num, ) def _extend_draft_token_with_ngram_match(self): # TODO: replace with gpu tensor hybrid_mtp_ngram( self.model_inputs["input_ids_cpu"].cuda(), self.model_inputs["input_ids_len"].cuda(), self.model_inputs["pre_ids"], self.model_inputs["step_idx"], self.target_model_inputs["actual_draft_token_num"], self.target_model_inputs["draft_tokens"], self.target_model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_decoder"], self.model_inputs["max_dec_len"], self.max_ngram_size, self.min_ngram_size, self.max_draft_token_num, ) def _run_impl( self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0, ): """Execute Draft Model""" self._prepare_inputs(full_hidden_states) self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, real_bsz=real_bsz) self._update_status() if self.hybrid_mode: self._extend_draft_token_with_ngram_match() def is_chunk_prefill_enabled(self): """""" return True def padding_cudagraph_inputs(self) -> None: """ Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. """ # In init_attention_metadata, the decode buffer has already been cleared # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. if self.forward_meta.step_use_cudagraph: self.forward_meta.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"] self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] return def _empty_cache(self): if current_platform.is_cuda(): paddle.device.cuda.empty_cache() elif current_platform.is_xpu(): paddle.device.xpu.empty_cache() else: paddle.device.empty_cache() def _get_cache_type(self): cache_type = None if current_platform.is_cuda(): cache_type = "uint8" elif current_platform.is_xpu(): cache_type = "int8" else: raise NotImplementedError return cache_type def reorder_inputs(self, target_model_input_batch): """ Reorder inputs to split prefill and decode. """ reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs, target_model_input_batch) def _share_external_data(self, cache, cache_name, cache_shape): if current_platform.is_xpu(): return share_external_data(cache, cache_name, cache_shape, False) else: return share_external_data(cache, cache_name, cache_shape)