""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import copy import threading import time import traceback from collections import deque from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Union import numpy as np import paddle from fastdeploy import envs from fastdeploy.cache_manager.multimodal_cache_manager import ( EncoderCacheManager, ProcessorCacheManager, ) from fastdeploy.config import ErnieArchitectures from fastdeploy.engine.request import ( ImagePosition, Request, RequestOutput, RequestStatus, RequestType, ) from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.input.utils import IDS_TYPE_FLAG from fastdeploy.inter_communicator import IPCSignal from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.multimodal.hasher import MultimodalHasher from fastdeploy.platforms import current_platform from fastdeploy.utils import download_from_bos, init_bos_client, llm_logger @dataclass class ScheduledDecodeTask: """ Task for allocating new blocks to decode. """ idx: int request_id: str block_tables: list[int] task_type: RequestType = RequestType.DECODE @dataclass class ScheduledPreemptTask: """ Task for terminating inference to recycle resource. """ idx: int request_id: str task_type: RequestType = RequestType.PREEMPTED @dataclass class ScheduledExtendBlocksTask: """ Task for allocating new blocks to extend. """ idx: int request_id: str extend_block_tables: list[int] task_type: RequestType = RequestType.EXTEND class SignalConsumer: """ A class that consumes a signal value up to a specified limit. This class maintains an internal signal value and allows controlled consumption of that signal. The signal can be watched at any time, but can only be consumed a limited number of times before being reset to zero. """ def __init__(self, signal, consume_limit): """ Initialize the SignalConsumer with a signal value and consumption limit. Args: signal: The initial signal value to be consumed. consume_limit (int): The maximum number of times the signal can be consumed before being reset to 0. Must be a positive integer. Raises: AssertionError: If consume_limit is not greater than 0. """ assert consume_limit > 0 self._signal = signal self._consume_limit = consume_limit def watch(self): """ Get the current signal value without consuming it. This method allows reading the signal value any number of times without affecting the consumption limit or the signal value itself. Returns: The current signal value. """ return self._signal def consume(self): """ Consume the signal value, decrementing the consumption limit. This method returns the current signal value and decrements the consumption counter. When the consumption limit reaches zero, the signal is automatically reset to 0. The consumption happens in a finally block to ensure the limit is decremented even if an exception occurs while processing the signal. Returns: The current signal value before consumption. Note: After the consumption limit is reached, this method will continue to return 0 on subsequent calls. """ try: return self._signal finally: if self._consume_limit > 0: self._consume_limit -= 1 if self._consume_limit == 0: self._signal = 0 class ResourceManagerV1(ResourceManager): """ Resource manager for scheduler v1. In scheduler v1, all gpu blocks are managed by PrefixCacheManager. Tasks sent to worker are divided into 3 types, PREFILL、DECODE and PREEMPTED. For prefill task, the worker infer with one step and then stopped for this query if not all prompt tokens are computed. For decode task, the work continues to decode until allocated blocks are exhausted. For preempted task, the work reset all inputs to terminate the inference. """ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id=0): super(ResourceManagerV1, self).__init__( max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id ) # req_id -> Request self.config = config self.requests: dict[str, Request] = {} # Priority queues for requests. self.waiting: deque[Request] = deque() self.running: list[Request] = [] self.preallocated_reqs: dict[str, Request] = {} self.enable_max_prefill = envs.FD_ENABLE_MAX_PREFILL self.finish_execution_pool = ThreadPoolExecutor(max_workers=1) self.lock = threading.Lock() self.to_be_rescheduled_request_id_set = set() main_process_metrics.max_batch_size.set(max_num_seqs) self.using_extend_tables_req_id = set() self.reuse_block_num_map = dict() # need block nums need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32) self.need_block_num_signal = IPCSignal( name="need_block_num_signal", array=need_block_num_data, dtype=np.int32, suffix=self.config.parallel_config.local_engine_worker_queue_port, create=True, ) self.need_block_num_map = dict() self.encoder_cache = None if config.model_config.enable_mm and config.cache_config.max_encoder_cache > 0: self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache) self.processor_cache = None if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0: max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024) self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes) self.bos_client = None self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4) def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size def get_new_block_nums(self, request: Request, num_new_tokens: int): block_num = ( request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size - len(request.block_tables) if self.config.speculative_config.method is not None: block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq) return block_num def _prepare_prefill_task(self, request, new_token_num): request.prefill_start_index = request.num_computed_tokens request.prefill_end_index = request.num_computed_tokens + new_token_num request.task_type = RequestType.PREFILL return request def _prepare_decode_task(self, request): return ScheduledDecodeTask(idx=request.idx, request_id=request.request_id, block_tables=request.block_tables) def _prepare_preempt_task(self, request): return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id) def reschedule_preempt_task(self, request_id, process_func=None): with self.lock: if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests: request = self.requests[request_id] if process_func is not None: process_func(request) self.waiting.appendleft(request) self.to_be_rescheduled_request_id_set.remove(request_id) def _info_each_block(self): """ print each req block """ for req in self.running: llm_logger.debug( f"req idx {req.idx} occupy {len(req.block_tables)} block_tables and {len(req.extend_block_tables)} extend_block_tables" ) def _can_preempt(self): """ cannot preempt request which use extend block """ for req in self.running: if not req.use_extend_tables: return True return False def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): """ If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out. """ can_schedule = False while self._can_preempt(): if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): preempted_req = self.running.pop() if preempted_req.use_extend_tables: self.running.insert(0, preempted_req) continue preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 if self.config.scheduler_config.splitwise_role == "decode": self.tasks_list[preempted_req.idx] = None self.stop_flags[preempted_req.idx] = True if preempted_req.request_id in self.requests: del self.requests[preempted_req.request_id] if preempted_req.request_id in self.req_dict: del self.req_dict[preempted_req.request_id] self._free_blocks(preempted_req) llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}") else: self._free_blocks(preempted_req) preempted_req.cached_block_num = 0 self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}") preempted_reqs.append(preempted_req) scheduled_reqs.append(self._prepare_preempt_task(preempted_req)) llm_logger.debug( f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}" ) llm_logger.debug(self.info()) self._info_each_block() if preempted_req == request: # No more request to preempt. can_schedule = False break else: # The request can be scheduled. can_schedule = True break return can_schedule def _update_mm_hashes(self, request): if request.multimodal_inputs is None: return inputs = request.multimodal_inputs if ( inputs.get("images", None) is not None and inputs.get("image_patch_id", None) is not None and inputs.get("grid_thw", None) is not None and len(inputs["grid_thw"]) != 0 ): grid_thw = [] new_mm_positions, new_mm_hashes = [], [] image_st = 0 for idx, one in enumerate(inputs["grid_thw"]): t, h, w = one[0], one[1], one[2] if t == 1: grid_thw.append(one) new_mm_positions.append(inputs["mm_positions"][idx]) new_mm_hashes.append(inputs["mm_hashes"][idx]) image_st += h * w else: grid_thw.extend([[2, h, w]] * (t // 2)) token_st = inputs["mm_positions"][idx].offset for _ in range(t // 2): new_mm_positions.append(ImagePosition(token_st, h * w // 4)) # videos are split into patches every 2 frames, need to rehash new_mm_hashes.append( MultimodalHasher.hash_features(inputs["images"][image_st : image_st + 2 * h * w]) ) image_st += 2 * h * w token_st += h * w // 4 inputs["mm_positions"] = new_mm_positions inputs["mm_hashes"] = new_mm_hashes elif inputs.get("mm_positions", None) is None or inputs.get("mm_hashes", None) is None: inputs["mm_positions"] = [] inputs["mm_hashes"] = [] def _is_mm_request(self, request): inputs = request.multimodal_inputs if inputs is None or len(inputs) == 0: return False if ( (inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0) or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0) or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0) ): return True elif ( inputs.get("images", None) is not None and inputs.get("image_patch_id", None) is not None and inputs.get("grid_thw", None) is not None ): return True return False def _get_num_new_tokens(self, request, token_budget): # TODO: set condition to new _get_num_new_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) request.with_image = False if not self.config.model_config.enable_mm: return num_new_tokens inputs = request.multimodal_inputs if inputs.get("patch_idx", None) is not None and inputs.get("patch_map", None) is not None: pre_end_idx = request.num_computed_tokens new_end_idx = pre_end_idx + num_new_tokens prompt_token_ids_len = len(request.prompt_token_ids) if not inputs.get("tts", False): assert prompt_token_ids_len == len(inputs["patch_idx"]), ( prompt_token_ids_len, len(inputs["patch_idx"]), ) def _compute_audio_prefix_count(end_idx, end_patch_idx): audio_prefix_count = 0 pre_patch_end_idx = 0 for patch_idx in range(end_patch_idx + 1): patch_map = inputs["patch_map"][patch_idx] modal_id = patch_map["modal_id"] if modal_id == IDS_TYPE_FLAG["audio"]: if patch_idx != end_patch_idx: audio_prefix_count += patch_map["end_idx"] - pre_patch_end_idx else: audio_prefix_count += end_idx - pre_patch_end_idx pre_patch_end_idx = patch_map["end_idx"] return audio_prefix_count # start if pre_end_idx >= prompt_token_ids_len: start_patch_idx = inputs["patch_idx"][-1] else: start_patch_idx = inputs["patch_idx"][pre_end_idx] start_patch_map = inputs["patch_map"][start_patch_idx] request.image_start = start_patch_map["image_num"] request.video_start = start_patch_map["video_num"] request.audio_start = _compute_audio_prefix_count(pre_end_idx, start_patch_idx) # end if new_end_idx >= prompt_token_ids_len: end_patch_idx = inputs["patch_idx"][-1] else: end_patch_idx = inputs["patch_idx"][new_end_idx] if request.prompt_token_ids[new_end_idx] in [ inputs["image_end_id"], inputs["video_end_id"], inputs["audio_end_id"], ]: end_patch_idx -= 1 end_patch_map = inputs["patch_map"][end_patch_idx] end_modal_id = end_patch_map["modal_id"] if end_modal_id == IDS_TYPE_FLAG["image"]: new_end_idx = end_patch_map["end_idx"] # 当前模态结束位置 if end_modal_id == IDS_TYPE_FLAG["video"] and "can_split_idx_list" in inputs: can_split_idx_list = inputs["can_split_idx_list"] for i in range(len(can_split_idx_list)): if can_split_idx_list[i] >= new_end_idx: new_end_idx = can_split_idx_list[i] break num_new_tokens = new_end_idx - pre_end_idx request.image_end = end_patch_map["image_num"] request.video_end = end_patch_map["video_num"] request.audio_end = _compute_audio_prefix_count(new_end_idx, end_patch_idx) elif ( inputs.get("images", None) is not None and inputs.get("image_patch_id", None) is not None and inputs.get("grid_thw", None) is not None ): input_ids_lst = request.prompt_token_ids + request.output_token_ids input_ids = paddle.to_tensor(input_ids_lst, dtype="int64") image_patch_id = inputs["image_patch_id"] if request.multimodal_img_boundaries is None: grid_thw = [] for idx, one in enumerate(inputs["grid_thw"]): t, h, w = one[0], one[1], one[2] if t == 1: grid_thw.append(one) else: grid_thw.extend([[2, h, w]] * (t // 2)) grid_thw = paddle.to_tensor(grid_thw, dtype="int64") if current_platform.is_xpu(): from fastdeploy.model_executor.ops.xpu import get_img_boundaries elif current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import ( get_img_boundaries, ) else: from fastdeploy.model_executor.ops.gpu import get_img_boundaries request.multimodal_img_boundaries = get_img_boundaries( task_input_ids=input_ids, grid_thw=grid_thw, image_patch_id=image_patch_id ).numpy() grid_thw = grid_thw.numpy().reshape([-1, 3]) inputs["grid_thw"] = grid_thw grid_thw = inputs["grid_thw"] img_boundaries_idx = request.multimodal_img_boundaries[0] img_num_per_boundary = request.multimodal_img_boundaries[1] ori_prompt_len = img_boundaries_idx[-1].item() pre_end_idx = request.num_computed_tokens new_end_idx = pre_end_idx + num_new_tokens if new_end_idx < ori_prompt_len and input_ids[new_end_idx - 1] == image_patch_id: boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item() if boundary_idx == len(img_boundaries_idx): new_end_idx = ori_prompt_len else: new_end_idx = img_boundaries_idx[boundary_idx].item() elif new_end_idx >= ori_prompt_len and paddle.sum(input_ids[pre_end_idx:new_end_idx] == image_patch_id): new_end_idx = ori_prompt_len num_new_tokens = new_end_idx - pre_end_idx image_mask = input_ids[pre_end_idx:new_end_idx] == image_patch_id request.with_image = image_mask.any() if request.with_image: pre_boundary_idx = np.searchsorted(img_boundaries_idx, pre_end_idx, side="left").item() if pre_boundary_idx == len(img_boundaries_idx): request.num_image_start = img_num_per_boundary[-1] else: pre_boundary_idx = ( pre_boundary_idx if pre_end_idx == img_boundaries_idx[pre_boundary_idx] else pre_boundary_idx - 1 ) request.num_image_start = img_num_per_boundary[pre_boundary_idx] new_boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item() if new_boundary_idx == len(img_boundaries_idx): request.num_image_end = img_num_per_boundary[-1] else: new_boundary_idx = ( new_boundary_idx if new_end_idx == img_boundaries_idx[new_boundary_idx] else new_boundary_idx - 1 ) request.num_image_end = img_num_per_boundary[new_boundary_idx] request.image_type_ids_start = np.sum(grid_thw[: request.num_image_start, 0]) request.image_type_ids_end = np.sum(grid_thw[: request.num_image_end, 0]) request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1)) request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1)) if self.encoder_cache: cur_mm_hashes = inputs["mm_hashes"][request.num_image_start : request.num_image_end] cur_mm_positions = inputs["mm_positions"][request.num_image_start : request.num_image_end] request.evict_mm_hashes = self.encoder_cache.apply_cache(cur_mm_hashes, cur_mm_positions) # Compatible with scenarios without images and videos. return num_new_tokens def exist_mm_prefill(self, scheduled_reqs): for request in scheduled_reqs: if request.task_type == RequestType.PREFILL and self._is_mm_request(request): return True return False def exist_prefill(self, scheduled_reqs): for request in scheduled_reqs: if request.task_type == RequestType.PREFILL: return True return False def cache_output_tokens(self, request): if self.config.cache_config.enable_prefix_caching and self.config.cache_config.enable_output_caching: with self.lock: self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_total_tokens - 1 ) def schedule(self): """ Try to pull a batch of requests from the waiting queue and schedule them. """ with self.lock: scheduled_reqs: list[Request] = [] preempted_reqs: list[Request] = [] error_reqs: list[tuple[str, str]] = [] token_budget = self.config.scheduler_config.max_num_batched_tokens # First, schedule the RUNNING requests. req_index = 0 num_decoding_req_nums = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] need_block_num = self.need_block_num_signal.value[request.idx] if need_block_num != 0: self.need_block_num_map[request.request_id] = SignalConsumer(need_block_num, 1) self.need_block_num_signal.value[request.idx] = 0 if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding if ( self.config.scheduler_config.splitwise_role == "prefill" ): # do not need to schedule for decoding req_index += 1 continue if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens request.num_computed_tokens = request.num_total_tokens - 1 if ( self.allocated_slots(request) - request.num_total_tokens <= self.config.cache_config.prealloc_dec_block_slot_num_threshold ): # Allocation for next decoding blocks if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num): llm_logger.debug( f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}" ) request.block_tables.extend( self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num) ) # Prepare decoding task scheduled_reqs.append(self._prepare_decode_task(request)) else: # Not enough blocks to allocate, trigger preemption can_schedule = self._trigger_preempt( request, self.config.cache_config.enc_dec_block_num, preempted_reqs, scheduled_reqs ) if not can_schedule: break # Allocation for next decoding blocks request.block_tables.extend( self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num) ) # Prepare decoding task scheduled_reqs.append(self._prepare_decode_task(request)) num_decoding_req_nums += 1 token_budget -= 1 if ( request.use_extend_tables and request.request_id not in self.using_extend_tables_req_id and self.need_block_num_map[request.request_id].watch() > 0 ): def _allocate_decode_and_extend(): allocate_block_num = self.need_block_num_map[request.request_id].consume() # Prepare decoding task request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(allocate_block_num)) scheduled_reqs.append(self._prepare_decode_task(request)) # Prepare extend task reuse_block_num = request.num_total_tokens // self.config.cache_config.block_size llm_logger.info( f"req {request.request_id} at batch id {request.idx} with reuse_block_num {reuse_block_num} is going to enable extend tables," f"need_block_num {allocate_block_num}" ) self.using_extend_tables_req_id.add(request.request_id) self.reuse_block_num_map[request.request_id] = reuse_block_num request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache request.extend_block_tables.extend( self.cache_manager.allocate_gpu_blocks(allocate_block_num) ) scheduled_reqs.append( ScheduledExtendBlocksTask( idx=request.idx, request_id=request.request_id, extend_block_tables=request.extend_block_tables, ) ) llm_logger.debug(f"extend blocks is {request.extend_block_tables}") if self.cache_manager.can_allocate_gpu_blocks( 2 * self.need_block_num_map[request.request_id].watch() ): _allocate_decode_and_extend() else: llm_logger.info( f"{request.idx} using extend block need {2 * self.need_block_num_map[request.request_id].watch()} blocks but got not enough blocks, ready to preempt" ) can_schedule = self._trigger_preempt( request, 2 * self.need_block_num_map[request.request_id].watch(), preempted_reqs, scheduled_reqs, ) if can_schedule: _allocate_decode_and_extend() else: break else: # need to prefill llm_logger.debug( f"scheduler prefill task: {request} request.need_prefill_tokens {request.need_prefill_tokens} request.num_computed_tokens {request.num_computed_tokens}" ) num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(num_new_block): request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) # Prepare prefill task scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) else: # Not enough blocks to allocate, trigger preemption can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) if not can_schedule: break request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) # Prepare prefill task scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if self.config.cache_config.enable_prefix_caching: self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens ) req_index += 1 # Second, schedule the WAITING requests. if not preempted_reqs: skip_requests: list[Request] = [] while self.waiting and token_budget > 0: if len(self.running) == self.max_num_seqs: break request = self.waiting[0] if ( ErnieArchitectures.is_ernie5_arch(self.config.model_config.architectures) and self._is_mm_request(request) and self.exist_mm_prefill(scheduled_reqs) ) or (paddle.is_compiled_with_xpu() and self.exist_prefill(scheduled_reqs)): break if request.status == RequestStatus.WAITING: result = self.waiting_async_process(request) if result is None: error_reqs.append((request.request_id, request.error_message)) self.waiting.popleft() continue elif result is True: # skip current request, try next request skip_requests.append(request) self.waiting.popleft() continue self._update_mm_hashes(request) # Enable prefix caching if self.config.cache_config.enable_prefix_caching: if self.cache_manager.num_cpu_blocks > 0: if not self.cache_manager.can_allocate_gpu_blocks( (request.need_prefill_tokens + self.config.cache_config.block_size - 1) // self.config.cache_config.block_size ): # to prevent block allocation for matching in hierarchical cache and cause dead lock break success = self.get_prefix_cached_blocks(request) if not success: self._free_blocks(request) break # Allocate blocks for the tokens that does not hit cache num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) if self.cache_manager.can_allocate_gpu_blocks(num_new_block): if not request.get("skip_allocate", False): extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block) request.block_tables.extend(extra_gpu_block_ids) if ( self.config.cache_config.enable_prefix_caching and self.config.cache_config.kvcache_storage_backend and num_new_tokens >= self.config.cache_config.block_size ): matched_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids) num_new_tokens -= len(matched_block_ids) * self.config.cache_config.block_size self.waiting.popleft() self.running.append(request) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if self.config.cache_config.enable_prefix_caching: self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens ) request.status = RequestStatus.RUNNING if self.config.scheduler_config.splitwise_role == "mixed": allocated_position = self.get_available_position() request.idx = allocated_position self.tasks_list[allocated_position] = request self.stop_flags[allocated_position] = False self.req_dict[request.request_id] = allocated_position else: if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) break elif request.status == RequestStatus.PREEMPTED: request.need_prefill_tokens = ( request.num_total_tokens ) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct if self.config.cache_config.enable_prefix_caching: if self.cache_manager.num_cpu_blocks > 0: if not self.cache_manager.can_allocate_gpu_blocks( (request.need_prefill_tokens + self.config.cache_config.block_size - 1) // self.config.cache_config.block_size ): # to prevent block allocation for matching in hierarchical cache and cause dead lock break success = self.get_prefix_cached_blocks(request) if not success: self._free_blocks(request) break # Allocate blocks for the tokens that does not hit cache num_new_tokens = self._get_num_new_tokens(request, token_budget) num_new_block = self.get_new_block_nums(request, num_new_tokens) if self.cache_manager.can_allocate_gpu_blocks(num_new_block): if not request.get("skip_allocate", False): extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block) request.block_tables.extend(extra_gpu_block_ids) if ( self.config.cache_config.enable_prefix_caching and self.config.cache_config.kvcache_storage_backend and num_new_tokens >= self.config.cache_config.block_size ): matched_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids) num_new_tokens -= len(matched_block_ids) * self.config.cache_config.block_size self.waiting.popleft() self.running.append(request) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if self.config.cache_config.enable_prefix_caching: self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens ) request.status = RequestStatus.RUNNING else: if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) break else: llm_logger.error("Unknown request status type") for req in skip_requests: # move waiting request to end of the deque self.waiting.append(req) if scheduled_reqs: llm_logger.debug(f"schedued_reqs: {scheduled_reqs}") self.update_metrics() return scheduled_reqs, error_reqs def waiting_async_process(self, request: Request) -> None: """ Check if async preprocessing is complete for a request. Args: request: The request to check Returns: None: If an error occurred during preprocessing True: If preprocessing is still in progress (request should be skipped) False: If preprocessing is complete (request can be scheduled) """ for future in request.async_process_futures: if future.done(): if request.get("error_message") is not None: return None else: return True request.async_process_futures = [] return False def apply_async_preprocess(self, request: Request) -> None: request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request)) def _has_features_info(self, task): inputs = task.multimodal_inputs if inputs is None or len(inputs) == 0: return False if ( (inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0) or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0) or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0) ): return True return False def _download_features(self, request: Request) -> None: """ download multimodal features from bos Note: 1. this function will be add features for request.multimodal_inputs 2. this function maybe update request.error_message and request.error_code Args: request (Request): request object """ def download_bos_features(bos_client, features_urls): result_list = [] for status, feature in download_from_bos(self.bos_client, features_urls, retry=1): if status: llm_logger.info(f"request {request.request_id} async download feature: {len(feature)}") result_list.append(feature) else: error_msg = f"request {request.request_id} download features error: {feature}" llm_logger.error(error_msg) return error_msg return result_list if not self._has_features_info(request): return None if self.bos_client is None: try: self.bos_client = init_bos_client() except Exception as e: error_msg = f"request {request.request_id} init bos client error: {str(e)}" llm_logger.error(error_msg) request.error_message = error_msg request.error_code = 540 return None inputs = request.multimodal_inputs if inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0: result = download_bos_features(self.bos_client, inputs["video_feature_urls"]) if isinstance(result, str): # download error request.error_message = result request.error_code = 530 return None inputs["video_features"] = result if inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0: result = download_bos_features(self.bos_client, inputs["image_feature_urls"]) if isinstance(result, str): # download error request.error_message = result request.error_code = 530 return None inputs["image_features"] = result if inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0: result = download_bos_features(self.bos_client, inputs["audio_feature_urls"]) if isinstance(result, str): # download error request.error_message = result request.error_code = 530 return None inputs["audio_features"] = result def get_available_position(self) -> int: position = 0 while position < self.max_num_seqs: if self.stop_flags[position] is True: return position position += 1 raise RuntimeError("No available position is available for new request") def get_real_bsz(self) -> int: for i in range(self.max_num_seqs - 1, -1, -1): if not self.stop_flags[i]: self.real_bsz = i + 1 break return self.real_bsz def get_prefix_cached_blocks(self, request: Request): """ set prefix cached information for the given request """ try: cache_prepare_time = time.time() (common_block_ids, matched_token_num, hit_info) = self.cache_manager.request_match_blocks( request, self.config.cache_config.block_size ) matched_block_num = len(common_block_ids) no_cache_block_num = self.cache_manager.get_required_block_num( request.need_prefill_tokens - matched_token_num, self.config.cache_config.block_size, ) request.num_cached_tokens = matched_token_num request.metrics.gpu_cache_token_num = hit_info["gpu_match_token_num"] request.metrics.cpu_cache_token_num = hit_info["cpu_match_token_num"] request.cache_info = [matched_block_num, no_cache_block_num] request.block_tables = common_block_ids request.skip_allocate = False # Report the number of cached tokens to Prometheus metrics main_process_metrics.prefix_cache_token_num.inc(matched_token_num) main_process_metrics.prefix_gpu_cache_token_num.inc(request.metrics.gpu_cache_token_num) main_process_metrics.prefix_cpu_cache_token_num.inc(request.metrics.gpu_cache_token_num) if matched_token_num == request.need_prefill_tokens: request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size request.skip_allocate = True else: request.num_computed_tokens = matched_token_num request.metrics.gpu_cpu_cache_prepare_time = time.time() - cache_prepare_time return True except Exception as e: llm_logger.error(f"prefix match blocks error: {e}, {str(traceback.format_exc())} waiting reschedule...") return False def get_storage_cached_blocks(self, request: Request, extra_gpu_block_ids: list = []): """ Match and prefetch the cached blocks from the storage backend. TODO: merge this function into get_prefix_cached_blocks """ try: tic = time.time() req_id = request.request_id llm_logger.debug(f"get_storage_cached_blocks start process req {req_id}") matched_block_ids = self.cache_manager.request_match_storage_blocks(request, extra_gpu_block_ids) llm_logger.debug( f"matched {len(matched_block_ids)} blocks from storage for req_id:{req_id}, " f"cost_time: {time.time() - tic:.6f}s" ) matched_token_num = len(matched_block_ids) * self.config.cache_config.block_size request.metrics.storage_cache_token_num = matched_token_num request.num_computed_tokens += matched_token_num if request.num_computed_tokens == request.need_prefill_tokens: request.num_computed_tokens = request.num_computed_tokens - self.config.cache_config.block_size request.metrics.storage_cache_prepare_time = time.time() - tic request.cache_info[0] += len(matched_block_ids) # matched_block_num request.cache_info[1] -= len(matched_block_ids) # no_cache_block_num main_process_metrics.prefix_cache_token_num.inc(matched_token_num) # TODO: main_process_metrics.prefix_storage_cache_token_num.inc(matched_token_num) return matched_block_ids except Exception as e: llm_logger.error( f"get_storage_cached_blocks process req {req_id}, error: {e}, {str(traceback.format_exc())} " ) return [] def add_request(self, request: Request) -> None: with self.lock: self.apply_async_preprocess(request) self.waiting.append(request) self.requests[request.request_id] = request def pre_recycle_resource(self, request_id: str): """ Recycle resource in P or D before finished due to unexpected error. """ with self.lock: if request_id not in self.requests: return req = self.requests[request_id] self.tasks_list[req.idx] = None self.stop_flags[req.idx] = True self._free_blocks(req) del self.requests[request_id] if request_id in self.req_dict: del self.req_dict[request_id] def add_request_in_p(self, requests: list[Request]): with self.lock: for request in requests: self.running.append(request) def preallocate_resource_in_p(self, request: Request): """ In P/D aggregated deployment, preallocate resource for P. If can allocate, allocate resources and return True If can not, return False """ assert self.config.scheduler_config.splitwise_role == "prefill", "Only P instance can call this method" with self.lock: if self.available_batch() == 0: return False request.need_prefill_tokens = len(request.prompt_token_ids) need_prealloc_prefill_blocks = ( request.need_prefill_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num if self.config.cache_config.enable_prefix_caching: # Enable prefix caching if self.cache_manager.num_cpu_blocks > 0: if not self.cache_manager.can_allocate_gpu_blocks( need_prealloc_prefill_blocks ): # to prevent block allocation for matching in hierarchical cache and cause dead lock return False success = self.get_prefix_cached_blocks(request) if not success: self._free_blocks(request) return False need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0] if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks): extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks) if self.config.cache_config.enable_prefix_caching: self.get_storage_cached_blocks(request, extra_gpu_block_ids) request.block_tables.extend(extra_gpu_block_ids) allocated_position = self.get_available_position() request.idx = allocated_position self.tasks_list[request.idx] = request self.stop_flags[request.idx] = False self.requests[request.request_id] = request self.req_dict[request.request_id] = allocated_position return True else: self._free_blocks(request) return False else: if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks): request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks)) request.num_computed_tokens = 0 allocated_position = self.get_available_position() request.idx = allocated_position self.tasks_list[request.idx] = request self.stop_flags[request.idx] = False self.requests[request.request_id] = request self.req_dict[request.request_id] = allocated_position return True return False def preallocate_resource_in_d(self, request: Request): """ In P/D aggregated deployment, D should preallocate resource for P. If can allocate, allocate resources and return True If can not, return False """ assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method" if request.reasoning_max_tokens is not None: request.reasoning_max_tokens -= 1 request.need_prefill_tokens = len(request.prompt_token_ids) need_prealloc_prefill_blocks = ( request.need_prefill_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num with self.lock: if len(self.waiting) > 0: return False if self.available_batch() == 0: return False if not self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks): return False request.block_tables = self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks) request.num_computed_tokens = request.need_prefill_tokens request.disaggregate_info["block_tables"] = request.block_tables allocated_position = self.get_available_position() request.idx = allocated_position self.tasks_list[request.idx] = request self.stop_flags[request.idx] = False self.requests[request.request_id] = request self.req_dict[request.request_id] = allocated_position return True def has_resource_for_prefilled_req(self, request_id: str): """ Check whether there are enough slot and gpu resource for the prefilled request, of which the cache is saved in cpu buffer. """ assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method" assert request_id in self.preallocated_reqs, "request_id must be in preallocate" need_blocks_num = len(self.preallocated_reqs[request_id].disaggregate_info["block_tables"]) return self.available_batch() > 0 and self.cache_manager.can_allocate_gpu_blocks(need_blocks_num) def add_prefilled_request(self, request_output: RequestOutput): """ In P/D aggregated deployment, D should continue to decode after receiving first token and cache from P. NOTE: GPU resources should be checked in advance to ensure they are sufficient for the prefilled request. """ assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method" if request_output.request_id not in self.requests: llm_logger.error(f"Request {request_output.request_id} not found in requests") return request = self.requests[request_output.request_id] # update request and insert to running request.output_token_ids.append(request_output.outputs.token_ids[0]) request.num_cached_tokens = request_output.num_cached_tokens if ( self.config.speculative_config.method in ["mtp"] and self.config.scheduler_config.splitwise_role == "decode" ): request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids) request.need_prefill_tokens = len(request.prompt_token_ids) + 1 request_output.metrics.decode_recv_req_time = request.metrics.decode_recv_req_time request_output.metrics.decode_preallocate_req_time = request.metrics.decode_preallocate_req_time request.metrics = request_output.metrics self.running.append(request) def _free_blocks(self, request: Request): if self.config.cache_config.enable_prefix_caching: self.cache_manager.release_block_ids(request) self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cached_block_num :]) else: self.cache_manager.recycle_gpu_blocks(request.block_tables) request.block_tables = [] if request.request_id in self.using_extend_tables_req_id: reuse_block_num = self.reuse_block_num_map[request.request_id] self.using_extend_tables_req_id.remove(request.request_id) self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:]) llm_logger.info( f"req {request.request_id} recycle extend blocks {request.extend_block_tables[reuse_block_num:]}" ) request.extend_block_tables = [] del self.reuse_block_num_map[request.request_id] del self.need_block_num_map[request.request_id] def finish_requests_async(self, request_ids: Union[str, Iterable[str]]): return self.finish_execution_pool.submit(self.finish_requests, request_ids) def finish_requests(self, request_ids: Union[str, Iterable[str]]): llm_logger.info(f"recycle resources for requests: {request_ids}") try: if isinstance(request_ids, str): request_ids = (request_ids,) else: request_ids = set(request_ids) need_postprocess_reqs = [] with self.lock: for req_id in request_ids: request = self.requests.get(req_id) if request is None: continue if request in self.waiting: llm_logger.error(f"request {request.request_id} scheduled into waiting list, after finished") continue if request in self.running: self.running.remove(request) request.status = RequestStatus.FINISHED need_postprocess_reqs.append(request) if request.request_id in self.to_be_rescheduled_request_id_set: # finished after preempted, blocks have been recycled. self.to_be_rescheduled_request_id_set.remove(request.request_id) self.tasks_list[request.idx] = None self.stop_flags[request.idx] = True del self.requests[req_id] if req_id in self.req_dict: del self.req_dict[req_id] # Do not block the main thread here for req in need_postprocess_reqs: self.cache_manager.write_cache_to_storage(req) with self.lock: for req in need_postprocess_reqs: try: self._free_blocks(req) except Exception as e: llm_logger.warning(f"release block failed {req.request_id}: {e}") except Exception as e: llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}") finally: self.update_metrics() def clear_data(self): self.waiting: deque[Request] = deque() self.to_be_rescheduled_request_id_set = set() def update_metrics(self): # Update metrics num_tasks = sum([1 if task else 0 for task in self.tasks_list]) num_blocks_used_by_tasks = sum([len(task.block_tables) if task else 0 for task in self.tasks_list]) main_process_metrics.available_gpu_block_num.set(self.total_block_number() - num_blocks_used_by_tasks) main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) main_process_metrics.num_requests_running.set(len(self.running)) main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running))