[LLM] First commit the llm deployment code

This commit is contained in:
jiangjiajun
2025-06-09 19:20:15 +08:00
parent 980c0a1d2c
commit 684703fd72
11814 changed files with 127294 additions and 1293102 deletions
+17
View File
@@ -0,0 +1,17 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .config import SchedulerConfig
+164
View File
@@ -0,0 +1,164 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import redis
from fastdeploy.utils import llm_logger
from .global_scheduler import GlobalScheduler
from .local_scheduler import LocalScheduler
class LocalSchedulerConfig:
"""
LocalSchedulerConfig class
"""
def __init__(self,
max_size: int = -1,
ttl: int = 900,
wait_response_timeout: float = 1,
**kwargs
):
self.max_size = max_size
self.ttl = ttl
self.wait_response_timeout = wait_response_timeout
def check(self):
"""
check config
"""
assert self.wait_response_timeout > 0, \
"LocalScheduler: `wait_response_timeout` must be greater than zero"
assert self.ttl > self.wait_response_timeout, \
"LocalScheduler: `ttl` must be greater than `wait_response_timeout`"
def print(self):
"""
print config
"""
llm_logger.info("LocalScheduler Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
class GlobalSchedulerConfig:
"""
GlobalSchedulerConfig class
"""
def __init__(self,
host: str = "127.0.0.1",
port: int = 6379,
db: int = 0,
password=None,
topic: str = "default",
ttl: int = 900,
wait_response_timeout: float = 1,
remote_write_time: int = 3,
**kwargs
):
self.host = host
self.port = port
self.db = db
self.password = password
self.topic = topic
self.ttl = ttl
self.wait_response_timeout = wait_response_timeout
self.remote_write_time = remote_write_time
def check(self):
"""
check config
"""
assert self.wait_response_timeout > 0, \
"GlobalScheduler: `wait_response_timeout` must be greater than zero"
assert self.remote_write_time > 0, \
"GlobalScheduler: `remote_write_time` must be greater than zero"
assert self.ttl > self.remote_write_time, \
"GlobalScheduler: `ttl` must be greater than `remote_write_time`"
assert self.ttl > self.wait_response_timeout, \
"GlobalScheduler: `ttl` must be greater than `wait_response_timeout`"
r = redis.Redis(self.host, self.port, self.db, self.password)
try:
response = r.ping()
if not response:
raise Exception("connect to redis failed")
finally:
r.close()
def print(self):
"""
print config
"""
llm_logger.info("GlobalScheduler Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
class SchedulerConfig:
"""
SchedulerConfig class
"""
def __init__(self, name="local", **kwargs):
self.name = name
self.config = None
if name == "local":
self.config = LocalSchedulerConfig(**kwargs)
if name == "global":
self.config = GlobalSchedulerConfig(**kwargs)
def check(self):
"""
check config
"""
if self.name not in ["local", "global"]:
raise Exception(
"SchedulerConfig: `name` must be `local` or `global`")
self.config.check()
def print(self):
"""
print config
"""
self.config.print()
def scheduler(self):
"""
create scheduler by config
"""
if self.name == "global":
return GlobalScheduler(host=self.config.host,
port=self.config.port,
db=self.config.db,
password=self.config.password,
topic=self.config.topic,
ttl=self.config.ttl,
remote_write_time=self.config.remote_write_time,
wait_response_timeout=self.config.wait_response_timeout)
return LocalScheduler(max_size=self.config.max_size,
ttl=self.config.ttl,
wait_response_timeout=self.config.wait_response_timeout
)
+75
View File
@@ -0,0 +1,75 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import time
import json
from fastdeploy.engine.request import Request, RequestOutput
class ScheduledRequest(object):
"""
ScheduledRequest class
"""
def __init__(self, request: Request):
self.raw: Request = request
self.id = request.request_id
self.scheduled_time = time.time()
self.size = len(request.prompt_token_ids)
def serialize(self) -> bytes:
"""serialize to bytes"""
data = {
"scheduled_time": self.scheduled_time,
"raw": self.raw.to_dict()
}
serialized_data = json.dumps(data, ensure_ascii=False)
return serialized_data.encode()
@classmethod
def unserialize(cls, serialized_data: bytes) -> 'ScheduledRequest':
"""unserialize to Request"""
data = json.loads(serialized_data)
request = Request.from_dict(data["raw"])
scheduled_request = cls(request)
scheduled_request.scheduled_time = data["scheduled_time"]
return scheduled_request
class ScheduledResponse(object):
"""
ScheduledResponse class
"""
def __init__(self, response: RequestOutput):
self.raw: RequestOutput = response
self.id = response.request_id
self.index = response.outputs.index
self.finished = response.finished
def serialize(self) -> bytes:
"""serialize to bytes"""
data = self.raw.to_dict()
serialized_data = json.dumps(data, ensure_ascii=False)
return serialized_data.encode()
@classmethod
def unserialize(cls, serialized_data: bytes) -> 'ScheduledResponse':
"""unserialize to RequestOutput"""
data = json.loads(serialized_data)
request_output = RequestOutput.from_dict(data)
scheduled_response = ScheduledResponse(request_output)
return scheduled_response
+297
View File
@@ -0,0 +1,297 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import List, Optional, Dict, Tuple
import time
from redis import ConnectionPool
from fastdeploy.scheduler.storage import AdaptedRedis
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
from fastdeploy.scheduler.workers import Workers
from fastdeploy.utils import llm_logger
class GlobalScheduler(object):
"""
GlobalScheduler class
"""
def __init__(self,
host: str,
port: int,
db: int,
password: Optional[str],
topic: str,
ttl: int,
remote_write_time: int,
wait_response_timeout: float
):
self.topic = topic
self.ttl = ttl
self.remote_write_time = remote_write_time
self.wait_response_timeout = 1.0 if wait_response_timeout < 1.0 else wait_response_timeout
self.wait_request_timeout = 10
connection_pool = ConnectionPool(
host=host, port=port, db=db, password=password, max_connections=10)
self.client = AdaptedRedis(connection_pool=connection_pool)
self.put_request_workers = Workers(
"put_request_worker", self._put_requests_worker, max_batch_size=5)
self.put_request_workers.start(size=1)
self.put_response_workers = Workers(
"put_response_worker", self._put_results_worker, max_batch_size=50)
self.put_response_workers.start(size=1)
self.get_response_workers = Workers(
"get_response_worker", self._get_results_worker, max_batch_size=1)
self.get_response_workers.start(size=5)
self.response_max_batch = 50
llm_logger.info(f"Scheduler: redis version is {self.client.version}")
def _request_queue_name(self):
return f"{self.topic}.request"
def _response_queue_name(self, id: str):
return f"{self.topic}.response.{id}"
def _unique_key_name(self, id: str):
return f"{self.topic}.unique.{id}"
@staticmethod
def calc_required_blocks(token_num, block_size):
"""calculate required blocks for given token number"""
return (token_num + block_size - 1) // block_size
def _put_requests_worker(self, tasks: List[Tuple[str, Request]]) -> List[Tuple[str, Optional[str]]]:
"""
add requests to shared cache
"""
requests: List[ScheduledRequest] = [
ScheduledRequest(request) for _, request in tasks]
# check the uniqueness of the request_id
valid_requests: List[ScheduledRequest] = list()
duplicated_ids: List[str] = list()
for request in requests:
unique_key = self._unique_key_name(request.id)
if self.client.set(unique_key, "", ex=self.ttl, nx=True):
valid_requests.append(request)
else:
duplicated_ids.append(request.id)
# add to request queue
serialized_requests = [request.serialize()
for request in valid_requests]
self.client.rpush(self._request_queue_name(), *serialized_requests)
llm_logger.info(
f"Scheduler has put some requests: {[request.id for request in valid_requests]}")
main_process_metrics.num_requests_waiting.inc(len(valid_requests))
if len(duplicated_ids) > 0:
llm_logger.warning(
f"Scheduler has received some duplicated requests: {duplicated_ids}")
results = [(request.id, None) for request in valid_requests]
results += [(request_id, "duplicated request_id")
for request_id in duplicated_ids]
return results
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
"""
add requests to scheduler
"""
tasks: List[Tuple[str, Request]] = [
(request.request_id, request) for request in requests]
self.put_request_workers.put_tasks(tasks)
return self.put_request_workers.get_results(10, 0.005)
def get_requests(self, available_blocks, block_size, reserved_output_blocks,
max_num_batched_tokens, batch=1) -> List[Request]:
"""
get requests blocked from shared cache
"""
if available_blocks <= reserved_output_blocks or batch < 1:
llm_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
return []
batches = []
piece = (batch + 1) // 2
while batch > 0:
batch -= piece
if batch >= 0:
batches.append(piece)
else:
batches.append(piece + batch)
serialized_requests = []
for bs in batches:
bs_data = self.client.lpop(self._request_queue_name(), bs)
if bs_data is None:
break
serialized_requests += bs_data
if len(serialized_requests) == 0:
blocked_data = self.client.blpop(
self._request_queue_name(), self.wait_request_timeout)
if blocked_data is None:
return []
serialized_requests = blocked_data[1:]
required_total_blocks = 0
current_prefill_tokens = 0
remaining_request = []
requests: List[Request] = []
for serialized_request in serialized_requests:
if len(remaining_request) > 0:
remaining_request.append(serialized_request)
continue
request: ScheduledRequest = ScheduledRequest.unserialize(
serialized_request)
if (time.time() - request.scheduled_time) > self.ttl:
llm_logger.info(
f"Request has expired when getting a request from the scheduler: {[request.id]}")
continue
required_input_blocks = self.calc_required_blocks(
request.size, block_size)
current_prefill_tokens += request.size
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
remaining_request.append(serialized_request)
continue
requests.append(request.raw)
if len(remaining_request) > 0:
self.client.lpush(self._request_queue_name(), *remaining_request)
if len(requests) > 0:
llm_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
main_process_metrics.num_requests_running.inc(len(requests))
main_process_metrics.num_requests_waiting.dec(len(requests))
return requests
def _put_results_worker(self, tasks: List[Tuple[str, RequestOutput]]):
"""
add tasks to shared cache
"""
responses: List[ScheduledResponse] = [
ScheduledResponse(result) for _, result in tasks]
sorted_responses = sorted(
responses, key=lambda response: f"{response.id}.{response.index}")
finished_responses = [
response.id for response in responses if response.finished]
if len(finished_responses) > 0:
llm_logger.info(
f"Scheduler has received a finished response: {finished_responses}")
group = dict()
for response in sorted_responses:
serialized_response = response.serialize()
if response.id not in group:
group[response.id] = [serialized_response]
continue
group[response.id].append(serialized_response)
for response_id, responses in group.items():
ttl = self.client.ttl(self._unique_key_name(
response_id)) - self.remote_write_time
if ttl <= 0:
llm_logger.warning(
f"Scheduler has received a expired response: {[response.id]}")
continue
with self.client.pipeline() as pipe:
pipe.multi()
pipe.rpush(self._response_queue_name(response_id), *responses)
pipe.expire(self._response_queue_name(response_id), ttl)
pipe.execute()
def put_results(self, results: List[RequestOutput]):
"""
add results to shared cache
"""
tasks: List[Tuple[str, RequestOutput]] = [
(result.request_id, result) for result in results]
self.put_response_workers.put_tasks(tasks)
def _get_results_worker(self, tasks: List[Tuple[str, str]]) -> List[Tuple[str, List[ScheduledResponse]]]:
"""
get results blocked from shared cache
"""
if len(tasks) != 1:
raise ValueError(
f"Tasks size of _get_results_worker must be 1. ({len(tasks)})")
task_id, request_id = tasks[0]
key = self._response_queue_name(request_id)
size = self.client.llen(key)
size = min(size, self.response_max_batch)
serialized_responses = None
if size > 0:
serialized_responses = self.client.lpop(key, size)
if serialized_responses is None or len(serialized_responses) == 0:
blocked_data = self.client.blpop(key, self.wait_response_timeout)
if blocked_data is None:
return []
serialized_responses = blocked_data[1:]
output = [(task_id, [])]
for serialized_response in serialized_responses:
response = ScheduledResponse.unserialize(serialized_response)
output[0][1].append(response)
return output
def get_results(self, request_ids: List[str]) -> Dict[str, RequestOutput]:
"""
get results blocked from scheduler.
"""
tasks = [(request_id, request_id) for request_id in request_ids]
self.get_response_workers.put_tasks(tasks, deduplication=True)
batch_responses: List[Tuple[str, List[ScheduledResponse]]] = self.get_response_workers.get_results(
10, self.wait_response_timeout)
results = dict()
for _, responses in batch_responses:
for response in responses:
if response.id not in results:
results[response.id] = []
results[response.id].append(response)
if response.finished:
llm_logger.info(
f"Scheduler has pulled a finished response: {[response.id]}")
request_ids = list(results.keys())
for request_id in request_ids:
results[request_id] = sorted(
results[request_id], key=lambda response: f"{response.id}.{response.index}")
results[request_id] = [
result.raw for result in results[request_id]]
return results
+222
View File
@@ -0,0 +1,222 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Dict, List, Optional, Tuple
import threading
import time
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
class LocalScheduler(object):
"""
LocalScheduler Class
"""
def __init__(self,
max_size: int,
ttl: int,
wait_response_timeout: float):
self.max_size = max_size
self.ttl = ttl
self.mutex = threading.Lock()
self.ids_read_cursor = 0
self.ids: List[str] = list()
self.requests: Dict[str, ScheduledRequest] = dict()
self.responses: Dict[str, List[ScheduledResponse]] = dict()
self.wait_request_timeout = 10
self.wait_response_timeout = wait_response_timeout
self.requests_not_empty = threading.Condition(self.mutex)
self.responses_not_empty = threading.Condition(self.mutex)
def _recycle(self, request_id: Optional[str] = None):
"""
recycle memory
"""
if request_id is not None:
self.requests.pop(request_id, None)
self.responses.pop(request_id, None)
self.ids.pop(self.ids.index(request_id))
self.ids_read_cursor -= 1
return
if self.max_size <= 0:
return
if len(self.requests) <= self.max_size:
return
now = time.time()
expired_ids = []
for request_id in self.ids:
request = self.requests[request_id]
if (now - request.scheduled_time < self.ttl):
break
expired_ids.append(request.id)
for i, expired_id in enumerate(expired_ids):
self.requests.pop(expired_id, None)
self.responses.pop(expired_id, None)
self.ids.pop(i)
if len(expired_ids) > 0:
if len(expired_ids) - 1 >= self.ids_read_cursor:
self.ids_read_cursor = 0
else:
self.ids_read_cursor -= len(expired_ids)
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
""" submit requests to scheduler
Args:
requests: List[Request]
"""
with self.mutex:
self._recycle()
if self.max_size > 0 and len(self.requests) + len(requests) > self.max_size:
msg = f"Exceeding the max length of the local scheduler (max_size={self.max_size})"
return [(request.request_id, msg) for request in requests]
valid_ids = []
duplicated_ids = []
for request in requests:
if request.request_id in self.requests:
duplicated_ids.append(request.request_id)
else:
scheduled_request = ScheduledRequest(request)
self.requests[scheduled_request.id] = scheduled_request
valid_ids.append(scheduled_request.id)
self.ids += valid_ids
self.requests_not_empty.notify_all()
llm_logger.info(
f"Scheduler has put some requests: {valid_ids}")
main_process_metrics.num_requests_waiting.inc(len(valid_ids))
if len(duplicated_ids) > 0:
llm_logger.warning(
f"Scheduler has received some duplicated requests: {duplicated_ids}")
results = [(request_id, None) for request_id in valid_ids]
results += [(request_id, "duplicated request_id")
for request_id in duplicated_ids]
return results
def calc_required_blocks(self, token_num, block_size):
"""calculate required blocks for given token number"""
return (token_num + block_size - 1) // block_size
def get_requests(self, available_blocks, block_size,
reserved_output_blocks, max_num_batched_tokens, batch=1) -> List[Request]:
"""get requests from local cache
Args:
available_blocks: int
block_size: int
reserved_output_blocks: int
max_num_batched_tokens: int
batch: int
"""
if available_blocks <= reserved_output_blocks or batch < 1:
llm_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}")
return []
with self.requests_not_empty:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor:
self.ids_read_cursor + batch], self.wait_request_timeout)
required_total_blocks = 0
current_prefill_tokens = 0
requests: List[Request] = []
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(
request.size, block_size)
current_prefill_tokens += request.size
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
break
requests.append(request.raw)
self.ids_read_cursor += len(requests)
if len(requests) > 0:
llm_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
main_process_metrics.num_requests_waiting.dec(len(requests))
main_process_metrics.num_requests_running.inc(len(requests))
return requests
def put_results(self, results: List[RequestOutput]):
"""put results into local cache"""
responses: List[ScheduledResponse] = [
ScheduledResponse(result) for result in results]
finished_responses = [
response.id for response in responses if response.finished]
if len(finished_responses) > 0:
llm_logger.info(
f"Scheduler has received a finished response: {finished_responses}")
with self.mutex:
for response in responses:
if response.id not in self.requests:
llm_logger.warning(
f"Scheduler has received a expired response: {[response.id]}")
continue
if response.id not in self.responses:
self.responses[response.id] = [response]
continue
self.responses[response.id].append(response)
self.responses_not_empty.notify_all()
def get_results(self, request_ids: List[str]) -> Dict[str, List[RequestOutput]]:
"""get results from local cache"""
def _get_results():
responses = dict()
for request_id in request_ids:
if request_id not in responses:
responses[request_id] = []
responses[request_id] += self.responses.pop(request_id, [])
return responses
with self.responses_not_empty:
responses = self.responses_not_empty.wait_for(
_get_results, self.wait_response_timeout)
results = dict()
for request_id, resps in responses.items():
finished = False
results[request_id] = []
for resp in resps:
results[request_id].append(resp.raw)
finished |= resp.finished
if finished:
self._recycle(request_id)
llm_logger.info(
f"Scheduler has pulled a finished response: {[request_id]}")
return results
+96
View File
@@ -0,0 +1,96 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional, List
from redis.typing import Number
import redis
from packaging import version
import re
LUA_SCRIPT_LPOP = """
local key = KEYS[1]
local count = tonumber(ARGV[1])
local elements = redis.call('LRANGE', key, 0, count - 1)
local elementsCount = #elements
if elementsCount > 0 then
redis.call('LTRIM', key, count, -1)
end
return elements
"""
class AdaptedRedis(redis.Redis):
"""
AdaptedRedis class: Adapt to different versions of Redis
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._old_version = False
self._parse_version()
self._warm_up()
def _parse_version(self):
"""
parse version
"""
server_info = self.info(section='server')
version_string = server_info['redis_version']
match = re.search(r'^(\d+\.\d+\.\d+)', version_string)
if match:
redis_version = match.group(1)
else:
redis_version = "0.0.0"
current_version = version.parse(redis_version)
target_version = version.parse("6.2.28")
if current_version <= target_version:
self._old_version = True
self.version = redis_version
def _warm_up(self):
"""
preload some lua scripts
"""
if self._old_version:
self._lpop = self.register_script(LUA_SCRIPT_LPOP)
def lpop(self, name: str, count: Optional[int] = None):
"""
similar to redis lpop
"""
if self._old_version and count is not None:
return self._lpop(keys=[name], args=[count])
return super().lpop(name, count)
def blpop(self, keys: List, timeout: Optional[Number] = 0):
"""
similar to redis blpop
"""
if self._old_version:
if timeout > 0 and timeout < 1:
timeout = 1
timeout = int(timeout)
return super().blpop(keys=keys, timeout=timeout)
if timeout > 0 and timeout < 0.01:
timeout = 0.01
return super().blpop(keys=keys, timeout=timeout)
+166
View File
@@ -0,0 +1,166 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Callable, List, Tuple, Any, Dict, Optional
import functools
import threading
from fastdeploy.utils import llm_logger
class Workers:
"""
Workers class
"""
def __init__(self,
name: str,
work: Callable[[List[Tuple[str, Any]]], Optional[List[Tuple[str, Any]]]],
max_batch_size: int = 1):
self.name = name
self.work = work
self.max_batch_size = max_batch_size
self.mutex = threading.Lock()
self.pool = []
self.tasks_not_empty = threading.Condition(self.mutex)
self.results_not_empty = threading.Condition(self.mutex)
self.tasks: List[Tuple[str, Any]] = []
self.results: List[Tuple[str, Any]] = []
self.running_tasks: Dict[int, List[Tuple[str, Any]]] = dict()
self.not_stop = threading.Condition(self.mutex)
self.stop = False
self.stopped = 0
def _stop(self, func: Callable):
"""
a stop decorator
"""
@functools.wraps(func)
def wrapper():
if self.stop:
return True
return func()
return wrapper
def _worker(self, number: int):
"""
worker thread
"""
with self.mutex:
self.running_tasks[number] = []
@self._stop
def _get_tasks():
self.running_tasks[number] = []
batch = min((len(self.tasks) + len(self.pool) - 1) //
len(self.pool), self.max_batch_size)
tasks = self.tasks[:batch]
del self.tasks[:batch]
self.running_tasks[number] = tasks
return tasks
while True:
with self.tasks_not_empty:
tasks = self.tasks_not_empty.wait_for(_get_tasks)
if self.stop:
self.stopped += 1
if self.stopped == len(self.pool):
self.not_stop.notify_all()
return
results = []
try:
results = self.work(tasks)
except Exception as e:
llm_logger.info(f"Worker {self.name} execute error: {e}")
if results is not None and len(results) > 0:
with self.mutex:
self.results += results
self.results_not_empty.notify_all()
def start(self, size: int):
"""
start thread pood
"""
with self.mutex:
remain = size - len(self.pool)
if remain <= 0:
return
for i in range(remain):
t = threading.Thread(target=self._worker, args=(i,))
t.daemon = True
t.start()
self.pool.append(t)
def terminate(self):
"""
terminame thread pool
"""
with self.mutex:
self.stop = True
self.tasks_not_empty.notify_all()
self.results_not_empty.notify_all()
self.not_stop.wait_for(lambda: self.stopped == len(self.pool))
self.pool = []
self.tasks = []
self.results = []
self.running_tasks = dict()
self.stop = False
self.stopped = 0
def get_results(self, max_size: int, timeout: float) -> List[Tuple[str, Any]]:
"""
get results from thread pool.
"""
@self._stop
def _get_results():
results = self.results[:max_size]
del self.results[:max_size]
return results
with self.results_not_empty:
results = self.results_not_empty.wait_for(_get_results, timeout)
if self.stop:
return []
return results
def put_tasks(self, tasks: List[Tuple[str, Any]], deduplication: bool = False):
"""
put tasks into thread pool.
"""
if len(tasks) == 0:
return
with self.mutex:
if not deduplication:
self.tasks += tasks
else:
task_set = set([t[0] for t in self.tasks])
for _, running in self.running_tasks.items():
task_set.update([t[0] for t in running])
for task in tasks:
if task[0] in task_set:
continue
self.tasks.append(task)
self.tasks_not_empty.notify_all()