Files
FastDeploy/fastdeploy/scheduler/dp_scheduler.py
T
chenjian 6727df8286 [Optimization] Optimize ttft for prefill pd (#6680)
* optimize ttft

* fix

* fix

* fix ci

* fix ci

* fix

* fix bug

* fix

* add comments

* fix ci

* fix

* fix ci

* fix format

* update according to review

* add comment

* fix

* fix format
2026-03-30 20:36:23 +08:00

228 lines
7.8 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import logging
import multiprocessing
import threading
import time
from typing import Dict, List, Optional
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledResponse
from fastdeploy.scheduler.local_scheduler import LocalScheduler
from fastdeploy.utils import get_logger
class DPLocalScheduler(LocalScheduler):
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
super().__init__(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
)
self.splitwise_role = splitwise_role
self.scheduler_logger = logging
def put_results(self, results: List[RequestOutput]):
"""
Add processing results back to the scheduler.
Args:
results: List of RequestOutput objects containing results
"""
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
finished_responses = [response.request_id for response in responses if response.finished]
if len(finished_responses) > 0:
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
with self.mutex:
self.batch_responses_per_step.append([response.raw for response in responses])
for response in responses:
if response.request_id not in self.responses:
self.responses[response.request_id] = [response]
continue
self.responses[response.request_id].append(response)
self.responses_not_empty.notify_all()
def _recycle(self, request_id: Optional[str] = None):
"""
Clean up expired or completed requests to free memory.
Args:
request_id: Optional specific request ID to remove.
If None, removes all expired requests.
"""
if request_id is not None:
self.requests.pop(request_id, None)
self.responses.pop(request_id, None)
if self.splitwise_role == "decode":
return
self.ids.pop(self.ids.index(request_id))
self.ids_read_cursor -= 1
return
if self.max_size <= 0:
return
if len(self.requests) <= self.max_size:
return
now = time.time()
expired_ids = []
for request_id in self.ids:
request = self.requests[request_id]
if now - request.schedule_time < self.ttl:
break
expired_ids.append(request.request_id)
for i, expired_id in enumerate(expired_ids):
self.requests.pop(expired_id, None)
self.responses.pop(expired_id, None)
self.ids.pop(i)
if len(expired_ids) > 0:
if len(expired_ids) - 1 >= self.ids_read_cursor:
self.ids_read_cursor = 0
else:
self.ids_read_cursor -= len(expired_ids)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
"""
Retrieve requests from the scheduler based on available resources.
Args:
available_blocks: Number of available processing blocks
block_size: Size of each processing block
reserved_output_blocks: Blocks reserved for output
max_num_batched_tokens: Maximum tokens that can be batched
batch: Preferred batch size
Returns:
List of Request objects ready for processing
"""
# DP scheduler is used in V1, there is no need to manage request fetching in the scheduler, resource_manager_v1 will do that.
requests: List[Request] = []
with self.requests_not_empty:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
requests.append(request.raw)
self.ids_read_cursor += 1
if len(requests) > 0:
self.scheduler_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
)
return requests
class DPScheduler:
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
self._scheduler = DPLocalScheduler(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
splitwise_role,
)
def start(self, dp_rank: int):
self.dp_rank = dp_rank
self.request_queues = multiprocessing.Queue()
self.result_queues = multiprocessing.Queue()
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
self._scheduler.scheduler_logger = self.scheduler_logger
threading.Thread(target=self._put_requests_to_local).start()
threading.Thread(target=self._get_response_from_local).start()
def put_requests(self, requests: List[Dict]):
results = []
for request in requests:
if not hasattr(request, "dp_rank"):
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
self.request_queues.put(request)
results.append((request.request_id, None))
return results
def _put_requests_to_local(self):
while True:
request = self.request_queues.get()
self.scheduler_logger.info(f"Receive request from puller, request_id: {request.request_id}")
self._scheduler.put_requests([request])
def _get_response_from_local(self):
while True:
results = self._scheduler.get_results()
if len(results) == 0:
continue
self.result_queues.put(results)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
return self._scheduler.get_requests(
available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch
)
def get_unhandled_request_num(self):
return len(self._scheduler.requests)
def put_results(self, results: List[RequestOutput]):
self._scheduler.put_results(results)
def get_results(self) -> Dict[str, List[RequestOutput]]:
return self.result_queues.get()