""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import logging import multiprocessing import threading import time from typing import Dict, List, Optional from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledResponse from fastdeploy.scheduler.local_scheduler import LocalScheduler from fastdeploy.utils import get_logger class DPLocalScheduler(LocalScheduler): def __init__( self, max_size: int, ttl: int, enable_chunked_prefill: bool, max_num_partial_prefills: int, max_long_partial_prefills: int, long_prefill_token_threshold: int, splitwise_role: str = "prefill", ): super().__init__( max_size, ttl, enable_chunked_prefill, max_num_partial_prefills, max_long_partial_prefills, long_prefill_token_threshold, ) self.splitwise_role = splitwise_role self.scheduler_logger = logging def put_results(self, results: List[RequestOutput]): """ Add processing results back to the scheduler. Args: results: List of RequestOutput objects containing results """ responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results] finished_responses = [response.request_id for response in responses if response.finished] if len(finished_responses) > 0: self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}") with self.mutex: self.batch_responses_per_step.append([response.raw for response in responses]) for response in responses: if response.request_id not in self.responses: self.responses[response.request_id] = [response] continue self.responses[response.request_id].append(response) self.responses_not_empty.notify_all() def _recycle(self, request_id: Optional[str] = None): """ Clean up expired or completed requests to free memory. Args: request_id: Optional specific request ID to remove. If None, removes all expired requests. """ if request_id is not None: self.requests.pop(request_id, None) self.responses.pop(request_id, None) if self.splitwise_role == "decode": return self.ids.pop(self.ids.index(request_id)) self.ids_read_cursor -= 1 return if self.max_size <= 0: return if len(self.requests) <= self.max_size: return now = time.time() expired_ids = [] for request_id in self.ids: request = self.requests[request_id] if now - request.schedule_time < self.ttl: break expired_ids.append(request.request_id) for i, expired_id in enumerate(expired_ids): self.requests.pop(expired_id, None) self.responses.pop(expired_id, None) self.ids.pop(i) if len(expired_ids) > 0: if len(expired_ids) - 1 >= self.ids_read_cursor: self.ids_read_cursor = 0 else: self.ids_read_cursor -= len(expired_ids) def get_requests( self, available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch=1, ) -> List[Request]: """ Retrieve requests from the scheduler based on available resources. Args: available_blocks: Number of available processing blocks block_size: Size of each processing block reserved_output_blocks: Blocks reserved for output max_num_batched_tokens: Maximum tokens that can be batched batch: Preferred batch size Returns: List of Request objects ready for processing """ # DP scheduler is used in V1, there is no need to manage request fetching in the scheduler, resource_manager_v1 will do that. requests: List[Request] = [] with self.requests_not_empty: batch_ids = self.requests_not_empty.wait_for( lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1], 0.005, ) if batch_ids: for request_id in batch_ids: request = self.requests[request_id] requests.append(request.raw) self.ids_read_cursor += 1 if len(requests) > 0: self.scheduler_logger.info( f"Scheduler has pulled some request: {[request.request_id for request in requests]}" ) return requests class DPScheduler: def __init__( self, max_size: int, ttl: int, enable_chunked_prefill: bool, max_num_partial_prefills: int, max_long_partial_prefills: int, long_prefill_token_threshold: int, splitwise_role: str = "prefill", ): self._scheduler = DPLocalScheduler( max_size, ttl, enable_chunked_prefill, max_num_partial_prefills, max_long_partial_prefills, long_prefill_token_threshold, splitwise_role, ) def start(self, dp_rank: int): self.dp_rank = dp_rank self.request_queues = multiprocessing.Queue() self.result_queues = multiprocessing.Queue() self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log") self._scheduler.scheduler_logger = self.scheduler_logger threading.Thread(target=self._put_requests_to_local).start() threading.Thread(target=self._get_response_from_local).start() def put_requests(self, requests: List[Dict]): results = [] for request in requests: if not hasattr(request, "dp_rank"): raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}") self.request_queues.put(request) results.append((request.request_id, None)) return results def _put_requests_to_local(self): while True: request = self.request_queues.get() self.scheduler_logger.info(f"Receive request from puller, request_id: {request.request_id}") self._scheduler.put_requests([request]) def _get_response_from_local(self): while True: results = self._scheduler.get_results() if len(results) == 0: continue self.result_queues.put(results) def get_requests( self, available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch=1, ) -> List[Request]: return self._scheduler.get_requests( available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch ) def get_unhandled_request_num(self): return len(self._scheduler.requests) def put_results(self, results: List[RequestOutput]): self._scheduler.put_results(results) def get_results(self) -> Dict[str, List[RequestOutput]]: return self.result_queues.get()