""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import copy import os import queue import time from concurrent.futures import Future from threading import Thread from typing import Dict, List, Optional, cast import numpy as np import paddle from paddle import nn from paddleformers.utils.log import logger from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig from fastdeploy.engine.pooling_params import PoolingParams from fastdeploy.engine.request import ImagePosition, Request, RequestType from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, sot_warmup_guard, ) from fastdeploy.model_executor.guided_decoding import ( LogitsProcessorBase, get_guided_backend, ) from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.append_attn_backend import ( allocate_launch_related_buffer, ) from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( RoutingReplayManager, ) from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.platforms import current_platform from fastdeploy.spec_decode import SpecMethod from fastdeploy.utils import print_gpu_memory_use from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode from fastdeploy.worker.tbo import GLOBAL_ATTN_BUFFERS if current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import ( recover_decode_task, set_data_ipc, set_value_by_flags_and_idx, ) share_external_data = None elif current_platform.is_dcu(): from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx recover_decode_task = None share_external_data = None else: from fastdeploy.model_executor.ops.gpu import ( recover_decode_task, set_value_by_flags_and_idx, share_external_data, speculate_schedule_cache, set_data_ipc, unset_data_ipc, ) import zmq from fastdeploy import envs from fastdeploy.engine.tasks import PoolingTask from fastdeploy.input.image_processors.adaptive_processor import AdaptiveImageProcessor from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient from fastdeploy.logger.deterministic_logger import DeterministicLogger from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling from fastdeploy.model_executor.pre_and_post_process import ( async_set_value, post_process, pre_process, rebuild_padding, save_output_normal, save_output_specualate, ) from fastdeploy.output.pooler import PoolerOutput from fastdeploy.worker.model_runner_base import ( DistributedOut, DistributedStatus, ModelRunnerBase, ) from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput class GPUModelRunner(ModelRunnerBase): def __init__( self, fd_config: FDConfig, device: str, # logic device device_id: int, # physical device id rank: int, local_rank: int, ): super().__init__(fd_config=fd_config, device=device) self.MAX_INFER_SEED = 9223372036854775806 self.enable_mm = self.fd_config.enable_mm_runtime self.rank = rank self.local_rank = local_rank self.device_id = device_id self.spec_method = self.fd_config.speculative_config.method self.speculative_decoding = self.spec_method is not None self.enable_logprob = fd_config.model_config.enable_logprob self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size self.max_logprobs = None if self.enable_logprob: self.max_logprobs = ( self.ori_vocab_size if fd_config.model_config.max_logprobs == -1 else fd_config.model_config.max_logprobs ) self.temp_scaled_logprobs = True self.top_p_normalized_logprobs = True self.prompt_logprobs_reqs: dict[str, Request] = {} self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {} self.forward_batch_reqs_list: list[Request] = [None for _ in range(self.scheduler_config.max_num_seqs)] self.cache_kvs_map: dict = {} self.exist_prefill_flag = False self.is_kvcache_sleeping = False self.is_weight_sleeping = False if self.speculative_decoding: self._real_output_token_num_host = paddle.empty([1], dtype="int32").pin_memory() self.output_token_num_event = paddle.device.cuda.Event() # VL model config: if self.enable_mm: if "Ernie4_5" in self.model_config.architectures[0]: self._init_image_preprocess() self.amp_black = [ "reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div", "sin", "cos", "sort", "multinomial", ] self.amp_white = [ "lookup_table", "lookup_table_v2", "flash_attn", "matmul", "matmul_v2", "fused_gemm_epilogue", ] if self.cache_config.max_encoder_cache > 0: self.encoder_cache: dict[str, paddle.Tensor] = {} else: self.encoder_cache = None # Sampler if not self.speculative_decoding: self.sampler = Sampler(fd_config) else: self.sampler = SpeculativeSampler(fd_config) self.guided_backend = None if self.fd_config.structured_outputs_config.guided_decoding_backend != "off": self.guided_backend = get_guided_backend(fd_config=self.fd_config) self.sampler.set_reasoning_parser(self.guided_backend.get_reasoning_parser()) # Lazy initialize kv cache after model loading # self.kv_caches: list[paddle.Tensor] = [] # CUDA Graph self.use_cudagraph = self.graph_opt_config.use_cudagraph self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.cudagraph_capture_sizes_prefill = list(reversed(self.graph_opt_config.cudagraph_capture_sizes_prefill)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill # Initialize input batch self.share_inputs = InputBatch(self.fd_config) self.share_inputs.init_share_inputs() self.increment_value = ( 4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4 ) self.infer_seed_increment = paddle.full( shape=[self.scheduler_config.max_num_seqs, 1], fill_value=self.increment_value, dtype="int64" ) self.restore_chunked_prefill_request = dict() # Initialize deterministic logger (only when deterministic debugging is enabled) self.deterministic_logger = ( DeterministicLogger(self.share_inputs) if envs.FD_DETERMINISTIC_MODE and envs.FD_DETERMINISTIC_LOG_MODE else None ) # Initialize attention Backend # NOTE(gonshaotian): Currently, all attention layers share one attention backend instance. # In the future, we will expand it as a list. self.attn_backends: list[AttentionBackend] = [] # self.attn_metadatas: list[AttentionMetadata] = [] self._initialize_attn_backend() # Forward meta store the global meta information of the forward self.forward_meta: ForwardMeta = None # Postprocess Env params os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.local_engine_worker_queue_port) logger.info(f"queue id is {str(self.parallel_config.local_engine_worker_queue_port)}") # Rollout routing replay config self.routing_replay_manager = None self.zmq_client = None self.async_output_queue = None if envs.FD_USE_GET_SAVE_OUTPUT_V1: port = self.fd_config.parallel_config.local_engine_worker_queue_port logger.info(f"zmq client get_save_output_rank{local_rank}_{port}") self.zmq_client = ZmqIpcClient(name=f"get_save_output_rank{local_rank}_{port}", mode=zmq.PUSH) self.zmq_client.connect() self.zmq_client.socket.SNDTIMEO = 3000 self.async_output_queue: queue.Queue = queue.Queue() self.async_output_copy_thread = Thread( target=self._async_output_busy_loop, daemon=True, name="WorkerAsyncOutputCopy", ) self.async_output_copy_thread.start() self.enable_entropy = self.model_config.enable_entropy # init signal cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32) self.cache_ready_signal = IPCSignal( name="cache_ready_signal", array=cache_ready_signal_data, dtype=np.int32, suffix=self.parallel_config.local_engine_worker_queue_port, create=False, ) # for overlap self._cached_model_output_data = None self._cached_sampler_output = None self._cached_post_process_event = None # Cached token count for next batch prediction in overlap scheduling. # Used to avoid synchronization overhead when preparing inputs for the next batch. self._cached_launch_token_num = -1 self._cached_real_bsz = -1 self.enable_overlap_schedule = fd_config.scheduler_config.enable_overlap_schedule if self.enable_overlap_schedule: logger.info("Using overlap schedule") self.current_launch_token_num = 0 def _async_output_busy_loop(self): """Entrypoint for the thread which handles outputs asynchronously.""" while True: try: output = self.async_output_queue.get() self.zmq_client.send_pyobj(output) except Exception as e: logger.exception("Exception in async output loop: %s", e) def exist_prefill(self): """ check whether prefill stage exist """ return self.exist_prefill_flag @property def is_sleeping(self): return self.is_weight_sleeping or self.is_kvcache_sleeping def exist_decode(self): """ check whether decode stage exist """ seq_lens_decoder = self.share_inputs["seq_lens_decoder"] stop_flags = self.share_inputs["stop_flags"].squeeze(1) return ((seq_lens_decoder > 0) & ~stop_flags).any().cpu().numpy().item() def _resolve_current_launch_token_num( self, cached_token_num: int, cached_real_bsz: int, token_num_event, is_dummy_or_profile_run: bool ) -> int: """ Resolve token count for current batch. In overlap mode, uses cached value from previous batch prediction to avoid GPU-CPU sync. Falls back to fresh computation in certain conditions: - dummy/profile runs need accurate counts - non-overlap mode doesn't support caching - prefill stage changes batch composition - invalid cached value """ if ( is_dummy_or_profile_run or (not self.enable_overlap_schedule) or self.exist_prefill() or cached_token_num <= 0 or cached_real_bsz <= 0 ): token_num_event.synchronize() seq_lens_this_time_cpu = self.share_inputs["seq_lens_this_time_cpu"].numpy() return seq_lens_this_time_cpu.sum().item(), (seq_lens_this_time_cpu > 0).sum().item() return cached_token_num, cached_real_bsz def _predict_next_launch_token_num(self) -> int: """ Predict token count for next batch. In overlap scheduling, while current batch executes model forward, the scheduler may have prepared decode requests for next batch. This prediction allows next batch to skip synchronization. Returns -1 if prediction is not applicable (non-overlap or prefill exists). """ if self.exist_prefill(): return -1, -1 seq_lens_this_time_cpu = self.share_inputs["seq_lens_this_time_cpu"].numpy() is_block_step_cpu = self.share_inputs["is_block_step_cpu"].numpy() next_real_bsz = (seq_lens_this_time_cpu > 0).sum().item() + (is_block_step_cpu > 0).sum().item() token_num_one_step = (self.speculative_config.num_speculative_tokens + 1) if self.speculative_decoding else 1 next_launch_token_num = next_real_bsz * token_num_one_step return next_launch_token_num, next_real_bsz def only_prefill(self): """ check whether prefill only """ if_only_prefill = True decode_exists = None if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": only_prefill_batch_list = [] decode_exists = self.exist_decode() paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists) if_only_prefill = all(only_prefill_batch_list) if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode()) return if_only_prefill def collect_distributed_status(self): """ Collect distributed status """ dist_status_list = [] dist_status_obj = DistributedStatus() dist_out = DistributedOut() prefill_exists = None if_only_decode = True # mix ep in single node if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": prefill_exists = self.exist_prefill() dist_status_obj.only_decode = not prefill_exists # whether chunked moe if self.fd_config.parallel_config.enable_chunked_moe: chunk_size = self.fd_config.parallel_config.chunked_moe_size token_num = self.share_inputs["ids_remove_padding"].shape[0] if token_num > chunk_size: self.forward_meta.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size else: self.forward_meta.moe_num_chunk = 1 dist_status_obj.moe_num_chunk = self.forward_meta.moe_num_chunk # only ep need to collect and sync distributed status if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": # call once to gather all status paddle.distributed.all_gather_object(dist_status_list, dist_status_obj) # Update Batch type for cuda graph for if_only_decode if_only_decode = all(dist_status.only_decode for dist_status in dist_status_list) if_only_decode = if_only_decode and not ( prefill_exists if prefill_exists is not None else self.exist_prefill() ) max_moe_num_chunk = None if self.fd_config.parallel_config.enable_chunked_moe: max_moe_num_chunk = max(dist_status.moe_num_chunk for dist_status in dist_status_list) dist_out = DistributedOut( if_only_decode=if_only_decode, max_moe_num_chunk=max_moe_num_chunk, ) return dist_out def only_decode(self): """ check whether decode only """ # Update Batch type for cuda graph for if_only_decode if_only_decode = True prefill_exists = None # mix ep in single node if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": only_decode_batch_list = [] prefill_exists = self.exist_prefill() paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) if_only_decode = all(only_decode_batch_list) if_only_decode = if_only_decode and not ( prefill_exists if prefill_exists is not None else self.exist_prefill() ) return if_only_decode def _init_speculative_proposer(self): """ Init speculative proposer """ if self.spec_method is None: self.proposer = None return # MTP-specific: swap seq_lens_this_time to the buffer tensor if self.spec_method == SpecMethod.MTP: self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"] self.proposer = self.spec_method.create_proposer( self.fd_config, main_model=self.get_model(), local_rank=self.local_rank, device_id=self.device_id, share_inputs=self.share_inputs, ) def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]: """ init logits processor for guided decoding """ assert self.guided_backend is not None, ( "guided_backend is None, use " "--guided-decoding-backend to specify the backend at server startup." ) if request.guided_json is not None: schemata_key = ("json", request.guided_json) elif request.guided_regex is not None: schemata_key = ("regex", request.guided_regex) elif request.guided_grammar is not None: schemata_key = ("grammar", request.guided_grammar) elif request.structural_tag is not None: schemata_key = ("structural_tag", request.structural_tag) return ( self.guided_backend.get_logits_processor( schemata_key=schemata_key, enable_thinking=False, # TODO cfg ), schemata_key, ) def _process_mm_features(self, request_list: List[Request]): """ Process and cache vision features from model - add image_features, extract and cache vision features from model - add rope_emb, rotate position embeddings """ if not self.enable_mm: return self.share_inputs["image_features_list"] = [-1] * self.scheduler_config.max_num_seqs img_index = 0 req_idx_img_index_map = {} multi_vision_inputs = { "images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0], "encoder_cache_info": [], "feature_position_list": [], "grid_thw_lst_batches": [], "feature_position_list_batches": [], } for request in request_list: if request.task_type.value != RequestType.PREFILL.value: continue if self.encoder_cache is not None: evict_mm_hashes = request.get("evict_mm_hashes", None) if evict_mm_hashes: for mm_hash in evict_mm_hashes: self.encoder_cache.pop(mm_hash, None) idx = self.share_inputs.get_index_by_batch_id(request.idx) req_idx_img_index_map[idx] = -1 if request.with_image: req_idx_img_index_map[idx] = img_index img_index = img_index + 1 inputs = request.multimodal_inputs if self.encoder_cache is not None: if envs.FD_ENABLE_MAX_PREFILL: if "vit_seqlen" in inputs: vit_seqlen_list = inputs["vit_seqlen"][request.num_image_start : request.num_image_end] if "vit_position_ids" in inputs: vit_position_ids_list = inputs["vit_position_ids"][ request.num_image_start : request.num_image_end ] grid_thw_list = inputs["grid_thw"][request.num_image_start : request.num_image_end] mm_hashes_list = inputs["mm_hashes"][request.num_image_start : request.num_image_end] feature_positions = self._get_feature_positions( mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end], prefill_start_index=request.prefill_start_index, prefill_end_index=request.prefill_end_index, ) image_start_idx = request.num_image_start logger.debug( f"request {request.request_id} start process encoder info, image_start_idx: {image_start_idx} " f"grid_thw_list: {grid_thw_list}, feature_positions: {feature_positions}, mm_hashes_list: {mm_hashes_list}" ) encoder_cache_info_per_req = [] grid_thw_lst_per_req = [] for i, mm_hash in enumerate(mm_hashes_list): image_offset = np.prod(grid_thw_list[i]) logger.debug( f"run idx {i} with mm_hash {mm_hash} image_offset: {image_offset} grid_thw: {grid_thw_list[i]}" ) if mm_hash in self.encoder_cache: encoder_cache_info_per_req.append((mm_hash, feature_positions[i], True)) continue encoder_cache_info_per_req.append((mm_hash, feature_positions[i], False)) if envs.FD_ENABLE_MAX_PREFILL: multi_vision_inputs["images_lst"].append( inputs["images"][image_start_idx : image_start_idx + image_offset].to(self.device) ) multi_vision_inputs["grid_thw_lst"].append(paddle.to_tensor(grid_thw_list[i])) grid_thw_lst_per_req.append(paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64)) multi_vision_inputs["cu_seqlens"].append(vit_seqlen_list[i]) multi_vision_inputs["vit_position_ids_lst"].append(vit_position_ids_list[i]) else: multi_vision_inputs["images_lst"].append( paddle.to_tensor( inputs["images"][image_start_idx : image_start_idx + image_offset], dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", ) ) multi_vision_inputs["grid_thw_lst"].append( paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64) ) grid_thw_lst_per_req.append(paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64)) image_start_idx += image_offset multi_vision_inputs["grid_thw_lst_batches"].append(grid_thw_lst_per_req) multi_vision_inputs["encoder_cache_info"].append(encoder_cache_info_per_req) else: if envs.FD_ENABLE_MAX_PREFILL: multi_vision_inputs["images_lst"].append( inputs["images"][request.image_start : request.image_end].to(self.device) ) multi_vision_inputs["grid_thw_lst"].extend( paddle.to_tensor(inputs["grid_thw"][request.num_image_start : request.num_image_end]) ) multi_vision_inputs["grid_thw_lst_batches"].append( paddle.to_tensor(inputs["grid_thw"][request.num_image_start : request.num_image_end]) ) multi_vision_inputs["cu_seqlens"].extend( inputs["vit_seqlen"][request.num_image_start : request.num_image_end] ) multi_vision_inputs["vit_position_ids_lst"].extend( inputs["vit_position_ids"][request.num_image_start : request.num_image_end] ) else: multi_vision_inputs["images_lst"].append( paddle.to_tensor( inputs["images"][request.image_start : request.image_end], dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", ) ) multi_vision_inputs["grid_thw_lst"].extend( paddle.to_tensor( inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype=paddle.int64, ) ) multi_vision_inputs["grid_thw_lst_batches"].append( paddle.to_tensor( inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype=paddle.int64, ) ) multi_vision_inputs["feature_position_list"].extend( self._get_feature_positions( mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end], prefill_start_index=request.prefill_start_index, prefill_end_index=request.prefill_end_index, ) ) multi_vision_inputs["feature_position_list_batches"].append( self._get_feature_positions( mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end], prefill_start_index=request.prefill_start_index, prefill_end_index=request.prefill_end_index, ) ) if self.encoder_cache is not None: if len(multi_vision_inputs["images_lst"]) > 0 or len(multi_vision_inputs["encoder_cache_info"]) > 0: image_features_output = None if len(multi_vision_inputs["images_lst"]) > 0: image_features_output = self.extract_vision_features(multi_vision_inputs) logger.debug(f"encoder_cache_info: {multi_vision_inputs['encoder_cache_info']}") feature_idx = 0 image_features_list = [] for index, encoder_cache_info in enumerate(multi_vision_inputs["encoder_cache_info"]): merge_image_features, thw_idx = [], 0 for mm_hash, feature_position, use_cache in encoder_cache_info: if use_cache: assert mm_hash in self.encoder_cache, f"{mm_hash} not in encoder cache" mm_feature = self.encoder_cache[mm_hash].cuda() else: assert ( image_features_output is not None ), f"image_features_output is None, images_lst length: {len(multi_vision_inputs['images_lst'])}" grid_thw = multi_vision_inputs["grid_thw_lst_batches"][index][thw_idx] mm_token_length = inputs["mm_num_token_func"](grid_thw=grid_thw) mm_feature = image_features_output[feature_idx : feature_idx + mm_token_length] # add feature to encoder cache self.encoder_cache[mm_hash] = mm_feature.detach().cpu() feature_idx += mm_token_length thw_idx += 1 feature_start = feature_position.offset feature_end = feature_position.offset + feature_position.length merge_image_features.append(mm_feature[feature_start:feature_end]) image_features_list.append(paddle.concat(merge_image_features, axis=0)) for idx, index in req_idx_img_index_map.items(): if index != -1: self.share_inputs["image_features_list"][idx] = image_features_list[index] elif len(multi_vision_inputs["images_lst"]) > 0: image_features_output = self.extract_vision_features(multi_vision_inputs) image_features_list = [] feature_idx = 0 for index, feature_position_item in enumerate(multi_vision_inputs["feature_position_list_batches"]): grid_thw_lst = multi_vision_inputs["grid_thw_lst_batches"][index] assert len(feature_position_item) == len(grid_thw_lst), f"{feature_position_item} != {grid_thw_lst}" merge_image_features, thw_idx = [], 0 for feature_position in feature_position_item: grid_thw = grid_thw_lst[thw_idx] mm_token_length = inputs["mm_num_token_func"](grid_thw=grid_thw) mm_feature = image_features_output[feature_idx : feature_idx + mm_token_length] feature_start = feature_position.offset feature_end = feature_position.offset + feature_position.length merge_image_features.append(mm_feature[feature_start:feature_end]) feature_idx += mm_token_length thw_idx += 1 image_features_list.append(paddle.concat(merge_image_features, axis=0)) for idx, index in req_idx_img_index_map.items(): if index != -1: self.share_inputs["image_features_list"][idx] = image_features_list[index] def _get_feature_positions( self, mm_positions: List[ImagePosition], prefill_start_index: int, prefill_end_index: int ): """ Filter and adjust ImagePosition objects that fall within the specified prefill range. Args: mm_positions: List of ImagePosition objects to filter prefill_start_index: Start index of the prefill range prefill_end_index: End index of the prefill range Returns: List of ImagePosition objects that are within or intersect with the prefill range """ feature_positions = [] for position in mm_positions: position_start = position.offset position_end = position.offset + position.length if position_end <= prefill_start_index or position_start >= prefill_end_index: continue elif position_start >= prefill_start_index and position_end <= prefill_end_index: new_position = copy.deepcopy(position) new_position.offset = 0 feature_positions.append(new_position) else: new_position = copy.deepcopy(position) # Adjust offset if it starts before prefill_start_index if position_start < prefill_start_index: new_position.offset = prefill_start_index - position_start new_position.length = min(position_end, prefill_end_index) - prefill_start_index # Adjust length if it extends beyond prefill_end_index elif position_end > prefill_end_index: new_position.offset = 0 new_position.length = prefill_end_index - position_start feature_positions.append(new_position) logger.debug( f"get feature_positions, original positions: {mm_positions}, filtered positions: {feature_positions}" ) return feature_positions def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None): """ Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 req_dict: A list of Request dict num_running_requests: batch_size """ req_len = len(req_dicts) batch_pooling_params = [] self.share_inputs["num_running_requests"] = num_running_requests self.share_inputs["running_requests_ids"] = range(num_running_requests) rope_3d_position_ids = { "position_ids_idx": [], "position_ids_lst": [], "position_ids_offset": [0], "max_tokens_lst": [], } for i in range(req_len): request = req_dicts[i] idx = self.share_inputs.get_index_by_batch_id(request.idx) self.share_inputs["req_ids"][idx] = str(request.request_id) if hasattr(request, "pooling_params") and request.pooling_params is not None: batch_pooling_params.append(request.pooling_params) logits_info = None prefill_tokens = [] if request.task_type.value == RequestType.PREFILL.value: # prefill task self.share_inputs["preempted_idx"][idx : idx + 1, :] = 0 self.share_inputs["req_ids"][idx] = str(request.request_id) # rope 3d if self.enable_mm: position_ids = request.multimodal_inputs["position_ids"] rope_3d_position_ids["position_ids_idx"].append(idx) rope_3d_position_ids["position_ids_lst"].append(position_ids) rope_3d_position_ids["position_ids_offset"].append( len(position_ids) + rope_3d_position_ids["position_ids_offset"][-1] ) if self.is_pooling_model: rope_3d_position_ids["max_tokens_lst"].append(0) else: rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048)) # guided decoding if ( request.guided_json is not None or request.guided_regex is not None or request.structural_tag is not None or request.guided_grammar is not None ): logits_info, schemata_key = self._init_logits_processor(request) request.schemata_key = schemata_key if ( self.scheduler_config.splitwise_role == "decode" and hasattr(request, "prefill_end_index") and hasattr(request, "prompt_token_ids") and request.prefill_end_index > len(request.prompt_token_ids) and hasattr(request, "output_token_ids") ): prefill_tokens.extend(request.output_token_ids) prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index if not self.is_pooling_model: if request.get("enable_thinking") is not None: enable_thinking = bool(request.get("enable_thinking")) logger.debug(f"request {request.request_id} with {enable_thinking=} at idx {idx}") self.share_inputs["enable_thinking"][idx : idx + 1, :] = enable_thinking if enable_thinking: self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 if request.get("reasoning_max_tokens") is not None: # Enable thinking self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get( "reasoning_max_tokens" ) else: self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 if request.get("response_max_tokens") is not None: # Enable thinking self.share_inputs["max_reply_lens"][idx : idx + 1, :] = request.get( "response_max_tokens" ) else: self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1 else: # Disable thinking self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 self.share_inputs["max_reply_lens"][idx : idx + 1, :] = -1 self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 if isinstance(request.prompt_token_ids, np.ndarray): prompt_token_ids = request.prompt_token_ids.tolist() else: prompt_token_ids = request.prompt_token_ids input_ids = prompt_token_ids + request.output_token_ids prompt_len = len(prompt_token_ids) # prompt_tokens async_set_value(self.share_inputs["token_ids_all"][idx : idx + 1, :prompt_len], prompt_token_ids) # generated_token_ids fill -1 self.share_inputs["token_ids_all"][idx : idx + 1, prompt_len:] = -1 # Log complete input_ids for input determinism verification # Note: Only current request info is logged here; batch info is logged during forward if self.deterministic_logger is not None: self.deterministic_logger.log_prefill_input( request.request_id, idx, prefill_start_index, prefill_end_index, input_ids ) logger.debug( f"Handle prefill request {request} at idx {idx}, " f"{prefill_start_index=}, {prefill_end_index=}, " f"need_prefilled_token_num={len(input_ids)}" f"prompt_len={prompt_len}" ) async_set_value( self.share_inputs["input_ids"][idx : idx + 1, :length], input_ids[prefill_start_index:prefill_end_index], ) encoder_block_num = len(request.block_tables) async_set_value(self.share_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) async_set_value(self.share_inputs["block_tables"][idx : idx + 1, :], -1) async_set_value( self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables ) async_set_value(self.share_inputs["stop_flags"][idx : idx + 1], False) async_set_value(self.share_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index) async_set_value(self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length) async_set_value(self.share_inputs["seq_lens_encoder"][idx : idx + 1], length) self.exist_prefill_flag = True async_set_value(self.share_inputs["step_seq_lens_decoder"][idx : idx + 1], 0) async_set_value(self.share_inputs["prompt_lens"][idx : idx + 1], len(input_ids)) async_set_value(self.share_inputs["is_block_step"][idx : idx + 1], False) self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids) async_set_value( self.share_inputs["step_idx"][idx : idx + 1], len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0, ) # pooling model request.sampling_params is None if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: self.prompt_logprobs_reqs[request.request_id] = request self.forward_batch_reqs_list[idx] = request if self.speculative_decoding and self.spec_method == SpecMethod.SUFFIX and self.proposer is not None: if isinstance(request.prompt_token_ids, np.ndarray): prompt_token_ids = request.prompt_token_ids.tolist() else: prompt_token_ids = request.prompt_token_ids self.proposer.start_request(idx, request.request_id, prompt_token_ids) # Routing Replay if self.fd_config.routing_replay_config.enable_routing_replay: # 1.prefix task(need regist) 2. chunkend task(not need regist) self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id) if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generate first token # TODO: delete useless operation like this async_set_value(self.share_inputs["seq_lens_encoder"][idx : idx + 1], 0) self.exist_prefill_flag = False if self._cached_launch_token_num != -1: token_num_one_step = ( (self.speculative_config.num_speculative_tokens + 1) if self.speculative_decoding else 1 ) self._cached_launch_token_num += token_num_one_step self._cached_real_bsz += 1 if self.speculative_decoding: # D first decode step, [Target first token, MTP first draft token] # MTP in P only generate one draft token in any num_model_step config draft_tokens_to_write = request.draft_token_ids[0:2] if len(draft_tokens_to_write) != 2: raise ValueError( "Expected at least 2 draft tokens for speculative suffix decode, " f"but got {len(draft_tokens_to_write)} for request {request.request_id}." ) async_set_value( self.share_inputs["draft_tokens"][idx : idx + 1, 0:2], draft_tokens_to_write, ) async_set_value(self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 2) logger.debug( f"insert request {request.request_id} idx: {idx} suffix tokens {request.draft_token_ids}" ) elif request.task_type.value == RequestType.DECODE.value: # decode task logger.debug(f"Handle decode request {request} at idx {idx}") encoder_block_num = len(request.block_tables) async_set_value(self.share_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) async_set_value(self.share_inputs["block_tables"][idx : idx + 1, :], -1) if current_platform.is_cuda(): async_set_value( self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables ) else: self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) # CPU Tensor self.share_inputs["preempted_idx"][idx : idx + 1, :] = 0 continue else: # preempted task if request.task_type.value == RequestType.PREEMPTED.value: logger.info(f"Handle preempted request {request} at idx {idx}") elif request.task_type.value == RequestType.ABORT.value: logger.info(f"Handle abort request {request} at idx {idx}") self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1 async_set_value(self.share_inputs["block_tables"][idx : idx + 1, :], -1) async_set_value(self.share_inputs["stop_flags"][idx : idx + 1], True) async_set_value(self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0) async_set_value(self.share_inputs["seq_lens_decoder"][idx : idx + 1], 0) async_set_value(self.share_inputs["seq_lens_encoder"][idx : idx + 1], 0) async_set_value(self.share_inputs["is_block_step"][idx : idx + 1], False) self.prompt_logprobs_reqs.pop(request.request_id, None) self.in_progress_prompt_logprobs.pop(request.request_id, None) self.forward_batch_reqs_list[idx] = None # Routing Replay if self.fd_config.routing_replay_config.enable_routing_replay: self.routing_replay_manager.clear_request(batch_id=idx) continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0) self.share_inputs["top_k_list"][idx] = request.get("top_k", 0) async_set_value(self.share_inputs["eos_token_id"][:], request.eos_token_ids) async_set_value(self.share_inputs["top_p"][idx : idx + 1], request.get("top_p", 0.7)) async_set_value(self.share_inputs["top_k"][idx : idx + 1], request.get("top_k", 0)) async_set_value(self.share_inputs["min_p"][idx : idx + 1], request.get("min_p", 0.0)) async_set_value(self.share_inputs["temperature"][idx : idx + 1], request.get("temperature", 0.95)) async_set_value(self.share_inputs["penalty_score"][idx : idx + 1], request.get("repetition_penalty", 1.0)) async_set_value(self.share_inputs["frequency_score"][idx : idx + 1], request.get("frequency_penalty", 0.0)) async_set_value(self.share_inputs["presence_score"][idx : idx + 1], request.get("presence_penalty", 0.0)) async_set_value( self.share_inputs["temp_scaled_logprobs"][idx : idx + 1], request.get("temp_scaled_logprobs", False) ) async_set_value( self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1], request.get("top_p_normalized_logprobs", False), ) async_set_value( self.share_inputs["generated_modality"][idx : idx + 1], request.get("generated_modality", 0) ) async_set_value(self.share_inputs["min_dec_len"][idx : idx + 1], request.get("min_tokens", 1)) async_set_value( self.share_inputs["max_dec_len"][idx : idx + 1], request.get("max_tokens", self.model_config.max_model_len), ) if request.get("seed") is not None: async_set_value(self.share_inputs["infer_seed"][idx : idx + 1], request.get("seed")) if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0: bad_words_len = len(request.get("bad_words_token_ids")) async_set_value(self.share_inputs["bad_tokens_len"][idx : idx + 1], bad_words_len) async_set_value( self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len], request.get("bad_words_token_ids") ) else: async_set_value(self.share_inputs["bad_tokens_len"][idx : idx + 1], 1) async_set_value(self.share_inputs["bad_tokens"][idx : idx + 1, :], -1) if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): request.sampling_params.stop_seqs_len.append(0) async_set_value( self.share_inputs["stop_seqs_len"][idx : idx + 1, :], request.sampling_params.stop_seqs_len ) # 每条 stop sequence pad 到 stop_seqs_max_len,凑齐空行后整块写入 # 避免对第 3 维做部分切片(非连续内存)导致 async_set_value stride 错位 stop_token_ids = request.get("stop_token_ids") max_len = self.model_config.stop_seqs_max_len padded = [seq + [-1] * (max_len - len(seq)) for seq in stop_token_ids] padded.extend([[-1] * max_len] * (self.model_config.max_stop_seqs_num - stop_seqs_num)) async_set_value(self.share_inputs["stop_seqs"][idx : idx + 1, :, :], padded) else: async_set_value(self.share_inputs["stop_seqs_len"][idx : idx + 1, :], 0) self.pooling_params = batch_pooling_params # For logits processors self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {} self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens) self._process_mm_features(req_dicts) if len(rope_3d_position_ids["position_ids_idx"]) > 0 and self.enable_mm: packed_position_ids = paddle.to_tensor( np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="float32" ) rope_3d_lst = self.prepare_rope3d( packed_position_ids, rope_3d_position_ids["max_tokens_lst"], rope_3d_position_ids["position_ids_offset"], ) for i, idx in enumerate(rope_3d_position_ids["position_ids_idx"]): self.share_inputs["rope_emb"][idx : idx + 1, :] = rope_3d_lst[i] self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests] if self.spec_method == SpecMethod.MTP: self.proposer.insert_tasks_v1(req_dicts, num_running_requests, self.share_inputs.index_to_batch_id) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): raise NotImplementedError("GPUs only support KVCACHE SCHEDULER V1 in versions 2.6 and above.") def get_input_length_list( self, num_tokens: int, batch_size: int, expected_decode_len: int, capture_prefill: bool = False ): """ Generates some list for _dummy_prefill_inputs, when capture pure prefill or mtp, the list should be carefully constructed. This function addresses a specific problem: in the pure prefill stage, variable input lengths (e.g., `prompt[160, 0]` vs. `prompt[80, 80]`) can lead to different CUDA Grid dimensions for kernels like `split_q_block`. This prevents CUDA Graph reuse. The `split_q_block` kernel calculates the total number of blocks, which directly determines the `griddim.x` launch parameter for the `multi_query_append_attention_kernel`. The blocks for a single sequence are determined by the formula: `num_blocks = ceil((sequence_length * group_size) / block_shape_q)` Due to the `ceil` (ceiling) function, distributing a total number of tokens across a batch of shorter sequences will result in a larger total block count. For example, with a `group_size` of 5 and `block_shape_q` of 64: - A single sequence of 160 tokens requires `ceil((160 * 5) / 64) = 13` blocks. - Two sequences of 80 tokens each require `ceil((80 * 5) / 64) * 2 = 7 * 2 = 14` blocks. To ensure graph replayability, this function creates a "dummy" list of sequence lengths that's designed to produce the theoretical maximum `encoder_num_blocks_x_cpu` for the given `num_tokens` and `batch_size`. This strategy ensures the captured CUDA Graph has the largest possible grid dimensions. At runtime, if the actual number of blocks is less than or equal to this maximum, the kernel can safely execute by using an early-exit mechanism. Args: num_tokens (int): The total number of tokens across all sequences. batch_size (int): The number of sequences (requests) in the batch. Returns: List[int]: A list of integers representing the sequence length for each request. This list is crafted to maximize the total number of blocks. """ # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token max_dec_len = expected_decode_len + 1 input_length = min( num_tokens // (1 if capture_prefill else batch_size), self.model_config.max_model_len - max_dec_len, ) # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. if self.fd_config.parallel_config.enable_expert_parallel: input_length = min(input_length, 32) block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num input_length_list = [input_length] * batch_size if capture_prefill: if num_tokens < batch_size: input_length_list = [1] * num_tokens else: input_length_list = [1] * (batch_size - 1) input_length_list.append(num_tokens - batch_size + 1) len_of_input_length_list = len(input_length_list) max_dec_len_list = [max_dec_len] * len_of_input_length_list return input_length_list, max_dec_len_list, block_num def get_supported_pooling_tasks(self) -> list[PoolingTask]: model = self.get_model() if not self.is_pooling_model: return [] supported_tasks = list(model.pooler.get_supported_tasks()) if self.cache_config.enable_chunked_prefill and "encode" in supported_tasks: supported_tasks.remove("encode") logger.debug( "Chunked prefill is not supported with " "encode task which using ALL pooling. " "Please turn off chunked prefill by export=FD_DISABLE_CHUNKED_PREFILL=1 before using it." ) # score not support return supported_tasks def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int): """Set dummy prefill inputs to share_inputs""" batch_size = len(input_length_list) for i in range(batch_size): idx = i input_length = input_length_list[i] max_dec_len = max_dec_len_list[i] self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["token_ids_all"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["eos_token_id"][:] = np.array( [2] * self.model_config.eos_tokens_lens, dtype="int64" ).reshape(-1, 1) self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = input_length self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length self.exist_prefill_flag = True self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["prompt_lens"][idx : idx + 1] = input_length self.share_inputs["step_idx"][idx : idx + 1] = 0 self.share_inputs["max_dec_len"][idx : idx + 1] = max_dec_len self.share_inputs["min_dec_len"][idx : idx + 1] = max_dec_len self.share_inputs["stop_flags"][idx : idx + 1] = False self.share_inputs["temperature"][idx : idx + 1] = 1 self.share_inputs["encoder_block_lens"][idx : idx + 1] = block_num self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange( idx * block_num, (idx + 1) * block_num, 1 ) self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"] def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_profile_run=False) -> None: """Prepare the model inputs""" if self.enable_mm and self.share_inputs["image_features_list"] is not None: tensor_feats = [t for t in self.share_inputs["image_features_list"] if isinstance(t, paddle.Tensor)] if tensor_feats: self.share_inputs["image_features"] = paddle.concat(tensor_feats, axis=0) recover_decode_task( self.share_inputs["stop_flags"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_decoder"], self.share_inputs["step_seq_lens_decoder"], self.share_inputs["block_tables"], self.share_inputs["is_block_step"], self.share_inputs["draft_tokens"] if self.speculative_decoding else None, self.share_inputs["step_draft_tokens"] if self.speculative_decoding else None, self.share_inputs["step_seq_lens_this_time"] if self.speculative_decoding else None, self.cache_config.block_size, self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0, ) logprobs_reqs = [ req for req in self.forward_batch_reqs_list if req is not None and req.sampling_params is not None and req.sampling_params.logprobs is not None ] if len(logprobs_reqs): self.max_logprobs = ( max( [ self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs for req in logprobs_reqs ] ) if not self.speculative_decoding else 20 ) self.temp_scaled_logprobs = any(req.sampling_params.temp_scaled_logprobs for req in logprobs_reqs) self.top_p_normalized_logprobs = any( req.sampling_params.top_p_normalized_logprobs for req in logprobs_reqs ) elif self.enable_logprob: self.max_logprobs = None if not self.speculative_decoding else 0 # Remove padding self.share_inputs["seq_lens_this_time_cpu"].copy_(self.share_inputs["seq_lens_this_time"], False) self.share_inputs["is_block_step_cpu"].copy_(self.share_inputs["is_block_step"], False) token_num_event = paddle.device.cuda.create_event() token_num_event.record() token_num, real_bsz = self._resolve_current_launch_token_num( cached_token_num, cached_real_bsz, token_num_event, is_dummy_or_profile_run ) ( ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, cu_seqlens_q_output, batch_id_per_token_output, real_output_token_num, ) = pre_process( token_num, self.share_inputs["input_ids"], self.share_inputs["seq_lens_this_time"], self.speculative_decoding, (self.share_inputs["draft_tokens"] if self.speculative_decoding else None), self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_decoder"], ) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) # NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions self.share_inputs["batch_id_per_token"][:] = -1 self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) # For speculative decoding if self.speculative_decoding: self.share_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False) self.share_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False) self._real_output_token_num_host.copy_(real_output_token_num, False) self.output_token_num_event.record() # Initialize forward meta data self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run) self.forward_meta.real_bsz = real_bsz # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], top_k_list=self.share_inputs["top_k_list"], min_p=self.share_inputs["min_p"], min_p_list=self.share_inputs["min_p_list"], seed=self.share_inputs["infer_seed"], step_idx=self.share_inputs["step_idx"], token_ids_all=self.share_inputs["token_ids_all"], prompt_lens=self.share_inputs["prompt_lens"], frequency_penalties=self.share_inputs["frequency_score"], presence_penalties=self.share_inputs["presence_score"], repetition_penalties=self.share_inputs["penalty_score"], min_dec_lens=self.share_inputs["min_dec_len"], bad_words_token_ids=self.share_inputs["bad_tokens"], bad_words_token_len=self.share_inputs["bad_tokens_len"], eos_token_ids=self.share_inputs["eos_token_id"], max_num_logprobs=self.max_logprobs, enable_early_stop=self.enable_early_stop, stop_flags=self.share_inputs["stop_flags"], temp_scaled_logprobs_flag=self.temp_scaled_logprobs, top_p_normalized_logprobs_flag=self.top_p_normalized_logprobs, temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"], top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], logits_processors=self.share_inputs["logits_processors"], share_inputs=self.share_inputs, ) return token_num, token_num_event def _process_reorder(self) -> None: if self.attn_backends and getattr(self.attn_backends[0], "enable_ids_reorder", False): self.share_inputs.enable_pd_reorder = True self.share_inputs.condense() reorder_split_prefill_and_decode(input_batch=self.share_inputs) if self.speculative_decoding: if self.spec_method == SpecMethod.MTP: self.proposer.reorder_inputs(self.share_inputs.index_to_batch_id) def load_model(self) -> None: """load or download model""" logger.info(f"Starting to load model {self.model_config.architectures[0]}") # 1. Load original model model_loader = get_model_loader(load_config=self.fd_config.load_config) self.model = model_loader.load_model(fd_config=self.fd_config) # 2. Load lora model # 3. Load drafter model(for speculative decoding) # 4. Init proposer for speculative method self._init_speculative_proposer() # Load RL dynamic model if self.fd_config.load_config.dynamic_load_weight: from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager if self.spec_method == SpecMethod.MTP: self.dynamic_weight_manager = DynamicWeightManager( self.fd_config, [self.model, self.proposer.model], self.local_rank ) else: self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank) def get_model(self) -> nn.Layer: """Get current model""" return self.model def initialize_forward_meta(self, is_dummy_or_profile_run=False): """ Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta routing_replay_table = None if self.routing_replay_manager is not None: routing_replay_table = self.routing_replay_manager.get_routing_table() num_running_requests = self.share_inputs["seq_lens_this_time"].shape[0] self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], attn_backend=self.attn_backends[0], decoder_batch_ids=self.share_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"], # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, # adapted to cudagraph. decoder_num_blocks_device=self.share_inputs["decoder_num_blocks_device"], decoder_chunk_size_device=self.share_inputs["decoder_chunk_size_device"], max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"][:num_running_requests], seq_lens_decoder=self.share_inputs["seq_lens_decoder"][:num_running_requests], seq_lens_this_time=self.share_inputs["seq_lens_this_time"], batch_id_per_token=self.share_inputs["batch_id_per_token"], cu_seqlens_q=self.share_inputs["cu_seqlens_q"], cu_seqlens_k=self.share_inputs["cu_seqlens_k"], block_tables=self.share_inputs["block_tables"][:num_running_requests], caches=self.share_inputs["caches"], encoder_batch_ids=self.share_inputs["encoder_batch_ids"], encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"], encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"], kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], routing_replay_table=routing_replay_table, ) dist_status = self.collect_distributed_status() if_only_decode = dist_status.if_only_decode if self.fd_config.parallel_config.enable_chunked_moe: self.forward_meta.max_moe_num_chunk = dist_status.max_moe_num_chunk only_decode_use_cudagraph = self.use_cudagraph and if_only_decode # Update config about moe for better performance # TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply() if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" if self.speculative_decoding: self.proposer.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" # Update Batch type for cuda graph for only_prefill_batch only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill() # When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph] self.forward_meta.step_use_cudagraph = ( only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph and self.forward_meta.ids_remove_padding.shape[0] > 0 ) # Use static graph splitting to isolate incompatible operators from the CUDA Graph. This splits the graph into subgraphs, allowing Prefill, Decode, and Mixed Batches to run compatible parts via CUDA Graph. if ( hasattr(self, "graph_opt_config") and self.use_cudagraph and self.graph_opt_config.graph_opt_level > 0 and not self.graph_opt_config.full_cuda_graph ): self.forward_meta.step_use_cudagraph = True # Set forward_meta.is_dummy_or_profile_run to True to skip init_kv_signal_per_query for attention backends self.forward_meta.is_dummy_or_profile_run = is_dummy_or_profile_run # Initialize attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) # for zero size self.forward_meta.is_zero_size = self.forward_meta.ids_remove_padding.shape[0] == 0 self.forward_meta.exist_prefill = self.exist_prefill() def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache """ # cache_kvs = {} max_block_num = self.num_gpu_blocks # Get kv cache dtype cache_type = self.model_config.dtype kv_cache_quant_type = None # NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention, # To rationalize the allocation of kvcache. from fastdeploy import envs self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" self.dsa_cache = envs.FD_ATTENTION_BACKEND == "DSA_ATTN" if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_type = "uint8" kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape if self.dsa_cache: # Determine dsa cache quant type kv_cache_quant_type = "uint8" cache_type = "uint8" # NOTE(changwenbin) Get dsa cache shape. key_cache_shape, value_cache_shape, indexer_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type ) else: key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type ) indexer_cache_shape = [] if kv_cache_quant_type == "block_wise_fp8": kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. # 2. If no need to profile, create kv cache if cache managers do not exist. create_cache_tensor = profile or not ( self.fd_config.cache_config.num_cpu_blocks > 0 or self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) cache_ready_signal = self.cache_ready_signal if not create_cache_tensor: logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}") while cache_ready_signal.value[local_rank] != 1: time.sleep(1) logger.info(f"OK! Stop waiting. {cache_ready_signal.value}") logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") cache_kvs_list = [] for i in range(self.model_config.num_hidden_layers): # init key cache key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}" if value_cache_shape: val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" value_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}" elif indexer_cache_shape: indexer_cache_name = f"indexer_caches_{i}_rank{local_rank}.device{self.device_id}" if create_cache_tensor: logger.info( f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}, indexer:{indexer_cache_shape}" ) key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_cache_name) self.cache_kvs_map[key_cache_name] = key_cache if value_cache_shape: val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(val_cache, val_cache_name) self.cache_kvs_map[val_cache_name] = val_cache cache_kvs_list.extend([key_cache, val_cache]) elif indexer_cache_shape: indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(indexer_cache, indexer_cache_name) self.cache_kvs_map[indexer_cache_name] = indexer_cache cache_kvs_list.extend([key_cache, indexer_cache]) else: cache_kvs_list.append(key_cache) if kv_cache_quant_type == "block_wise_fp8": key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) set_data_ipc(key_cache_scales, key_cache_scales_name) self.cache_kvs_map[key_cache_scales_name] = key_cache_scales if value_cache_shape: val_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) set_data_ipc(val_cache_scales, value_cache_scales_name) self.cache_kvs_map[value_cache_scales_name] = val_cache_scales cache_kvs_list.extend([key_cache_scales, val_cache_scales]) else: cache_kvs_list.append(key_cache_scales) else: logger.info( f"..attaching kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}, indexer:{indexer_cache_shape}" ) key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape) self.cache_kvs_map[key_cache_name] = key_cache if kv_cache_quant_type == "block_wise_fp8": key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) key_cache_scales = share_external_data( key_cache_scales, key_cache_scales_name, kv_cache_scale_shape ) self.cache_kvs_map[key_cache_scales_name] = key_cache_scales if value_cache_shape: val_cache = paddle.empty(shape=[], dtype=cache_type) val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape) self.cache_kvs_map[val_cache_name] = val_cache cache_kvs_list.extend([key_cache, val_cache]) if kv_cache_quant_type == "block_wise_fp8": val_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) val_cache_scales = share_external_data( val_cache_scales, value_cache_scales_name, kv_cache_scale_shape ) self.cache_kvs_map[value_cache_scales_name] = val_cache_scales cache_kvs_list.extend([key_cache_scales, val_cache_scales]) elif indexer_cache_shape: indexer_cache = paddle.empty(shape=[], dtype=cache_type) indexer_cache = share_external_data(indexer_cache, indexer_cache_name, indexer_cache_shape) self.cache_kvs_map[indexer_cache_name] = indexer_cache cache_kvs_list.extend([key_cache, indexer_cache]) else: cache_kvs_list.append(key_cache) if kv_cache_quant_type == "block_wise_fp8": cache_kvs_list.append(key_cache_scales) self.share_inputs["caches"] = cache_kvs_list if not profile and create_cache_tensor: cache_ready_signal.value[local_rank] = 1 logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}") paddle.device.cuda.empty_cache() logger.info("kv cache is initialized!") def _initialize_attn_backend(self) -> None: """ Initialize attention backends """ assert ( len(self.attn_backends) == 0 ), f"attn_backends should be empty before initialization, got {len(self.attn_backends)} backends" num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size self.model_config.kv_num_heads = max( 1, int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size, ) head_dim = self.model_config.head_dim encoder_block_shape_q = 64 decoder_block_shape_q = 16 # Deterministic mode: use deterministic_split_kv_size to ensure batch-invariant attention if envs.FD_DETERMINISTIC_MODE: decoder_block_shape_q = envs.FD_DETERMINISTIC_SPLIT_KV_SIZE buffer_kwargs = dict( max_batch_size=self.scheduler_config.max_num_seqs, max_model_len=self.model_config.max_model_len, encoder_block_shape_q=encoder_block_shape_q, decoder_block_shape_q=decoder_block_shape_q, decoder_step_token_num=self.speculative_config.num_speculative_tokens + 1, num_heads=num_heads, kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, ) res_buffer = allocate_launch_related_buffer(**buffer_kwargs) self.share_inputs.update(res_buffer) if int(os.getenv("USE_TBO", "0")) == 1: for j in range(2): GLOBAL_ATTN_BUFFERS[j] = allocate_launch_related_buffer(**buffer_kwargs) # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( self.fd_config, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim, encoder_block_shape_q=encoder_block_shape_q, decoder_block_shape_q=decoder_block_shape_q, ) self.attn_backends.append(attn_backend) def _dummy_pooler_run_task( self, hidden_states: paddle.Tensor, task: PoolingTask, ) -> PoolerOutput: num_tokens = hidden_states.shape[0] max_num_seqs = self.scheduler_config.max_num_seqs num_reqs = min(num_tokens, max_num_seqs) min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs req_num_tokens = num_tokens // num_reqs dummy_prompt_lens = paddle.to_tensor(num_scheduled_tokens_list, dtype="int64", place=paddle.CPUPlace()) dummy_token_ids = paddle.zeros([num_reqs, req_num_tokens], dtype="int64", device=hidden_states.place) model = cast(FdModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) dummy_metadata = PoolingMetadata( prompt_lens=dummy_prompt_lens, prompt_token_ids=dummy_token_ids, pooling_params=[dummy_pooling_params] * num_reqs, ) dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=hidden_states.place) try: return model.pooler(hidden_states=hidden_states, pooling_metadata=dummy_metadata) except RuntimeError as e: if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " "lowering `max_num_seqs` or `gpu_memory_utilization` when " "initializing the engine." ) from e else: raise e def _dummy_pooler_run( self, hidden_states: paddle.Tensor, model_output: paddle.Tensor, ) -> PoolerOutput: output_size = dict[PoolingTask, float]() for task in self.get_supported_pooling_tasks(): output = self._dummy_pooler_run_task(hidden_states, task) output_size[task] = sum(o.numel() * o.element_size() if hasattr(o, "numel") else 0 for o in output) del output max_task = max(output_size.items(), key=lambda x: x[1])[0] pooler_output = self._dummy_pooler_run_task(hidden_states, max_task) model_output_data = ModelOutputData( next_tokens=self.share_inputs["next_tokens"], stop_flags=self.share_inputs["stop_flags"], step_idx=self.share_inputs["step_idx"], max_dec_len=self.share_inputs["max_dec_len"], seq_lens_this_time=self.share_inputs["seq_lens_this_time"], eos_token_id=self.share_inputs["eos_token_id"], not_need_stop=self.share_inputs["not_need_stop"], not_need_stop_device=self.share_inputs["not_need_stop_device"], input_ids=self.share_inputs["input_ids"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], is_block_step=self.share_inputs["is_block_step"], full_hidden_states=model_output, msg_queue_id=self.parallel_config.msg_queue_id, mp_rank=self.parallel_config.tensor_parallel_rank, use_ep=self.parallel_config.use_ep, draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), actual_draft_token_num=( self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None ), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), token_ids_all=self.share_inputs["token_ids_all"], stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], min_tokens=self.share_inputs["min_dec_len"], prompt_lens=self.share_inputs["prompt_lens"], index_to_batch_id=self.share_inputs["index_to_batch_id"], enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False), ) post_process( sampler_or_pooler_output=pooler_output, model_output=model_output_data, share_inputs=self.share_inputs, sampling_metadata=self.sampling_metadata, block_size=self.cache_config.block_size, speculative_decoding=self.speculative_decoding, skip_save_output=True, async_output_queue=self.async_output_queue, think_end_id=self.model_config.think_end_id, splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode", ) self.exist_prefill_flag = False return pooler_output def _dummy_sampler_run( self, hidden_states: paddle.Tensor, model_output: paddle.Tensor, batch_size: int, accept_all_drafts=False, reject_all_drafts=False, ) -> paddle.Tensor: logits = self.model.compute_logits(hidden_states, self.forward_meta) if not self.speculative_decoding: set_value_by_flags_and_idx( self.share_inputs["token_ids_all"], self.share_inputs["input_ids"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_decoder"], self.share_inputs["prompt_lens"], self.share_inputs["step_idx"], self.share_inputs["stop_flags"], ) sampler_output = self.sampler(logits, self.sampling_metadata) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( sampler_output.sampled_token_ids, self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) else: sampler_output = self.sampler( logits, self.sampling_metadata, self.model_config.max_model_len, self.share_inputs, int(self._real_output_token_num_host), self.increment_value, accept_all_drafts, reject_all_drafts, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( self.share_inputs["accept_tokens"], self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) paddle.distributed.broadcast( self.share_inputs["accept_num"], self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) paddle.distributed.broadcast( self.share_inputs["step_idx"], self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) paddle.distributed.broadcast( self.share_inputs["stop_flags"], self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) # 5. post process model_output_data = ModelOutputData( next_tokens=self.share_inputs["next_tokens"], stop_flags=self.share_inputs["stop_flags"], step_idx=self.share_inputs["step_idx"], max_dec_len=self.share_inputs["max_dec_len"], seq_lens_this_time=self.share_inputs["seq_lens_this_time"], eos_token_id=self.share_inputs["eos_token_id"], not_need_stop=self.share_inputs["not_need_stop"], not_need_stop_device=self.share_inputs["not_need_stop_device"], input_ids=self.share_inputs["input_ids"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], is_block_step=self.share_inputs["is_block_step"], full_hidden_states=model_output, msg_queue_id=self.parallel_config.msg_queue_id, mp_rank=self.parallel_config.tensor_parallel_rank, use_ep=self.parallel_config.use_ep, draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), actual_draft_token_num=( self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None ), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), token_ids_all=self.share_inputs["token_ids_all"], stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], min_tokens=self.share_inputs["min_dec_len"], prompt_lens=self.share_inputs["prompt_lens"], mask_rollback=self.share_inputs["mask_rollback"], index_to_batch_id=self.share_inputs["index_to_batch_id"], enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False), ) post_process( sampler_or_pooler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, sampling_metadata=self.sampling_metadata, block_size=self.cache_config.block_size, speculative_decoding=self.speculative_decoding, skip_save_output=True, async_output_queue=self.async_output_queue, think_end_id=self.model_config.think_end_id, splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode", enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0, ) self.exist_prefill_flag = False if self.speculative_decoding: if self.spec_method == SpecMethod.MTP: self.proposer.run( full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph, is_dummy_run=True, ) elif self.spec_method == SpecMethod.NAIVE: pass else: self.proposer.prepare_dummy_speculative_drafts(share_inputs=self.share_inputs, batch_size=batch_size) return sampler_output def _dummy_run( self, num_tokens: int, batch_size: int, expected_decode_len: int = 1, in_capturing: bool = False, capture_prefill: bool = False, accept_all_drafts: bool = False, reject_all_drafts: bool = False, ) -> paddle.Tensor: """ Use dummy inputs to run before formal execution. Args: num_tokens: expected_decode_len: Expected number of tokens generated in_capturing: Is cuda graph in capturing state capture_prefill: Capture pure prefill for cuda graph accept_all_drafts: Target model will accept all draft tokens reject_all_drafts: Target model will reject all draft tokens """ input_length_list, max_dec_len_list, block_num = self.get_input_length_list( num_tokens=num_tokens, batch_size=batch_size, expected_decode_len=expected_decode_len, capture_prefill=capture_prefill, ) self._dummy_prefill_inputs( input_length_list=input_length_list, max_dec_len_list=max_dec_len_list, block_num=block_num, ) if self.spec_method == SpecMethod.MTP: self.proposer.dummy_prefill_inputs( num_tokens=num_tokens, batch_size=batch_size, expected_decode_len=expected_decode_len, ) while True: # 1. Initialize forward meta and attention meta data self._prepare_inputs(is_dummy_or_profile_run=True) # 2. Padding inputs for cuda graph self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph self.padding_cudagraph_inputs() model_inputs = {} model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] model_inputs["generated_modality"] = self.share_inputs["generated_modality"] if self.enable_mm: model_inputs["image_features"] = self.share_inputs["image_features"] # 3. Run model model_output = self.model( model_inputs, self.forward_meta, ) if self.use_cudagraph: model_output = model_output[: self.real_token_num] if self.is_pooling_model: self._dummy_pooler_run(model_output, model_output) break else: if self.speculative_decoding: self.output_token_num_event.synchronize() real_num = int(self._real_output_token_num_host) real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_num] else: real_batch_id_per_token_output = None hidden_states = rebuild_padding( model_output, self.share_inputs["cu_seqlens_q"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_encoder"], real_batch_id_per_token_output, (self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None), ) self._dummy_sampler_run(hidden_states, model_output, batch_size, accept_all_drafts, reject_all_drafts) # 7. Updata 'infer_seed' and step_cuda() if not self.speculative_decoding: self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: break if capture_prefill and self.graph_opt_config.graph_opt_level > 0: # only need to capture prefill break @sot_warmup_guard(True) def capture_model(self) -> None: """ Trigger CUDA Graph capture for all shapes in cuda graph capture list """ if not self.use_cudagraph: logger.info("Skipping CUDA graph capture. Please check GraphOptimizationConfig") return time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() try: if self.fd_config.graph_opt_config.cudagraph_only_prefill: for num_tokens in sorted(capture_sizes, reverse=True): self._dummy_run( num_tokens=num_tokens, batch_size=self.scheduler_config.max_num_seqs, in_capturing=True, expected_decode_len=expected_decode_len, capture_prefill=True, ) logger.info( f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" ) elif self.speculative_decoding and self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX]: for capture_size in sorted(capture_sizes, reverse=True): expected_decode_len = (self.speculative_config.num_speculative_tokens + 1) * 2 self._dummy_run( num_tokens=self.fd_config.get_max_chunk_tokens(), batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), in_capturing=True, expected_decode_len=expected_decode_len, accept_all_drafts=True, ) logger.info( f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{expected_decode_len}" ) else: for batch_size in sorted(capture_sizes, reverse=True): self._dummy_run( num_tokens=self.fd_config.get_max_chunk_tokens(), batch_size=batch_size, in_capturing=True, expected_decode_len=expected_decode_len, ) logger.info( f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}" ) except RuntimeError as e: if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up CUDAGraph " f"with the capture sizes {capture_sizes}. Please try " "lowering `max_num_seqs` or `gpu_memory_utilization` when " "initializing the engine." ) from e if "CUDA error(700)" in str(e): raise RuntimeError( "CUDA error(700), an illegal memory access was encountered, " "when warming up CUDAGraph. Please try to set the startup parameter: " "--graph-optimization-config '{\"use_cudagraph\": false}' to close CUDAGraph" ) from e else: raise e time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") @sot_warmup_guard(True) def capture_model_prefill_and_mixed(self) -> None: """ Trigger CUDA Graph capture for prefill/mixed phase in static split graph mode. """ if not self.use_cudagraph: logger.info("Skipping CUDA graph capture. Please check GraphOptimizationConfig") return time_before_capture = time.perf_counter() capture_sizes = self.cudagraph_capture_sizes_prefill.copy() for capture_size in sorted(capture_sizes, reverse=True): self._dummy_run( num_tokens=capture_size, batch_size=1, in_capturing=True, expected_decode_len=1, capture_prefill=True, ) logger.info(f"Warm up the model (prefill/mixed) with num_tokens:{capture_size}") time_after_capture = time.perf_counter() logger.info( f"Cuda Graph capturing (Prefill and Mixed) took {time_after_capture - time_before_capture} seconds" ) def vision_encoder_compile(self): if self.graph_opt_config.graph_opt_level == 0: return # Currently only PaddleOCR-VL model is supported for vision encoder layer if self.model_config.model_type != "paddleocr_vl": return # Compile for paddleocr_vl vision encoder layer def apply_compile(fn): backend = "CINN" if self.graph_opt_config.graph_opt_level >= 2 else None return paddle.jit.to_static( fn, full_graph=False, backend=backend, ) from fastdeploy.model_executor.models.paddleocr_vl.siglip import SiglipEncoder SiglipEncoder._run_encoder_layer = apply_compile(SiglipEncoder._run_encoder_layer) # Warmup for paddleocr_vl vision encoder layer logger.info(f"Warmup for {self.model_config.model_type} compile...") self._dummy_run_extract_vision_features() @sot_warmup_guard(True) def sot_warmup(self) -> None: start_time = time.perf_counter() for batch_size in self.sot_warmup_sizes: self._dummy_run( num_tokens=self.fd_config.get_max_chunk_tokens(), batch_size=batch_size, ) logger.info(f"SOT warmup the model with the batch size:{batch_size}") logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds") def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int): """ Get indices for guided decoding. When Prefill is done, async compiled logits_processor must be joined. """ if self.guided_backend is None: return [] prefill_done_idxs = [] for idx in range(0, num_running_requests): batch_id = self.share_inputs.get_index_by_batch_id(idx) if self.share_inputs["step_idx"][batch_id] == 0: prefill_done_idxs.append(batch_id) if model_forward_batch is None: return prefill_done_idxs for task in model_forward_batch: if task.task_type.value != RequestType.PREFILL.value: continue # in chunk prefill if self.cache_config.enable_chunked_prefill: if hasattr(task, "prefill_end_index") and hasattr(task, "prompt_token_ids"): task_idx = self.share_inputs.get_index_by_batch_id(task.idx) if len(task.prompt_token_ids) > task.prefill_end_index and task_idx in prefill_done_idxs: prefill_done_idxs.remove(task_idx) return prefill_done_idxs def _execute_empty_mtp_input(self, forward_meta) -> None: """ run ep inference forward with empty input. """ for _ in range(self.fd_config.speculative_config.num_model_steps): self.proposer.model.empty_input_forward(forward_meta) def execute_model( self, model_forward_batch: Optional[List[Request]] = None, num_running_requests: int = None, ) -> None: """ The Entrance of model execute. Args: model_forward_batch: 'Request' contains information related to prompt and is an abstract class at the server level, which is too granular for ModelRunner. We plan to replace it with 'ModelForwardBatch'. intermediate_tensors: num_running_requests: batch_size """ if not self.enable_overlap_schedule: self.execute_model_normal(model_forward_batch, num_running_requests) else: self.execute_model_overlap(model_forward_batch, num_running_requests) def execute_model_normal( self, model_forward_batch: Optional[List[Request]] = None, num_running_requests: int = None, ) -> None: model_inputs, p_done_idxs, _ = self._preprocess(model_forward_batch, num_running_requests) model_output = self._execute(model_inputs) real_bsz = (self.share_inputs["seq_lens_this_time_cpu"].numpy() > 0).sum().item() if model_output is None or real_bsz <= 0: if ( self.fd_config.speculative_config.method == SpecMethod.MTP and hasattr(self.proposer.model, "empty_input_forward") and self.parallel_config.use_ep ): self._execute_empty_mtp_input(self.forward_meta) return model_output_data, sampler_output, post_process_event = self._postprocess( model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz ) if model_output_data is not None: # synchronizes the async DtoH copies of sampled_token_ids. post_process_event.synchronize() self._save_model_output(model_output_data, sampler_output) def execute_model_overlap( self, model_forward_batch: Optional[List[Request]] = None, num_running_requests: int = None, ) -> None: # preprocess and execute model (current batch) model_inputs, p_done_idxs, token_num_event = self._preprocess( model_forward_batch, num_running_requests, self._cached_launch_token_num, self._cached_real_bsz ) model_output = self._execute(model_inputs) # save output (last batch) if self._cached_model_output_data is not None: # synchronizes the async DtoH copies of sampled_token_ids. self._cached_post_process_event.synchronize() self._save_model_output( self._cached_model_output_data, self._cached_sampler_output, ) # postprocess (current batch) # synchronizes the async DtoH copies of seq_lens_this_time_cpu and is_block_step_cpu, # ensuring that the token count for the current batch is ready to be computed and reused in the subsequent batch. token_num_event.synchronize() next_launch_token_num, next_real_bsz = self._predict_next_launch_token_num() real_bsz = (self.share_inputs["seq_lens_this_time_cpu"].numpy() > 0).sum().item() if real_bsz > 0 and model_output is not None: model_output_data, sampler_output, post_process_event = self._postprocess( model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz ) self._cached_model_output_data = model_output_data self._cached_sampler_output = sampler_output self._cached_post_process_event = post_process_event else: if ( self.fd_config.speculative_config.method == SpecMethod.MTP and hasattr(self.proposer.model, "empty_input_forward") and self.parallel_config.use_ep ): self._execute_empty_mtp_input(self.forward_meta) self._cached_model_output_data = None self._cached_sampler_output = None self._cached_post_process_event = None self._cached_launch_token_num = next_launch_token_num self._cached_real_bsz = next_real_bsz def _preprocess( self, model_forward_batch: Optional[List[Request]] = None, num_running_requests: int = None, cached_token_num: int = -1, cached_real_bsz: int = -1, ) -> None: if self.deterministic_logger is not None: self.deterministic_logger.log_batch_start(model_forward_batch) # Reorder inputs to split prefill and decode tokens self._process_reorder() # Prepare inputs of model and sampler. current_launch_token_num, token_num_event = self._prepare_inputs(cached_token_num, cached_real_bsz) self.current_launch_token_num = current_launch_token_num # NOTE(sunxin): # If current_launch_token_num is 0, it means the current worker is in an idle state, # and no further processing is required in TP mode. # However, in EP (Expert Parallelism) mode, there is data on other runner, # the current runner is required to execute part of the model. if current_launch_token_num == 0 and not self.parallel_config.use_ep: return None, None, token_num_event p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests) self.sampler.pre_process(p_done_idxs) if self.fd_config.routing_replay_config.enable_routing_replay: self.routing_replay_manager.pending_update_positions = self.routing_replay_manager.get_token_positions( seq_lens_decoder=self.share_inputs["seq_lens_decoder"], seq_lens_this_time=self.share_inputs["seq_lens_this_time_buffer"], ) # Update state of logits processor for proc in self.sampling_metadata.logits_processors: proc.update_state(self.share_inputs) # Padding inputs for cuda graph self.padding_cudagraph_inputs() model_inputs = {} model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] model_inputs["generated_modality"] = self.share_inputs["generated_modality"] if self.enable_mm: model_inputs["image_features"] = self.share_inputs["image_features"] return model_inputs, p_done_idxs, token_num_event def _execute(self, model_inputs: Dict[str, paddle.Tensor]) -> None: if model_inputs is not None and len(model_inputs) > 0: model_output = self.model( model_inputs, self.forward_meta, ) if self.use_cudagraph: model_output = model_output[: self.real_token_num] else: model_output = None return model_output def _postprocess( self, model_output: paddle.Tensor, p_done_idxs: List[int], model_forward_batch: Optional[List[Request]] = None, num_running_requests: int = None, real_bsz: int = 0, ) -> None: if self.speculative_decoding: self.output_token_num_event.synchronize() real_output_token_num = int(self._real_output_token_num_host) real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_output_token_num] prompt_logprobs_list = self._get_prompt_logprobs_list(model_output) if self.is_pooling_model: pooler_output = self._pool(model_output, num_running_requests) model_output_data = ModelOutputData( next_tokens=self.share_inputs["next_tokens"], stop_flags=self.share_inputs["stop_flags"], step_idx=self.share_inputs["step_idx"], max_dec_len=self.share_inputs["max_dec_len"], seq_lens_this_time=self.share_inputs["seq_lens_this_time"], eos_token_id=self.share_inputs["eos_token_id"], not_need_stop=self.share_inputs["not_need_stop"], not_need_stop_device=self.share_inputs["not_need_stop_device"], input_ids=self.share_inputs["input_ids"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], is_block_step=self.share_inputs["is_block_step"], full_hidden_states=model_output, msg_queue_id=self.parallel_config.msg_queue_id, mp_rank=self.parallel_config.tensor_parallel_rank, use_ep=self.parallel_config.use_ep, draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), actual_draft_token_num=( self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None ), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), token_ids_all=self.share_inputs["token_ids_all"], stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], min_tokens=self.share_inputs["min_dec_len"], prompt_lens=self.share_inputs["prompt_lens"], index_to_batch_id=self.share_inputs["index_to_batch_id"], enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False), ) post_process( sampler_or_pooler_output=pooler_output, model_output=model_output_data, share_inputs=self.share_inputs, sampling_metadata=self.sampling_metadata, block_size=self.cache_config.block_size, save_each_rank=self.parallel_config.use_ep, speculative_decoding=self.speculative_decoding, skip_save_output=False, async_output_queue=self.async_output_queue, enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0, routing_replay_manager=self.routing_replay_manager, ) return None, None, None else: hidden_states = rebuild_padding( model_output, self.share_inputs["cu_seqlens_q"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_encoder"], (real_batch_id_per_token_output if self.speculative_decoding else None), (self.share_inputs["cu_seqlens_q_output"] if self.speculative_decoding else None), ) # 4. Compute logits, Sample if self.deterministic_logger is not None: # Log MD5 of hidden_states (model output) self.deterministic_logger.log_tensor_md5s( {"hidden_states": hidden_states}, forward_batch_reqs_list=self.forward_batch_reqs_list, stage="hidden_states", ) logits = self.model.compute_logits(hidden_states, self.forward_meta) if self.deterministic_logger is not None: # Log MD5 of logits (before sampling) self.deterministic_logger.log_tensor_md5s( {"logits": logits}, forward_batch_reqs_list=self.forward_batch_reqs_list, stage="logits" ) if not self.speculative_decoding: set_value_by_flags_and_idx( self.share_inputs["token_ids_all"], self.share_inputs["input_ids"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_decoder"], self.share_inputs["prompt_lens"], self.share_inputs["step_idx"], self.share_inputs["stop_flags"], ) sampler_output = self.sampler( logits, self.sampling_metadata, p_done_idxs, ) if self.deterministic_logger is not None: # Log MD5 of sampling results self.deterministic_logger.log_tensor_md5s( {"sampled_token_ids": sampler_output.sampled_token_ids}, forward_batch_reqs_list=self.forward_batch_reqs_list, stage="sampled_tokens", ) if ( self.enable_logprob and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and sampler_output.logprobs_tensors is None ): sampler_output.logprobs_tensors = LogprobsTensors( logprob_token_ids=sampler_output.sampled_token_ids, logprobs=paddle.empty_like(sampler_output.sampled_token_ids, device="cpu", dtype="float32"), selected_token_ranks=paddle.empty( [sampler_output.sampled_token_ids.shape[0]], device="cpu", dtype="int64" ), ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( sampler_output.sampled_token_ids, self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) else: sampler_output = self.sampler( logits, self.sampling_metadata, self.model_config.max_model_len, self.share_inputs, real_output_token_num, self.increment_value, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( self.share_inputs["accept_tokens"], self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) paddle.distributed.broadcast( self.share_inputs["accept_num"], self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) paddle.distributed.broadcast( self.share_inputs["step_idx"], self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) paddle.distributed.broadcast( self.share_inputs["stop_flags"], self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) # 5. Post Process model_output_data = ModelOutputData( next_tokens=self.share_inputs["next_tokens"], stop_flags=self.share_inputs["stop_flags"], step_idx=self.share_inputs["step_idx"], max_dec_len=self.share_inputs["max_dec_len"], seq_lens_this_time=self.share_inputs["seq_lens_this_time"], eos_token_id=self.share_inputs["eos_token_id"], not_need_stop=self.share_inputs["not_need_stop"], not_need_stop_device=self.share_inputs["not_need_stop_device"], input_ids=self.share_inputs["input_ids"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], is_block_step=self.share_inputs["is_block_step"], full_hidden_states=model_output, msg_queue_id=self.parallel_config.msg_queue_id, mp_rank=self.parallel_config.tensor_parallel_rank, use_ep=self.parallel_config.use_ep, draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), actual_draft_token_num=( self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None ), token_ids_all=self.share_inputs["token_ids_all"], accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], min_tokens=self.share_inputs["min_dec_len"], prompt_lens=self.share_inputs["prompt_lens"], mask_rollback=self.share_inputs["mask_rollback"], prompt_logprobs_list=prompt_logprobs_list, index_to_batch_id=self.share_inputs["index_to_batch_id"], enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False), ) if self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill": skip_save_output = True else: skip_save_output = False post_process( sampler_or_pooler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, sampling_metadata=self.sampling_metadata, block_size=self.cache_config.block_size, save_each_rank=self.parallel_config.use_ep, speculative_decoding=self.speculative_decoding, skip_save_output=skip_save_output, async_output_queue=self.async_output_queue, think_end_id=self.model_config.think_end_id, splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode", enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0, routing_replay_manager=self.routing_replay_manager, ) if self.guided_backend is not None and sampler_output is not None: self.sampler.post_process(sampler_output.sampled_token_ids) # 5.1. Async cpy post_process_event = paddle.device.cuda.create_event() if envs.FD_USE_GET_SAVE_OUTPUT_V1: # If one query is preempted, there is no sampled token for it, we use token_id PREEMPTED_TOKEN_ID to signal server, abort is finished. paddle.assign( paddle.where( self.share_inputs["last_preempted_idx"][: sampler_output.sampled_token_ids.shape[0]] == 1, PREEMPTED_TOKEN_ID, sampler_output.sampled_token_ids, ), sampler_output.sampled_token_ids, ) # if not self.speculative_decoding: self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False) if self.speculative_decoding: self.share_inputs["accept_tokens_cpu"].copy_(self.share_inputs["accept_tokens"], False) self.share_inputs["accept_num_cpu"].copy_(self.share_inputs["accept_num"], False) self.share_inputs["seq_lens_decoder_cpu"].copy_(self.share_inputs["seq_lens_decoder"], False) self.share_inputs["prompt_lens_cpu"].copy_(self.share_inputs["prompt_lens"], False) post_process_event.record() # 6. Speculative decode -- proposer run (method="naive" has proposer=None, skip) # For naive mode: seq_lens_this_time is already reset to 1 inside # unified_update_model_status kernel. For MTP/Ngram, the proposer # will overwrite it with (draft_count + 1) below. if self.speculative_decoding and self.proposer is not None: if self.spec_method == SpecMethod.MTP: self.proposer.run( full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph, real_bsz=real_bsz, ) elif self.spec_method == SpecMethod.NAIVE: pass else: self.proposer.run(share_inputs=self.share_inputs) # 7. Update 'infer_seed' and step_cuda() if not self.speculative_decoding: self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED if self.speculative_decoding: speculate_schedule_cache( self.share_inputs["draft_tokens"], self.share_inputs["block_tables"], self.share_inputs["stop_flags"], self.share_inputs["prompt_lens"], self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_decoder"], self.share_inputs["step_seq_lens_decoder"], self.share_inputs["step_draft_tokens"], self.share_inputs["step_seq_lens_this_time"], self.share_inputs["accept_num"], self.share_inputs["accept_tokens"], self.share_inputs["is_block_step"], self.share_inputs["not_need_stop_device"], self.cache_config.block_size, self.speculative_config.num_speculative_tokens, ) self.exist_prefill_flag = False return model_output_data, sampler_output, post_process_event def _save_model_output( self, model_output_data, sampler_output, ): if self.speculative_decoding: skip_save_output = self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" save_output_specualate( sampler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, save_each_rank=self.parallel_config.use_ep, skip_save_output=skip_save_output, ) else: save_output_normal( model_output=model_output_data, sampler_output=sampler_output, share_inputs=self.share_inputs, async_output_queue=self.async_output_queue, save_each_rank=self.parallel_config.use_ep, ) def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]: num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum()) hidden_states = hidden_states[:num_scheduled_tokens] prompt_lens = self.share_inputs["prompt_lens"][:num_running_requests] prompt_token_ids = self.share_inputs["token_ids_all"] pooling_metadata = PoolingMetadata( prompt_lens=prompt_lens, prompt_token_ids=prompt_token_ids, pooling_params=self.pooling_params, ) num_scheduled_tokens_list = [ int(self.share_inputs["seq_lens_this_time"][i]) for i in range(num_running_requests) ] device_str = "gpu" if hidden_states.place.is_gpu_place() else "cpu" pooling_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=device_str) raw_pooler_output = self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata) seq_lens_decoder = self.share_inputs["seq_lens_decoder"][:num_running_requests] seq_lens_encoder = self.share_inputs["seq_lens_encoder"][:num_running_requests] pooler_output: list[Optional[paddle.Tensor]] = [] pooler_output_idx = 0 for i, prompt_len in enumerate(pooling_metadata.prompt_lens): current_seq_len = num_scheduled_tokens_list[i] if current_seq_len == 0: pooler_output.append(None) continue total_processed = int(seq_lens_decoder[i]) + int(seq_lens_encoder[i]) if total_processed == int(prompt_len): output = raw_pooler_output[pooler_output_idx] else: output = None pooler_output.append(output) pooler_output_idx += 1 return PoolerOutput(outputs=pooler_output) def _execute_empty_input(self, forward_meta) -> None: """ In certain scenarios, such as during EP, the runner needs to execute partial modules of the model without input data. This requires the model to implement the `empty_input_forward` method. """ if hasattr(self.model, "empty_input_forward"): self.model.empty_input_forward(forward_meta) else: raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") @profile_run_guard(True) def profile_run(self) -> None: """Execute a forward pass with dummy inputs to profile the memory usage of the model""" # Initialize kv cache for profile run. After profile run kv cache will be reset. # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.cache_config.total_block_num self.initialize_kv_cache(profile=True) if self.spec_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) # 1. Profile with multimodal encoder & encoder cache # 2. Dummy run num_tokens = self.fd_config.get_max_chunk_tokens() logger.info( f"Dummy run with {num_tokens} tokens, mm_max_tokens_per_item: {self.model_config.mm_max_tokens_per_item}" ) self._dummy_run( num_tokens=num_tokens, batch_size=self.scheduler_config.max_num_seqs, ) # 3. gc if self.spec_method == SpecMethod.MTP: self.proposer.clear_mtp_cache(profile=True) self.clear_cache(profile=True) def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ Set a globally unified block number and update the model's shared input. Args: num_gpu_blocks: """ self.num_gpu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num self.initialize_kv_cache() # Reset free list free_list = list( range( self.num_gpu_blocks - 1, int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list_len = len(free_list) self.share_inputs.update( { "free_list": paddle.to_tensor(free_list, dtype="int32"), "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), } ) if self.spec_method == SpecMethod.MTP: self.proposer.update_mtp_block_num(num_gpu_blocks) def cal_theortical_kvcache(self): """ Calculate the total block memory required at the model level TODO(gongshaotian): Move to Attention Backend """ """ Byte of dtype: - default(bf16): 2 - cache_int8: 1 - cache_int4: """ cache_quant_dtype = None if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_quant_dtype = self.quant_config.kv_cache_quant_type if cache_quant_dtype is not None: # int8, int8_zp, fp8, fp8_zp byte_of_dtype = 1 else: # default byte_of_dtype = 2 hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads # NOTE(liuzichang): Implement multi-layer MTP architecture in the future num_layers = ( self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio if self.spec_method == SpecMethod.MTP else self.model_config.num_hidden_layers ) # NOTE:(changwenbin) Determie whether it is Multi-Head Latent Attention, # To rationalize the allocation of kvcache. if self.fd_config.cache_config.use_mla_cache: required_memory = ( byte_of_dtype * (self.fd_config.model_config.kv_lora_rank + self.fd_config.model_config.qk_rope_head_dim) * (self.cache_config.block_size) * num_layers ) # compress_kv + k_pe elif self.dsa_cache: required_memory = ( 1 * ( self.fd_config.model_config.kv_lora_rank + self.fd_config.model_config.kv_lora_rank // 128 * 4 + 2 * self.fd_config.model_config.qk_rope_head_dim # indexer + self.fd_config.model_config.index_head_dim + self.fd_config.model_config.index_head_dim // 128 * 4 ) * (self.cache_config.block_size) * num_layers ) else: required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v return required_memory def clear_cache(self, profile=False): """Clear cached data from shared inputs and forward metadata""" create_cache_tensor = profile or not ( self.fd_config.cache_config.num_cpu_blocks > 0 or self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size if not create_cache_tensor: for name, tensor in self.cache_kvs_map.items(): unset_data_ipc(tensor, name, True, False) self.cache_ready_signal.value[local_rank] = 0 self.cache_kvs_map.clear() self.share_inputs.pop("caches", None) if self.forward_meta is not None: self.forward_meta.clear_caches() paddle.device.cuda.empty_cache() def clear_parameters(self, pid): """Dynamic model loader use to clear parameters use for RL""" # Clear CUDAGraph if self.use_cudagraph: self.model.clear_graph_opt_backend() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) if self.spec_method == SpecMethod.MTP: self.proposer.model.clear_graph_opt_backend() self.proposer.clear_mtp_cache() self.clear_cache() paddle.device.cuda.empty_cache() self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") def clear_requests(self): """Dynamic model loader use to clear requests use for RL""" self.share_inputs["stop_flags"][:] = True # prompt_logprobs self.prompt_logprobs_reqs.clear() self.in_progress_prompt_logprobs.clear() self.forward_batch_reqs_list = [None for _ in range(self.scheduler_config.max_num_seqs)] # Routing Replay if self.routing_replay_manager: self.routing_replay_manager.clear_all_request() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" # Update parameters self.dynamic_weight_manager.update_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) # Reset share_inputs self.share_inputs.reset_share_inputs() if self.spec_method == SpecMethod.MTP: self.proposer.model_inputs.reset_model_inputs() self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) self.initialize_kv_cache() # Recapture CUDAGraph if self.use_cudagraph: self.capture_model() # Send single self.dynamic_weight_manager.finalize_update(pid) self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") def update_weights(self, version: str = None, verify_checksum: bool = False): return self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum) def sleep(self, tags): logger.info(f">>> start offloading memory, tags: {tags}") start_time = time.perf_counter() # Clear weights, deepep_buffer, cudagraph, etc. if "weight" in tags.split(","): if self.is_weight_sleeping: logger.info("GPU model runner's weight is already sleeping, no need to sleep again!") return if self.use_cudagraph: self.model.clear_graph_opt_backend() if self.fd_config.parallel_config.enable_expert_parallel: self.dynamic_weight_manager.clear_deepep_buffer() self.dynamic_weight_manager.clear_model_weight() if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle: self.dynamic_weight_manager.clear_communication_group() self.is_weight_sleeping = True # Clear KV cache if "kv_cache" in tags.split(","): if self.is_kvcache_sleeping: logger.info("GPU model runner's kv cache is already sleeping, no need to sleep again!") return if self.spec_method == SpecMethod.MTP: self.proposer.clear_mtp_cache() self.clear_cache() self.is_kvcache_sleeping = True paddle.device.cuda.empty_cache() logger.info(f"<<< finish offloading memory! time cost: {time.perf_counter()-start_time:.3f}s") print_gpu_memory_use(f"After offloading memory [{tags}]", self.local_rank, self.device_id) def wakeup(self, tags): if tags == "weight" and self.use_cudagraph and self.is_kvcache_sleeping: raise RuntimeError( "Waking up [weight] alone is not supported when CUDA Graph is enabled, " "as recapturing the graph requires the KV cache to be rebuilt first. " "Please wake up [kv_cache] first." ) logger.info(f">>> start reloading memory, tags: {tags}") start_time = time.perf_counter() # Reset share_inputs to restore tensor shapes and values if self.spec_method == SpecMethod.MTP: self.proposer.model_inputs.reset_model_inputs() self.share_inputs.reset_share_inputs() # Reinitialize KV cache if "kv_cache" in tags.split(","): if not self.is_kvcache_sleeping: logger.info("GPU model runner's kv cache is not sleeping, no need to wakeup!") return if self.spec_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) self.initialize_kv_cache() self.is_kvcache_sleeping = False # Reload weights, deepep_buffer, cudagraph, etc. if "weight" in tags.split(","): if not self.is_weight_sleeping: logger.info("GPU model runner's weight is not sleeping, no need to wakeup!") return if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle: self.dynamic_weight_manager.restart_communication_group() if self.fd_config.parallel_config.enable_expert_parallel: self.dynamic_weight_manager.recreate_deepep_buffer() self.dynamic_weight_manager.reload_model_weights() if self.use_cudagraph: self.capture_model() self.is_weight_sleeping = False logger.info(f"<<< finish reloading memory! time cost: {time.perf_counter()-start_time:.3f}s") print_gpu_memory_use(f"After reloading memory [{tags}]", self.local_rank, self.device_id) def padding_cudagraph_inputs(self) -> None: """ Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. """ # In init_attention_metadata, the decode buffer has already been cleared # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. if self.use_cudagraph: self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] return def _init_image_preprocess(self) -> None: image_preprocess = AdaptiveImageProcessor.from_pretrained(str(self.model_config.model)) image_preprocess.image_mean_tensor = paddle.to_tensor(image_preprocess.image_mean, dtype="float32").reshape( [1, 3, 1, 1] ) image_preprocess.image_std_tensor = paddle.to_tensor(image_preprocess.image_std, dtype="float32").reshape( [1, 3, 1, 1] ) image_preprocess.rescale_factor = paddle.to_tensor(image_preprocess.rescale_factor, dtype="float32") image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze([-2, -1]).repeat_interleave( self.model_config.vision_config.patch_size**2 * 1, -1 ) image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze([-2, -1]).repeat_interleave( self.model_config.vision_config.patch_size**2 * 1, -1 ) self.image_preprocess = image_preprocess def _preprocess_mm_task(self, one: dict) -> None: """process batch""" input_ids = one["input_ids"][np.newaxis, :] input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64) token_type_ids = one["token_type_ids"][np.newaxis, :] token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64) if "images" in one and one["images"] is not None: image_type_ids = one["image_type_ids"][np.newaxis, :] images = one["images"] image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64) images = paddle.to_tensor(images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16") grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64") else: image_type_ids = None images = None grid_thw = None if one["position_ids"] is not None: position_ids = paddle.to_tensor(one["position_ids"], dtype="int64") else: position_ids = None result = dict( input_ids=input_ids, image_type_ids=image_type_ids, token_type_ids=token_type_ids, position_ids=position_ids, grid_thw=grid_thw, images=images, ) return result def extract_vision_features_ernie(self, vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: """ vision feature extractor for ernie-vl """ assert len(vision_inputs["images_lst"]) > 0, "at least one image needed" grid_thw = paddle.to_tensor(vision_inputs["grid_thw_lst"], dtype=paddle.int64) # ernie-vl has images norm images = paddle.concat(vision_inputs["images_lst"]).cast("float32") images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor images = images / self.image_preprocess.image_std_tensor images = images.cast("bfloat16") with paddle.amp.auto_cast( True, custom_black_list=self.amp_black, custom_white_list=self.amp_white, level="O2", dtype=self.model_config.dtype, ): image_features = self.model.vision_model.extract_feature(images, grid_thw) if self.parallel_config.tensor_parallel_size > 1: S, C = image_features.shape image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2]) image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea image_features = image_features.reshape([S, -1]) # ernie-vl has resampler_model image_features = self.model.resampler_model( image_features, grid_thw, ) return image_features def extract_vision_features_qwen(self, vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: assert len(vision_inputs["images_lst"]) > 0, "at least one image needed" grid_thw = paddle.to_tensor(vision_inputs["grid_thw_lst"], dtype=paddle.int64) images = paddle.concat(vision_inputs["images_lst"]).cast("bfloat16") with paddle.amp.auto_cast( True, custom_black_list=self.amp_black, custom_white_list=self.amp_white, level="O2", dtype=self.model_config.dtype, ): image_features = self.model.visual.extract_feature(images, grid_thw) return image_features def extract_vision_features_paddleocr(self, inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: if envs.FD_ENABLE_MAX_PREFILL: inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"]) images = paddle.concat(inputs["images_lst"]).cast("bfloat16") grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64") position_ids = paddle.to_tensor(inputs["vit_position_ids_lst"], dtype="int64") cu_seqlens = paddle.cumsum(paddle.to_tensor(inputs["cu_seqlens"])).cast("int32") else: assert inputs["images"] is not None grid_thw = inputs["grid_thw"] images = inputs["images"] position_ids = [] cu_seqlens = [0] for idx, thw in enumerate(grid_thw): numel = np.prod(np.array(thw)) position_ids.append(paddle.arange(numel) % np.prod(thw[1:])) cu_seqlens.append(cu_seqlens[-1] + numel) position_ids = paddle.concat(position_ids, axis=0).to(images.place) cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place) with paddle.amp.auto_cast( True, custom_black_list=self.amp_black, custom_white_list=self.amp_white, level="O2", dtype=self.model_config.dtype, ): image_features = self.model.visual( pixel_values=images, image_grid_thw=grid_thw, position_ids=position_ids, interpolate_pos_encoding=True, cu_seqlens=cu_seqlens, use_rope=True, window_size=-1, ) image_features = self.model.projector(image_features, grid_thw) image_features = paddle.concat(image_features, axis=0) return image_features @paddle.no_grad() def extract_vision_features(self, multi_vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: """extract_vision_features""" if "ernie" in self.model_config.model_type: return self.extract_vision_features_ernie(multi_vision_inputs) elif "qwen" in self.model_config.model_type: return self.extract_vision_features_qwen(multi_vision_inputs) elif "paddleocr" in self.model_config.model_type: return self.extract_vision_features_paddleocr(multi_vision_inputs) else: raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") @paddle.no_grad() def _dummy_run_extract_vision_features(self): grid_thw_list = ([(1, 10, 88), (1, 10, 80)], [(1, 14, 62), (1, 20, 42), (1, 14, 60)]) for grid_thw in grid_thw_list: images = [] position_ids = [] cu_seqlens = [0] for idx, thw in enumerate(grid_thw): numel = np.prod(np.array(thw)) images.append(paddle.uniform(shape=[numel, 3, 14, 14], dtype="float32", min=0.0, max=1.0)) position_ids.append(paddle.arange(numel) % np.prod(thw[1:])) cu_seqlens.append(cu_seqlens[-1] + numel) images = paddle.concat(images, axis=0) position_ids = paddle.concat(position_ids, axis=0).to(images.place) cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place) with paddle.amp.auto_cast( True, custom_black_list=self.amp_black, custom_white_list=self.amp_white, level="O2", dtype=self.model_config.dtype, ): self.model.visual( pixel_values=images, image_grid_thw=grid_thw, position_ids=position_ids, interpolate_pos_encoding=True, cu_seqlens=cu_seqlens, use_rope=True, window_size=-1, ) @paddle.no_grad() def prepare_rope3d( self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int] ) -> list[paddle.Tensor]: """prepare_rope3d""" rope_emb_lst = get_rope_3d( position_ids=position_ids, rotary_dim=self.model_config.head_dim, partial_rotary_factor=1.0, base=self.model_config.rope_theta, max_position=self.model_config.max_model_len, freq_allocation=getattr(self.model_config, "freq_allocation", 20), rope_scaling=getattr(self.model_config, "rope_scaling", {}), model_type=self.model_config.model_type, max_len_lst=max_len_lst, cumsum_seqlens=cumsum_seqlens, ) return rope_emb_lst def _get_prompt_logprobs_list( self, hidden_states: paddle.Tensor, ) -> list[Optional[LogprobsTensors]]: if len(self.prompt_logprobs_reqs) > 0: assert ( not self.fd_config.cache_config.enable_prefix_caching ), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching." logprobs_mode = self.fd_config.model_config.logprobs_mode prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None] completed_prefill_reqs: list[Request] = [] for req_id, request in self.prompt_logprobs_reqs.items(): num_prompt_logprobs = request.sampling_params.prompt_logprobs if request.prompt_token_ids is None or num_prompt_logprobs is None: continue if num_prompt_logprobs == -1: num_prompt_logprobs = self.ori_vocab_size num_tokens = request.prefill_end_index - request.prefill_start_index num_prompt_tokens = len(request.prompt_token_ids) logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id) if not logprobs_tensors: logprobs_tensors = LogprobsTensors.empty_cpu(num_prompt_tokens - 1, num_prompt_logprobs + 1) self.in_progress_prompt_logprobs[req_id] = logprobs_tensors start_idx = request.prefill_start_index start_tok = start_idx + 1 num_remaining_tokens = num_prompt_tokens - start_tok batch_id = self.share_inputs.get_index_by_batch_id(request.idx) if num_tokens <= num_remaining_tokens: # This is a chunk, more tokens remain. # In the == case, there are no more prompt logprobs to produce # but we want to defer returning them to the next step where we # have new generated tokens to return. num_logits = num_tokens else: # This is the last chunk of prompt tokens to return. num_logits = num_remaining_tokens completed_prefill_reqs.append(request) prompt_logprobs_list[batch_id] = logprobs_tensors if num_logits <= 0: # This can happen for the final chunk if we prefilled exactly # (num_prompt_tokens - 1) tokens for this request in the prior # step. There are no more prompt logprobs to produce. continue offset = self.share_inputs["cu_seqlens_q"][batch_id] prompt_hidden_states = hidden_states[offset : offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states) prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits] prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64") if logprobs_mode == "raw_logprobs": raw_logprobs = self.sampler.compute_logprobs(logits) elif logprobs_mode == "raw_logits": raw_logprobs = logits token_ids, logprobs, ranks = self.sampler.gather_logprobs( raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor ) # Synchronize before using token_ids, logprobs and ranks to ensure async copy are completed. paddle.device.synchronize() chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False) logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False) logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, False) for req in completed_prefill_reqs: del self.prompt_logprobs_reqs[req.request_id] del self.in_progress_prompt_logprobs[req.request_id] return prompt_logprobs_list def initialize_routing_replay_manager(self): """Initialize the routing replay manager after initialize the KVCache""" # Use updated block number self.routing_replay_manager = RoutingReplayManager( fd_config=self.fd_config, block_table=self.share_inputs["block_tables"], total_block_num=self.num_gpu_blocks, )