polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions
+163 -150
View File
@@ -14,24 +14,25 @@
# limitations under the License.
"""
from typing import List, Optional, Dict, Tuple
import traceback
import random
import threading
import time
import random
import traceback
import uuid
from typing import Dict, List, Optional, Tuple
import crcmod
from redis import ConnectionPool
from fastdeploy.scheduler.storage import AdaptedRedis
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.scheduler.workers import Workers, Task
from fastdeploy.utils import scheduler_logger
from fastdeploy.scheduler import utils
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.scheduler.storage import AdaptedRedis
from fastdeploy.scheduler.workers import Task, Workers
from fastdeploy.utils import scheduler_logger
class GlobalScheduler(object):
class GlobalScheduler:
"""
A distributed task scheduler that manages request/response queues using Redis.
@@ -42,20 +43,21 @@ class GlobalScheduler(object):
- Maintaining worker health checks
"""
def __init__(self,
host: str,
port: int,
db: int,
password: Optional[str],
topic: str,
ttl: int,
min_load_score: float,
load_shards_num: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
):
def __init__(
self,
host: str,
port: int,
db: int,
password: Optional[str],
topic: str,
ttl: int,
min_load_score: float,
load_shards_num: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
):
"""
Initialize the GlobalScheduler with Redis connection and scheduling parameters.
@@ -94,29 +96,25 @@ class GlobalScheduler(object):
self.blpop_response_timeout = 10
self.crc16_mutex = threading.Lock()
self.crc16 = crcmod.predefined.Crc('ccitt-false')
self.crc16 = crcmod.predefined.Crc("ccitt-false")
self.load_slot_for_getting_request = 0
self.load_offset = 0 # const
self.load_count = 50 # const
self.load_offset = 0 # const
self.load_count = 50 # const
self.load_lookup_num = 5 # const
self.keep_alive_duration = 30 # const
connection_pool = ConnectionPool(
host=host, port=port, db=db, password=password, max_connections=10)
connection_pool = ConnectionPool(host=host, port=port, db=db, password=password, max_connections=10)
self.client = AdaptedRedis(connection_pool=connection_pool)
self.name, self.shard = self._generate_scheduler_name_and_shard()
self.keep_alive_workers = threading.Thread(
target=self._keep_alive, daemon=True)
self.keep_alive_workers = threading.Thread(target=self._keep_alive, daemon=True)
self.keep_alive_workers.start()
self.put_requests_workers = Workers(
"put_requests_workers", self._put_requests_worker, 20)
self.put_requests_workers = Workers("put_requests_workers", self._put_requests_worker, 20)
self.put_requests_workers.start(1)
self.put_results_workers = Workers(
"put_results_workers", self._put_results_worker, 300)
self.put_results_workers = Workers("put_results_workers", self._put_results_worker, 300)
self.put_results_workers.start(1)
self.mutex = threading.Lock()
@@ -124,12 +122,10 @@ class GlobalScheduler(object):
self.local_responses: Dict[str, List[ScheduledResponse]] = dict()
self.stolen_requests: Dict[str, ScheduledRequest] = dict()
self.get_response_workers = threading.Thread(
target=self._get_results_worker, daemon=True)
self.get_response_workers = threading.Thread(target=self._get_results_worker, daemon=True)
self.get_response_workers.start()
scheduler_logger.info(
f"Scheduler: name={self.name} redis_version={self.client.version}")
scheduler_logger.info(f"Scheduler: name={self.name} redis_version={self.client.version}")
def _get_hash_slot(self, data: str) -> int:
"""
@@ -184,8 +180,8 @@ class GlobalScheduler(object):
4. Handles naming conflicts by appending incrementing suffixes
Returns:
Tuple[str, int]:
- str: Unique scheduler name
Tuple[str, int]:
- str: Unique scheduler name
- int: Assigned shard number (0 to load_shards_num-1)
Implementation Details:
@@ -202,21 +198,28 @@ class GlobalScheduler(object):
try:
_, name = utils.get_hostname_ip()
except Exception as e:
scheduler_logger.warning(
f"Scheduler encountered an error while resolving the IP address. {e}")
scheduler_logger.warning(f"Scheduler encountered an error while resolving the IP address. {e}")
name = str(uuid.uuid4())
size = len(name)
count = 1
while True:
if self.client.set(self._instance_name(name), "", ex=self.keep_alive_duration, nx=True):
if self.client.set(
self._instance_name(name),
"",
ex=self.keep_alive_duration,
nx=True,
):
break
name = f"{name[:size]}:{count}"
count += 1
shard = self._get_hash_slot(name) % self.load_shards_num
self.client.set(self._instance_name(name), self._load_table_name(shard=shard),
ex=self.keep_alive_duration)
self.client.set(
self._instance_name(name),
self._load_table_name(shard=shard),
ex=self.keep_alive_duration,
)
return name, shard
def _keep_alive(self):
@@ -227,8 +230,11 @@ class GlobalScheduler(object):
"""
while True:
try:
self.client.set(self._instance_name(
self.name), self._load_table_name(), ex=self.keep_alive_duration)
self.client.set(
self._instance_name(self.name),
self._load_table_name(),
ex=self.keep_alive_duration,
)
time.sleep(self.keep_alive_duration / 2)
except Exception as e:
scheduler_logger.error(f"Scheduler keep alive failed: {e}")
@@ -324,7 +330,7 @@ class GlobalScheduler(object):
mark = f"mark<{request_queue_name}>"
if not response.request_id.startswith(mark):
return
response.request_id = response.request_id[len(mark):]
response.request_id = response.request_id[len(mark) :]
def _put_requests_worker(self, tasks: List[Task]) -> List[Task]:
"""
@@ -341,7 +347,10 @@ class GlobalScheduler(object):
with self.mutex:
for task in tasks:
request = ScheduledRequest(
task.raw, self._request_queue_name(), self._response_queue_name())
task.raw,
self._request_queue_name(),
self._response_queue_name(),
)
task.raw = None
if request.request_id in self.local_responses:
@@ -353,18 +362,21 @@ class GlobalScheduler(object):
if len(requests) > 0:
serialized_requests = [request.serialize() for request in requests]
self.client.rpush(self._request_queue_name(), *
serialized_requests, ttl=self.ttl)
self.client.zincrby(self._load_table_name(),
len(serialized_requests), self.name,
rem_amount=0, ttl=self.ttl)
scheduler_logger.info(
f"Scheduler has enqueued some requests: {requests}")
self.client.rpush(self._request_queue_name(), *serialized_requests, ttl=self.ttl)
self.client.zincrby(
self._load_table_name(),
len(serialized_requests),
self.name,
rem_amount=0,
ttl=self.ttl,
)
scheduler_logger.info(f"Scheduler has enqueued some requests: {requests}")
if duplicate:
scheduler_logger.warning(
"Scheduler has received some duplicated requests: "
f"{[task for task in tasks if task.reason is not None]}")
f"{[task for task in tasks if task.reason is not None]}"
)
return tasks
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
@@ -386,8 +398,14 @@ class GlobalScheduler(object):
results = self.put_requests_workers.get_results(10, 0.001)
return [(result.id, result.reason) for result in results]
def get_requests(self, available_blocks, block_size, reserved_output_blocks,
max_num_batched_tokens, batch=1) -> List[Request]:
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
"""
Get requests from the shared cache based on available resources.
@@ -406,7 +424,8 @@ class GlobalScheduler(object):
scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
f"max_num_batched_tokens={max_num_batched_tokens}"
)
return []
mini_batch = (batch + 1) // 2
@@ -424,37 +443,38 @@ class GlobalScheduler(object):
local_request_queue_name = self._request_queue_name()
serialized_requests: List[Tuple[str, bytes]] = []
for bs in batches:
elements = self.client.lpop(
local_request_queue_name, bs, ttl=self.ttl)
elements = self.client.lpop(local_request_queue_name, bs, ttl=self.ttl)
if elements is None:
break
self.client.zincrby(self._load_table_name(), -
len(elements), self.name, rem_amount=0, ttl=self.ttl)
serialized_requests += [(local_request_queue_name, element)
for element in elements]
self.client.zincrby(
self._load_table_name(),
-len(elements),
self.name,
rem_amount=0,
ttl=self.ttl,
)
serialized_requests += [(local_request_queue_name, element) for element in elements]
extend_scheduler_names = []
extend_scheduler_load_table_name = ""
if len(serialized_requests) == 0 and len(batches) > 0:
for _ in range(min(self.load_lookup_num, self.load_shards_num)):
extend_scheduler_load_table_name = self._load_table_name(
slot=self.load_slot_for_getting_request)
extend_scheduler_load_table_name = self._load_table_name(slot=self.load_slot_for_getting_request)
serialized_members = self.client.zrangebyscore(
extend_scheduler_load_table_name,
self.min_load_score,
float("+inf"),
start=self.load_offset,
num=self.load_count)
num=self.load_count,
)
self.load_slot_for_getting_request += 1
if len(serialized_members) > 0:
break
members = [member.decode("utf-8") for member in serialized_members]
if len(members) > 0:
extend_scheduler_names = random.sample(
members, k=min(10, len(members)))
extend_scheduler_names = [
name for name in extend_scheduler_names if name != self.name]
extend_scheduler_names = random.sample(members, k=min(10, len(members)))
extend_scheduler_names = [name for name in extend_scheduler_names if name != self.name]
# find lucky one
if len(extend_scheduler_names) > 0:
@@ -463,40 +483,43 @@ class GlobalScheduler(object):
elements = self.client.lpop(lucky_request_queue_name, batches[0])
if elements is not None and len(elements) > 0:
self.client.zincrby(extend_scheduler_load_table_name,
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
serialized_requests += [(lucky_request_queue_name, element)
for element in elements]
self.client.zincrby(
extend_scheduler_load_table_name,
-len(elements),
lucky,
rem_amount=0,
ttl=self.ttl,
)
serialized_requests += [(lucky_request_queue_name, element) for element in elements]
scheduler_logger.info(
f"Scheduler {self.name} has stolen some requests from another lucky one. "
f"(name={lucky} num={len(serialized_requests)})")
f"(name={lucky} num={len(serialized_requests)})"
)
else:
exist_num = self.client.exists(self._instance_name(lucky))
if exist_num == 0:
if self.client.zrem(extend_scheduler_load_table_name, lucky):
scheduler_logger.info(
f"Scheduler {lucky} has been removed")
scheduler_logger.info(f"Scheduler {lucky} has been removed")
# blocked read
if len(serialized_requests) == 0:
request_queue_names = [local_request_queue_name]
request_queue_names += [
self._request_queue_name(name) for name in extend_scheduler_names]
request_queue_names += [self._request_queue_name(name) for name in extend_scheduler_names]
element = self.client.blpop(
request_queue_names, self.blpop_request_timeout)
element = self.client.blpop(request_queue_names, self.blpop_request_timeout)
if element is None:
return []
request_queue_name = element[0].decode("utf-8")
scheduler_name = self._scheduler_name_from_request_queue(
request_queue_name)
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
self.client.zincrby(load_table_name,
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
scheduler_name = self._scheduler_name_from_request_queue(request_queue_name)
load_table_name = (
extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
)
self.client.zincrby(load_table_name, -1, scheduler_name, rem_amount=0, ttl=self.ttl)
serialized_requests.append((request_queue_name, element[1]))
if scheduler_name != self.name:
scheduler_logger.info(
f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})")
f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})"
)
long_partial_requests = 0
short_partial_requests = 0
@@ -506,41 +529,34 @@ class GlobalScheduler(object):
scheduled_requests: List[ScheduledRequest] = []
for request_queue_name, serialized_request in serialized_requests:
if len(remaining_request) > 0:
remaining_request.append(
(request_queue_name, serialized_request))
remaining_request.append((request_queue_name, serialized_request))
continue
request: ScheduledRequest = ScheduledRequest.unserialize(
serialized_request)
required_input_blocks = self.calc_required_blocks(
request.prompt_tokens_ids_len, block_size)
request: ScheduledRequest = ScheduledRequest.unserialize(serialized_request)
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
remaining_request.append(
(request_queue_name, serialized_request))
remaining_request.append((request_queue_name, serialized_request))
continue
if self.enable_chunked_prefill:
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
long_partial_requests += 1
if long_partial_requests > self.max_long_partial_prefills:
remaining_request.append(
(request_queue_name, serialized_request))
remaining_request.append((request_queue_name, serialized_request))
continue
else:
short_partial_requests += 1
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
remaining_request.append(
(request_queue_name, serialized_request))
remaining_request.append((request_queue_name, serialized_request))
continue
else:
if current_prefill_tokens > max_num_batched_tokens:
remaining_request.append(
(request_queue_name, serialized_request))
remaining_request.append((request_queue_name, serialized_request))
continue
scheduled_requests.append(request)
@@ -556,11 +572,9 @@ class GlobalScheduler(object):
self.stolen_requests[request.request_id] = request
continue
scheduler_logger.error(
f"Scheduler has received a duplicate request from others: {request}")
scheduler_logger.error(f"Scheduler has received a duplicate request from others: {request}")
requests: List[Request] = [
request.raw for request in scheduled_requests]
requests: List[Request] = [request.raw for request in scheduled_requests]
if len(remaining_request) > 0:
group: Dict[str, List] = dict()
for request_queue_name, serialized_request in remaining_request:
@@ -569,23 +583,26 @@ class GlobalScheduler(object):
group[request_queue_name].append(serialized_request)
for request_queue_name, serialized_requests in group.items():
self.client.lpush(request_queue_name, *
serialized_requests)
scheduler_name = self._scheduler_name_from_request_queue(
request_queue_name)
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
self.client.zincrby(load_table_name,
len(serialized_requests), scheduler_name, ttl=self.ttl)
self.client.lpush(request_queue_name, *serialized_requests)
scheduler_name = self._scheduler_name_from_request_queue(request_queue_name)
load_table_name = (
extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
)
self.client.zincrby(
load_table_name,
len(serialized_requests),
scheduler_name,
ttl=self.ttl,
)
scheduler_logger.info(
f"Scheduler has put remaining request into the queue: {len(remaining_request)}")
scheduler_logger.info(f"Scheduler has put remaining request into the queue: {len(remaining_request)}")
if len(requests) == 0:
scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}")
f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}"
)
if len(requests) > 0:
scheduler_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
scheduler_logger.info(f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
return requests
def _put_results_worker(self, tasks: List[Task]):
@@ -623,17 +640,15 @@ class GlobalScheduler(object):
if response.request_id in stolen_request_id_request_queue:
response_queue_name = stolen_request_id_response_queue[response.request_id]
request_queue_name = stolen_request_id_request_queue[response.request_id]
# request_queue_name = stolen_request_id_request_queue[response.request_id]
# self._unmark_response(response, request_queue_name)
if response_queue_name not in stolen_responses:
stolen_responses[response_queue_name] = []
stolen_responses[response_queue_name].append(
response.serialize())
stolen_responses[response_queue_name].append(response.serialize())
continue
scheduler_logger.error(
f"Scheduler has recieved a non-existent response from engine: {[response]}")
scheduler_logger.error(f"Scheduler has recieved a non-existent response from engine: {[response]}")
with self.mutex:
for request_id, responses in local_responses.items():
@@ -648,8 +663,7 @@ class GlobalScheduler(object):
self.local_response_not_empty.notify_all()
if len(finished_request_ids) > 0:
scheduler_logger.info(
f"Scheduler has received some finished responses: {finished_request_ids}")
scheduler_logger.info(f"Scheduler has received some finished responses: {finished_request_ids}")
for response_queue_name, responses in stolen_responses.items():
self.client.rpush(response_queue_name, *responses, ttl=self.ttl)
@@ -663,8 +677,7 @@ class GlobalScheduler(object):
Args:
results: List of RequestOutput objects to return
"""
tasks: List[Task] = [Task(result.request_id, result)
for result in results]
tasks: List[Task] = [Task(result.request_id, result) for result in results]
self.put_results_workers.add_tasks(tasks)
# ---- for test ----
@@ -684,20 +697,20 @@ class GlobalScheduler(object):
"""
while True:
try:
serialized_responses = self.client.lpop(
self._response_queue_name(), 300, ttl=self.ttl)
serialized_responses = self.client.lpop(self._response_queue_name(), 300, ttl=self.ttl)
if serialized_responses is None or len(serialized_responses) == 0:
element = self.client.blpop(
[self._response_queue_name()], self.blpop_response_timeout)
[self._response_queue_name()],
self.blpop_response_timeout,
)
if element is None or len(element) == 0:
continue
serialized_responses = [element[1]]
responses: Dict[str, List[ScheduledResponse]] = dict()
for serialized_response in serialized_responses:
response = ScheduledResponse.unserialize(
serialized_response)
response = ScheduledResponse.unserialize(serialized_response)
if response.request_id not in responses:
responses[response.request_id] = []
responses[response.request_id].append(response)
@@ -707,13 +720,15 @@ class GlobalScheduler(object):
if request_id not in self.local_responses:
scheduler_logger.error(
"Scheduler has received some non-existent response from the queue. "
f"response:{contents} queue:{self._response_queue_name()}")
f"response:{contents} queue:{self._response_queue_name()}"
)
continue
self.local_responses[request_id] += contents
self.local_response_not_empty.notify_all()
except Exception as e:
scheduler_logger.error(f"Scheduler get_results_worker exception: {e} "
f"traceback: {traceback.format_exc()}")
scheduler_logger.error(
f"Scheduler get_results_worker exception: {e} " f"traceback: {traceback.format_exc()}"
)
def get_results(self) -> Dict[str, List[RequestOutput]]:
"""
@@ -732,7 +747,7 @@ class GlobalScheduler(object):
4. Automatically cleans up completed request tracking
Returns:
Dict[str, List[RequestOutput]]:
Dict[str, List[RequestOutput]]:
A dictionary where:
- Key is the request ID
- Value is a list of RequestOutput objects for that request
@@ -765,8 +780,7 @@ class GlobalScheduler(object):
return responses
with self.local_response_not_empty:
responses: Dict[str, List[ScheduledResponse]] = self.local_response_not_empty.wait_for(
_get_results, 0.001)
responses: Dict[str, List[ScheduledResponse]] = self.local_response_not_empty.wait_for(_get_results, 0.001)
results: Dict[str, List[RequestOutput]] = dict()
for request_id, resps in responses.items():
@@ -778,8 +792,7 @@ class GlobalScheduler(object):
if finished:
del self.local_responses[request_id]
scheduler_logger.info(
f"Scheduler has pulled a finished response: {[request_id]}")
scheduler_logger.info(f"Scheduler has pulled a finished response: {[request_id]}")
return results
def reset(self):
@@ -800,14 +813,13 @@ class GlobalScheduler(object):
- Clears the local_responses dictionary tracking pending responses
- Clears the stolen_requests dictionary tracking requests taken from other schedulers
Note:
Note:
- Uses the scheduler's mutex to ensure thread safety
- Does not affect other scheduler instances in the cluster
- After reset, the scheduler will need to be reinitialized to be usable again
"""
with self.mutex:
self.client.delete(self._request_queue_name(),
self._response_queue_name())
self.client.delete(self._request_queue_name(), self._response_queue_name())
self.client.zrem(self._load_table_name(), self.name)
self.local_responses = dict()
self.stolen_requests = dict()
@@ -843,9 +855,10 @@ class GlobalScheduler(object):
self.load_shards_num = load_shards_num
if reallocate:
self.shard = self._get_hash_slot(
self.name) % self.load_shards_num
self.shard = self._get_hash_slot(self.name) % self.load_shards_num
scheduler_logger.info("Scheduler has reload config, "
f"load_shards_num({old_load_shards_num} => {self.load_shards_num}) "
f"shard({old_shard} => {self.shard})")
scheduler_logger.info(
"Scheduler has reload config, "
f"load_shards_num({old_load_shards_num} => {self.load_shards_num}) "
f"shard({old_shard} => {self.shard})"
)