mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -14,24 +14,25 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
import traceback
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import random
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import crcmod
|
||||
from redis import ConnectionPool
|
||||
from fastdeploy.scheduler.storage import AdaptedRedis
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
from fastdeploy.scheduler.workers import Workers, Task
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
from fastdeploy.scheduler import utils
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
from fastdeploy.scheduler.storage import AdaptedRedis
|
||||
from fastdeploy.scheduler.workers import Task, Workers
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
|
||||
|
||||
class GlobalScheduler(object):
|
||||
class GlobalScheduler:
|
||||
"""
|
||||
A distributed task scheduler that manages request/response queues using Redis.
|
||||
|
||||
@@ -42,20 +43,21 @@ class GlobalScheduler(object):
|
||||
- Maintaining worker health checks
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
host: str,
|
||||
port: int,
|
||||
db: int,
|
||||
password: Optional[str],
|
||||
topic: str,
|
||||
ttl: int,
|
||||
min_load_score: float,
|
||||
load_shards_num: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
long_prefill_token_threshold: int,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
db: int,
|
||||
password: Optional[str],
|
||||
topic: str,
|
||||
ttl: int,
|
||||
min_load_score: float,
|
||||
load_shards_num: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
long_prefill_token_threshold: int,
|
||||
):
|
||||
"""
|
||||
Initialize the GlobalScheduler with Redis connection and scheduling parameters.
|
||||
|
||||
@@ -94,29 +96,25 @@ class GlobalScheduler(object):
|
||||
self.blpop_response_timeout = 10
|
||||
|
||||
self.crc16_mutex = threading.Lock()
|
||||
self.crc16 = crcmod.predefined.Crc('ccitt-false')
|
||||
self.crc16 = crcmod.predefined.Crc("ccitt-false")
|
||||
self.load_slot_for_getting_request = 0
|
||||
self.load_offset = 0 # const
|
||||
self.load_count = 50 # const
|
||||
self.load_offset = 0 # const
|
||||
self.load_count = 50 # const
|
||||
self.load_lookup_num = 5 # const
|
||||
self.keep_alive_duration = 30 # const
|
||||
|
||||
connection_pool = ConnectionPool(
|
||||
host=host, port=port, db=db, password=password, max_connections=10)
|
||||
connection_pool = ConnectionPool(host=host, port=port, db=db, password=password, max_connections=10)
|
||||
self.client = AdaptedRedis(connection_pool=connection_pool)
|
||||
|
||||
self.name, self.shard = self._generate_scheduler_name_and_shard()
|
||||
|
||||
self.keep_alive_workers = threading.Thread(
|
||||
target=self._keep_alive, daemon=True)
|
||||
self.keep_alive_workers = threading.Thread(target=self._keep_alive, daemon=True)
|
||||
self.keep_alive_workers.start()
|
||||
|
||||
self.put_requests_workers = Workers(
|
||||
"put_requests_workers", self._put_requests_worker, 20)
|
||||
self.put_requests_workers = Workers("put_requests_workers", self._put_requests_worker, 20)
|
||||
self.put_requests_workers.start(1)
|
||||
|
||||
self.put_results_workers = Workers(
|
||||
"put_results_workers", self._put_results_worker, 300)
|
||||
self.put_results_workers = Workers("put_results_workers", self._put_results_worker, 300)
|
||||
self.put_results_workers.start(1)
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
@@ -124,12 +122,10 @@ class GlobalScheduler(object):
|
||||
self.local_responses: Dict[str, List[ScheduledResponse]] = dict()
|
||||
self.stolen_requests: Dict[str, ScheduledRequest] = dict()
|
||||
|
||||
self.get_response_workers = threading.Thread(
|
||||
target=self._get_results_worker, daemon=True)
|
||||
self.get_response_workers = threading.Thread(target=self._get_results_worker, daemon=True)
|
||||
self.get_response_workers.start()
|
||||
|
||||
scheduler_logger.info(
|
||||
f"Scheduler: name={self.name} redis_version={self.client.version}")
|
||||
scheduler_logger.info(f"Scheduler: name={self.name} redis_version={self.client.version}")
|
||||
|
||||
def _get_hash_slot(self, data: str) -> int:
|
||||
"""
|
||||
@@ -184,8 +180,8 @@ class GlobalScheduler(object):
|
||||
4. Handles naming conflicts by appending incrementing suffixes
|
||||
|
||||
Returns:
|
||||
Tuple[str, int]:
|
||||
- str: Unique scheduler name
|
||||
Tuple[str, int]:
|
||||
- str: Unique scheduler name
|
||||
- int: Assigned shard number (0 to load_shards_num-1)
|
||||
|
||||
Implementation Details:
|
||||
@@ -202,21 +198,28 @@ class GlobalScheduler(object):
|
||||
try:
|
||||
_, name = utils.get_hostname_ip()
|
||||
except Exception as e:
|
||||
scheduler_logger.warning(
|
||||
f"Scheduler encountered an error while resolving the IP address. {e}")
|
||||
scheduler_logger.warning(f"Scheduler encountered an error while resolving the IP address. {e}")
|
||||
name = str(uuid.uuid4())
|
||||
|
||||
size = len(name)
|
||||
count = 1
|
||||
while True:
|
||||
if self.client.set(self._instance_name(name), "", ex=self.keep_alive_duration, nx=True):
|
||||
if self.client.set(
|
||||
self._instance_name(name),
|
||||
"",
|
||||
ex=self.keep_alive_duration,
|
||||
nx=True,
|
||||
):
|
||||
break
|
||||
name = f"{name[:size]}:{count}"
|
||||
count += 1
|
||||
|
||||
shard = self._get_hash_slot(name) % self.load_shards_num
|
||||
self.client.set(self._instance_name(name), self._load_table_name(shard=shard),
|
||||
ex=self.keep_alive_duration)
|
||||
self.client.set(
|
||||
self._instance_name(name),
|
||||
self._load_table_name(shard=shard),
|
||||
ex=self.keep_alive_duration,
|
||||
)
|
||||
return name, shard
|
||||
|
||||
def _keep_alive(self):
|
||||
@@ -227,8 +230,11 @@ class GlobalScheduler(object):
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
self.client.set(self._instance_name(
|
||||
self.name), self._load_table_name(), ex=self.keep_alive_duration)
|
||||
self.client.set(
|
||||
self._instance_name(self.name),
|
||||
self._load_table_name(),
|
||||
ex=self.keep_alive_duration,
|
||||
)
|
||||
time.sleep(self.keep_alive_duration / 2)
|
||||
except Exception as e:
|
||||
scheduler_logger.error(f"Scheduler keep alive failed: {e}")
|
||||
@@ -324,7 +330,7 @@ class GlobalScheduler(object):
|
||||
mark = f"mark<{request_queue_name}>"
|
||||
if not response.request_id.startswith(mark):
|
||||
return
|
||||
response.request_id = response.request_id[len(mark):]
|
||||
response.request_id = response.request_id[len(mark) :]
|
||||
|
||||
def _put_requests_worker(self, tasks: List[Task]) -> List[Task]:
|
||||
"""
|
||||
@@ -341,7 +347,10 @@ class GlobalScheduler(object):
|
||||
with self.mutex:
|
||||
for task in tasks:
|
||||
request = ScheduledRequest(
|
||||
task.raw, self._request_queue_name(), self._response_queue_name())
|
||||
task.raw,
|
||||
self._request_queue_name(),
|
||||
self._response_queue_name(),
|
||||
)
|
||||
task.raw = None
|
||||
|
||||
if request.request_id in self.local_responses:
|
||||
@@ -353,18 +362,21 @@ class GlobalScheduler(object):
|
||||
|
||||
if len(requests) > 0:
|
||||
serialized_requests = [request.serialize() for request in requests]
|
||||
self.client.rpush(self._request_queue_name(), *
|
||||
serialized_requests, ttl=self.ttl)
|
||||
self.client.zincrby(self._load_table_name(),
|
||||
len(serialized_requests), self.name,
|
||||
rem_amount=0, ttl=self.ttl)
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has enqueued some requests: {requests}")
|
||||
self.client.rpush(self._request_queue_name(), *serialized_requests, ttl=self.ttl)
|
||||
self.client.zincrby(
|
||||
self._load_table_name(),
|
||||
len(serialized_requests),
|
||||
self.name,
|
||||
rem_amount=0,
|
||||
ttl=self.ttl,
|
||||
)
|
||||
scheduler_logger.info(f"Scheduler has enqueued some requests: {requests}")
|
||||
|
||||
if duplicate:
|
||||
scheduler_logger.warning(
|
||||
"Scheduler has received some duplicated requests: "
|
||||
f"{[task for task in tasks if task.reason is not None]}")
|
||||
f"{[task for task in tasks if task.reason is not None]}"
|
||||
)
|
||||
return tasks
|
||||
|
||||
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
|
||||
@@ -386,8 +398,14 @@ class GlobalScheduler(object):
|
||||
results = self.put_requests_workers.get_results(10, 0.001)
|
||||
return [(result.id, result.reason) for result in results]
|
||||
|
||||
def get_requests(self, available_blocks, block_size, reserved_output_blocks,
|
||||
max_num_batched_tokens, batch=1) -> List[Request]:
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch=1,
|
||||
) -> List[Request]:
|
||||
"""
|
||||
Get requests from the shared cache based on available resources.
|
||||
|
||||
@@ -406,7 +424,8 @@ class GlobalScheduler(object):
|
||||
scheduler_logger.debug(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}")
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}"
|
||||
)
|
||||
return []
|
||||
|
||||
mini_batch = (batch + 1) // 2
|
||||
@@ -424,37 +443,38 @@ class GlobalScheduler(object):
|
||||
local_request_queue_name = self._request_queue_name()
|
||||
serialized_requests: List[Tuple[str, bytes]] = []
|
||||
for bs in batches:
|
||||
elements = self.client.lpop(
|
||||
local_request_queue_name, bs, ttl=self.ttl)
|
||||
elements = self.client.lpop(local_request_queue_name, bs, ttl=self.ttl)
|
||||
if elements is None:
|
||||
break
|
||||
self.client.zincrby(self._load_table_name(), -
|
||||
len(elements), self.name, rem_amount=0, ttl=self.ttl)
|
||||
serialized_requests += [(local_request_queue_name, element)
|
||||
for element in elements]
|
||||
self.client.zincrby(
|
||||
self._load_table_name(),
|
||||
-len(elements),
|
||||
self.name,
|
||||
rem_amount=0,
|
||||
ttl=self.ttl,
|
||||
)
|
||||
serialized_requests += [(local_request_queue_name, element) for element in elements]
|
||||
|
||||
extend_scheduler_names = []
|
||||
extend_scheduler_load_table_name = ""
|
||||
if len(serialized_requests) == 0 and len(batches) > 0:
|
||||
for _ in range(min(self.load_lookup_num, self.load_shards_num)):
|
||||
extend_scheduler_load_table_name = self._load_table_name(
|
||||
slot=self.load_slot_for_getting_request)
|
||||
extend_scheduler_load_table_name = self._load_table_name(slot=self.load_slot_for_getting_request)
|
||||
serialized_members = self.client.zrangebyscore(
|
||||
extend_scheduler_load_table_name,
|
||||
self.min_load_score,
|
||||
float("+inf"),
|
||||
start=self.load_offset,
|
||||
num=self.load_count)
|
||||
num=self.load_count,
|
||||
)
|
||||
self.load_slot_for_getting_request += 1
|
||||
if len(serialized_members) > 0:
|
||||
break
|
||||
|
||||
members = [member.decode("utf-8") for member in serialized_members]
|
||||
if len(members) > 0:
|
||||
extend_scheduler_names = random.sample(
|
||||
members, k=min(10, len(members)))
|
||||
extend_scheduler_names = [
|
||||
name for name in extend_scheduler_names if name != self.name]
|
||||
extend_scheduler_names = random.sample(members, k=min(10, len(members)))
|
||||
extend_scheduler_names = [name for name in extend_scheduler_names if name != self.name]
|
||||
|
||||
# find lucky one
|
||||
if len(extend_scheduler_names) > 0:
|
||||
@@ -463,40 +483,43 @@ class GlobalScheduler(object):
|
||||
|
||||
elements = self.client.lpop(lucky_request_queue_name, batches[0])
|
||||
if elements is not None and len(elements) > 0:
|
||||
self.client.zincrby(extend_scheduler_load_table_name,
|
||||
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
|
||||
serialized_requests += [(lucky_request_queue_name, element)
|
||||
for element in elements]
|
||||
self.client.zincrby(
|
||||
extend_scheduler_load_table_name,
|
||||
-len(elements),
|
||||
lucky,
|
||||
rem_amount=0,
|
||||
ttl=self.ttl,
|
||||
)
|
||||
serialized_requests += [(lucky_request_queue_name, element) for element in elements]
|
||||
scheduler_logger.info(
|
||||
f"Scheduler {self.name} has stolen some requests from another lucky one. "
|
||||
f"(name={lucky} num={len(serialized_requests)})")
|
||||
f"(name={lucky} num={len(serialized_requests)})"
|
||||
)
|
||||
else:
|
||||
exist_num = self.client.exists(self._instance_name(lucky))
|
||||
if exist_num == 0:
|
||||
if self.client.zrem(extend_scheduler_load_table_name, lucky):
|
||||
scheduler_logger.info(
|
||||
f"Scheduler {lucky} has been removed")
|
||||
scheduler_logger.info(f"Scheduler {lucky} has been removed")
|
||||
|
||||
# blocked read
|
||||
if len(serialized_requests) == 0:
|
||||
request_queue_names = [local_request_queue_name]
|
||||
request_queue_names += [
|
||||
self._request_queue_name(name) for name in extend_scheduler_names]
|
||||
request_queue_names += [self._request_queue_name(name) for name in extend_scheduler_names]
|
||||
|
||||
element = self.client.blpop(
|
||||
request_queue_names, self.blpop_request_timeout)
|
||||
element = self.client.blpop(request_queue_names, self.blpop_request_timeout)
|
||||
if element is None:
|
||||
return []
|
||||
request_queue_name = element[0].decode("utf-8")
|
||||
scheduler_name = self._scheduler_name_from_request_queue(
|
||||
request_queue_name)
|
||||
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
|
||||
self.client.zincrby(load_table_name,
|
||||
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
|
||||
scheduler_name = self._scheduler_name_from_request_queue(request_queue_name)
|
||||
load_table_name = (
|
||||
extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
|
||||
)
|
||||
self.client.zincrby(load_table_name, -1, scheduler_name, rem_amount=0, ttl=self.ttl)
|
||||
serialized_requests.append((request_queue_name, element[1]))
|
||||
if scheduler_name != self.name:
|
||||
scheduler_logger.info(
|
||||
f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})")
|
||||
f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})"
|
||||
)
|
||||
|
||||
long_partial_requests = 0
|
||||
short_partial_requests = 0
|
||||
@@ -506,41 +529,34 @@ class GlobalScheduler(object):
|
||||
scheduled_requests: List[ScheduledRequest] = []
|
||||
for request_queue_name, serialized_request in serialized_requests:
|
||||
if len(remaining_request) > 0:
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
|
||||
request: ScheduledRequest = ScheduledRequest.unserialize(
|
||||
serialized_request)
|
||||
required_input_blocks = self.calc_required_blocks(
|
||||
request.prompt_tokens_ids_len, block_size)
|
||||
request: ScheduledRequest = ScheduledRequest.unserialize(serialized_request)
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
|
||||
if required_total_blocks > available_blocks:
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
|
||||
if self.enable_chunked_prefill:
|
||||
if request.prompt_tokens_ids_len > self.long_prefill_token_threshold:
|
||||
long_partial_requests += 1
|
||||
if long_partial_requests > self.max_long_partial_prefills:
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
else:
|
||||
short_partial_requests += 1
|
||||
|
||||
if short_partial_requests + long_partial_requests > self.max_num_partial_prefills:
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
else:
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
remaining_request.append(
|
||||
(request_queue_name, serialized_request))
|
||||
remaining_request.append((request_queue_name, serialized_request))
|
||||
continue
|
||||
|
||||
scheduled_requests.append(request)
|
||||
@@ -556,11 +572,9 @@ class GlobalScheduler(object):
|
||||
self.stolen_requests[request.request_id] = request
|
||||
continue
|
||||
|
||||
scheduler_logger.error(
|
||||
f"Scheduler has received a duplicate request from others: {request}")
|
||||
scheduler_logger.error(f"Scheduler has received a duplicate request from others: {request}")
|
||||
|
||||
requests: List[Request] = [
|
||||
request.raw for request in scheduled_requests]
|
||||
requests: List[Request] = [request.raw for request in scheduled_requests]
|
||||
if len(remaining_request) > 0:
|
||||
group: Dict[str, List] = dict()
|
||||
for request_queue_name, serialized_request in remaining_request:
|
||||
@@ -569,23 +583,26 @@ class GlobalScheduler(object):
|
||||
group[request_queue_name].append(serialized_request)
|
||||
|
||||
for request_queue_name, serialized_requests in group.items():
|
||||
self.client.lpush(request_queue_name, *
|
||||
serialized_requests)
|
||||
scheduler_name = self._scheduler_name_from_request_queue(
|
||||
request_queue_name)
|
||||
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
|
||||
self.client.zincrby(load_table_name,
|
||||
len(serialized_requests), scheduler_name, ttl=self.ttl)
|
||||
self.client.lpush(request_queue_name, *serialized_requests)
|
||||
scheduler_name = self._scheduler_name_from_request_queue(request_queue_name)
|
||||
load_table_name = (
|
||||
extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
|
||||
)
|
||||
self.client.zincrby(
|
||||
load_table_name,
|
||||
len(serialized_requests),
|
||||
scheduler_name,
|
||||
ttl=self.ttl,
|
||||
)
|
||||
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has put remaining request into the queue: {len(remaining_request)}")
|
||||
scheduler_logger.info(f"Scheduler has put remaining request into the queue: {len(remaining_request)}")
|
||||
if len(requests) == 0:
|
||||
scheduler_logger.debug(
|
||||
f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}")
|
||||
f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}"
|
||||
)
|
||||
|
||||
if len(requests) > 0:
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
|
||||
scheduler_logger.info(f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
|
||||
return requests
|
||||
|
||||
def _put_results_worker(self, tasks: List[Task]):
|
||||
@@ -623,17 +640,15 @@ class GlobalScheduler(object):
|
||||
|
||||
if response.request_id in stolen_request_id_request_queue:
|
||||
response_queue_name = stolen_request_id_response_queue[response.request_id]
|
||||
request_queue_name = stolen_request_id_request_queue[response.request_id]
|
||||
# request_queue_name = stolen_request_id_request_queue[response.request_id]
|
||||
# self._unmark_response(response, request_queue_name)
|
||||
|
||||
if response_queue_name not in stolen_responses:
|
||||
stolen_responses[response_queue_name] = []
|
||||
stolen_responses[response_queue_name].append(
|
||||
response.serialize())
|
||||
stolen_responses[response_queue_name].append(response.serialize())
|
||||
continue
|
||||
|
||||
scheduler_logger.error(
|
||||
f"Scheduler has recieved a non-existent response from engine: {[response]}")
|
||||
scheduler_logger.error(f"Scheduler has recieved a non-existent response from engine: {[response]}")
|
||||
|
||||
with self.mutex:
|
||||
for request_id, responses in local_responses.items():
|
||||
@@ -648,8 +663,7 @@ class GlobalScheduler(object):
|
||||
self.local_response_not_empty.notify_all()
|
||||
|
||||
if len(finished_request_ids) > 0:
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has received some finished responses: {finished_request_ids}")
|
||||
scheduler_logger.info(f"Scheduler has received some finished responses: {finished_request_ids}")
|
||||
|
||||
for response_queue_name, responses in stolen_responses.items():
|
||||
self.client.rpush(response_queue_name, *responses, ttl=self.ttl)
|
||||
@@ -663,8 +677,7 @@ class GlobalScheduler(object):
|
||||
Args:
|
||||
results: List of RequestOutput objects to return
|
||||
"""
|
||||
tasks: List[Task] = [Task(result.request_id, result)
|
||||
for result in results]
|
||||
tasks: List[Task] = [Task(result.request_id, result) for result in results]
|
||||
self.put_results_workers.add_tasks(tasks)
|
||||
|
||||
# ---- for test ----
|
||||
@@ -684,20 +697,20 @@ class GlobalScheduler(object):
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
serialized_responses = self.client.lpop(
|
||||
self._response_queue_name(), 300, ttl=self.ttl)
|
||||
serialized_responses = self.client.lpop(self._response_queue_name(), 300, ttl=self.ttl)
|
||||
|
||||
if serialized_responses is None or len(serialized_responses) == 0:
|
||||
element = self.client.blpop(
|
||||
[self._response_queue_name()], self.blpop_response_timeout)
|
||||
[self._response_queue_name()],
|
||||
self.blpop_response_timeout,
|
||||
)
|
||||
if element is None or len(element) == 0:
|
||||
continue
|
||||
serialized_responses = [element[1]]
|
||||
|
||||
responses: Dict[str, List[ScheduledResponse]] = dict()
|
||||
for serialized_response in serialized_responses:
|
||||
response = ScheduledResponse.unserialize(
|
||||
serialized_response)
|
||||
response = ScheduledResponse.unserialize(serialized_response)
|
||||
if response.request_id not in responses:
|
||||
responses[response.request_id] = []
|
||||
responses[response.request_id].append(response)
|
||||
@@ -707,13 +720,15 @@ class GlobalScheduler(object):
|
||||
if request_id not in self.local_responses:
|
||||
scheduler_logger.error(
|
||||
"Scheduler has received some non-existent response from the queue. "
|
||||
f"response:{contents} queue:{self._response_queue_name()}")
|
||||
f"response:{contents} queue:{self._response_queue_name()}"
|
||||
)
|
||||
continue
|
||||
self.local_responses[request_id] += contents
|
||||
self.local_response_not_empty.notify_all()
|
||||
except Exception as e:
|
||||
scheduler_logger.error(f"Scheduler get_results_worker exception: {e} "
|
||||
f"traceback: {traceback.format_exc()}")
|
||||
scheduler_logger.error(
|
||||
f"Scheduler get_results_worker exception: {e} " f"traceback: {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def get_results(self) -> Dict[str, List[RequestOutput]]:
|
||||
"""
|
||||
@@ -732,7 +747,7 @@ class GlobalScheduler(object):
|
||||
4. Automatically cleans up completed request tracking
|
||||
|
||||
Returns:
|
||||
Dict[str, List[RequestOutput]]:
|
||||
Dict[str, List[RequestOutput]]:
|
||||
A dictionary where:
|
||||
- Key is the request ID
|
||||
- Value is a list of RequestOutput objects for that request
|
||||
@@ -765,8 +780,7 @@ class GlobalScheduler(object):
|
||||
return responses
|
||||
|
||||
with self.local_response_not_empty:
|
||||
responses: Dict[str, List[ScheduledResponse]] = self.local_response_not_empty.wait_for(
|
||||
_get_results, 0.001)
|
||||
responses: Dict[str, List[ScheduledResponse]] = self.local_response_not_empty.wait_for(_get_results, 0.001)
|
||||
|
||||
results: Dict[str, List[RequestOutput]] = dict()
|
||||
for request_id, resps in responses.items():
|
||||
@@ -778,8 +792,7 @@ class GlobalScheduler(object):
|
||||
|
||||
if finished:
|
||||
del self.local_responses[request_id]
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has pulled a finished response: {[request_id]}")
|
||||
scheduler_logger.info(f"Scheduler has pulled a finished response: {[request_id]}")
|
||||
return results
|
||||
|
||||
def reset(self):
|
||||
@@ -800,14 +813,13 @@ class GlobalScheduler(object):
|
||||
- Clears the local_responses dictionary tracking pending responses
|
||||
- Clears the stolen_requests dictionary tracking requests taken from other schedulers
|
||||
|
||||
Note:
|
||||
Note:
|
||||
- Uses the scheduler's mutex to ensure thread safety
|
||||
- Does not affect other scheduler instances in the cluster
|
||||
- After reset, the scheduler will need to be reinitialized to be usable again
|
||||
"""
|
||||
with self.mutex:
|
||||
self.client.delete(self._request_queue_name(),
|
||||
self._response_queue_name())
|
||||
self.client.delete(self._request_queue_name(), self._response_queue_name())
|
||||
self.client.zrem(self._load_table_name(), self.name)
|
||||
self.local_responses = dict()
|
||||
self.stolen_requests = dict()
|
||||
@@ -843,9 +855,10 @@ class GlobalScheduler(object):
|
||||
self.load_shards_num = load_shards_num
|
||||
|
||||
if reallocate:
|
||||
self.shard = self._get_hash_slot(
|
||||
self.name) % self.load_shards_num
|
||||
self.shard = self._get_hash_slot(self.name) % self.load_shards_num
|
||||
|
||||
scheduler_logger.info("Scheduler has reload config, "
|
||||
f"load_shards_num({old_load_shards_num} => {self.load_shards_num}) "
|
||||
f"shard({old_shard} => {self.shard})")
|
||||
scheduler_logger.info(
|
||||
"Scheduler has reload config, "
|
||||
f"load_shards_num({old_load_shards_num} => {self.load_shards_num}) "
|
||||
f"shard({old_shard} => {self.shard})"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user