mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[LLM] First commit the llm deployment code
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .config import SchedulerConfig
|
||||
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import redis
|
||||
from fastdeploy.utils import llm_logger
|
||||
from .global_scheduler import GlobalScheduler
|
||||
from .local_scheduler import LocalScheduler
|
||||
|
||||
|
||||
class LocalSchedulerConfig:
|
||||
"""
|
||||
LocalSchedulerConfig class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_size: int = -1,
|
||||
ttl: int = 900,
|
||||
wait_response_timeout: float = 1,
|
||||
**kwargs
|
||||
):
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
self.wait_response_timeout = wait_response_timeout
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
check config
|
||||
"""
|
||||
assert self.wait_response_timeout > 0, \
|
||||
"LocalScheduler: `wait_response_timeout` must be greater than zero"
|
||||
assert self.ttl > self.wait_response_timeout, \
|
||||
"LocalScheduler: `ttl` must be greater than `wait_response_timeout`"
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print config
|
||||
"""
|
||||
llm_logger.info("LocalScheduler Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
class GlobalSchedulerConfig:
|
||||
"""
|
||||
GlobalSchedulerConfig class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password=None,
|
||||
topic: str = "default",
|
||||
ttl: int = 900,
|
||||
wait_response_timeout: float = 1,
|
||||
remote_write_time: int = 3,
|
||||
**kwargs
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.db = db
|
||||
self.password = password
|
||||
self.topic = topic
|
||||
self.ttl = ttl
|
||||
self.wait_response_timeout = wait_response_timeout
|
||||
self.remote_write_time = remote_write_time
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
check config
|
||||
"""
|
||||
assert self.wait_response_timeout > 0, \
|
||||
"GlobalScheduler: `wait_response_timeout` must be greater than zero"
|
||||
assert self.remote_write_time > 0, \
|
||||
"GlobalScheduler: `remote_write_time` must be greater than zero"
|
||||
assert self.ttl > self.remote_write_time, \
|
||||
"GlobalScheduler: `ttl` must be greater than `remote_write_time`"
|
||||
assert self.ttl > self.wait_response_timeout, \
|
||||
"GlobalScheduler: `ttl` must be greater than `wait_response_timeout`"
|
||||
|
||||
r = redis.Redis(self.host, self.port, self.db, self.password)
|
||||
try:
|
||||
response = r.ping()
|
||||
if not response:
|
||||
raise Exception("connect to redis failed")
|
||||
finally:
|
||||
r.close()
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print config
|
||||
"""
|
||||
llm_logger.info("GlobalScheduler Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
class SchedulerConfig:
|
||||
"""
|
||||
SchedulerConfig class
|
||||
"""
|
||||
|
||||
def __init__(self, name="local", **kwargs):
|
||||
self.name = name
|
||||
self.config = None
|
||||
|
||||
if name == "local":
|
||||
self.config = LocalSchedulerConfig(**kwargs)
|
||||
|
||||
if name == "global":
|
||||
self.config = GlobalSchedulerConfig(**kwargs)
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
check config
|
||||
"""
|
||||
if self.name not in ["local", "global"]:
|
||||
raise Exception(
|
||||
"SchedulerConfig: `name` must be `local` or `global`")
|
||||
|
||||
self.config.check()
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print config
|
||||
"""
|
||||
self.config.print()
|
||||
|
||||
def scheduler(self):
|
||||
"""
|
||||
create scheduler by config
|
||||
"""
|
||||
|
||||
if self.name == "global":
|
||||
return GlobalScheduler(host=self.config.host,
|
||||
port=self.config.port,
|
||||
db=self.config.db,
|
||||
password=self.config.password,
|
||||
topic=self.config.topic,
|
||||
ttl=self.config.ttl,
|
||||
remote_write_time=self.config.remote_write_time,
|
||||
wait_response_timeout=self.config.wait_response_timeout)
|
||||
|
||||
return LocalScheduler(max_size=self.config.max_size,
|
||||
ttl=self.config.ttl,
|
||||
wait_response_timeout=self.config.wait_response_timeout
|
||||
)
|
||||
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
|
||||
|
||||
class ScheduledRequest(object):
|
||||
"""
|
||||
ScheduledRequest class
|
||||
"""
|
||||
|
||||
def __init__(self, request: Request):
|
||||
self.raw: Request = request
|
||||
self.id = request.request_id
|
||||
self.scheduled_time = time.time()
|
||||
self.size = len(request.prompt_token_ids)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""serialize to bytes"""
|
||||
data = {
|
||||
"scheduled_time": self.scheduled_time,
|
||||
"raw": self.raw.to_dict()
|
||||
}
|
||||
serialized_data = json.dumps(data, ensure_ascii=False)
|
||||
return serialized_data.encode()
|
||||
|
||||
@classmethod
|
||||
def unserialize(cls, serialized_data: bytes) -> 'ScheduledRequest':
|
||||
"""unserialize to Request"""
|
||||
data = json.loads(serialized_data)
|
||||
request = Request.from_dict(data["raw"])
|
||||
scheduled_request = cls(request)
|
||||
scheduled_request.scheduled_time = data["scheduled_time"]
|
||||
return scheduled_request
|
||||
|
||||
|
||||
class ScheduledResponse(object):
|
||||
"""
|
||||
ScheduledResponse class
|
||||
"""
|
||||
|
||||
def __init__(self, response: RequestOutput):
|
||||
self.raw: RequestOutput = response
|
||||
self.id = response.request_id
|
||||
self.index = response.outputs.index
|
||||
self.finished = response.finished
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""serialize to bytes"""
|
||||
data = self.raw.to_dict()
|
||||
serialized_data = json.dumps(data, ensure_ascii=False)
|
||||
return serialized_data.encode()
|
||||
|
||||
@classmethod
|
||||
def unserialize(cls, serialized_data: bytes) -> 'ScheduledResponse':
|
||||
"""unserialize to RequestOutput"""
|
||||
data = json.loads(serialized_data)
|
||||
request_output = RequestOutput.from_dict(data)
|
||||
scheduled_response = ScheduledResponse(request_output)
|
||||
return scheduled_response
|
||||
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
import time
|
||||
from redis import ConnectionPool
|
||||
from fastdeploy.scheduler.storage import AdaptedRedis
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
from fastdeploy.scheduler.workers import Workers
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class GlobalScheduler(object):
|
||||
"""
|
||||
GlobalScheduler class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
host: str,
|
||||
port: int,
|
||||
db: int,
|
||||
password: Optional[str],
|
||||
topic: str,
|
||||
ttl: int,
|
||||
remote_write_time: int,
|
||||
wait_response_timeout: float
|
||||
):
|
||||
|
||||
self.topic = topic
|
||||
self.ttl = ttl
|
||||
self.remote_write_time = remote_write_time
|
||||
self.wait_response_timeout = 1.0 if wait_response_timeout < 1.0 else wait_response_timeout
|
||||
self.wait_request_timeout = 10
|
||||
|
||||
connection_pool = ConnectionPool(
|
||||
host=host, port=port, db=db, password=password, max_connections=10)
|
||||
self.client = AdaptedRedis(connection_pool=connection_pool)
|
||||
|
||||
self.put_request_workers = Workers(
|
||||
"put_request_worker", self._put_requests_worker, max_batch_size=5)
|
||||
self.put_request_workers.start(size=1)
|
||||
|
||||
self.put_response_workers = Workers(
|
||||
"put_response_worker", self._put_results_worker, max_batch_size=50)
|
||||
self.put_response_workers.start(size=1)
|
||||
|
||||
self.get_response_workers = Workers(
|
||||
"get_response_worker", self._get_results_worker, max_batch_size=1)
|
||||
self.get_response_workers.start(size=5)
|
||||
self.response_max_batch = 50
|
||||
|
||||
llm_logger.info(f"Scheduler: redis version is {self.client.version}")
|
||||
|
||||
def _request_queue_name(self):
|
||||
return f"{self.topic}.request"
|
||||
|
||||
def _response_queue_name(self, id: str):
|
||||
return f"{self.topic}.response.{id}"
|
||||
|
||||
def _unique_key_name(self, id: str):
|
||||
return f"{self.topic}.unique.{id}"
|
||||
|
||||
@staticmethod
|
||||
def calc_required_blocks(token_num, block_size):
|
||||
"""calculate required blocks for given token number"""
|
||||
return (token_num + block_size - 1) // block_size
|
||||
|
||||
def _put_requests_worker(self, tasks: List[Tuple[str, Request]]) -> List[Tuple[str, Optional[str]]]:
|
||||
"""
|
||||
add requests to shared cache
|
||||
"""
|
||||
requests: List[ScheduledRequest] = [
|
||||
ScheduledRequest(request) for _, request in tasks]
|
||||
|
||||
# check the uniqueness of the request_id
|
||||
valid_requests: List[ScheduledRequest] = list()
|
||||
duplicated_ids: List[str] = list()
|
||||
for request in requests:
|
||||
unique_key = self._unique_key_name(request.id)
|
||||
if self.client.set(unique_key, "", ex=self.ttl, nx=True):
|
||||
valid_requests.append(request)
|
||||
else:
|
||||
duplicated_ids.append(request.id)
|
||||
|
||||
# add to request queue
|
||||
serialized_requests = [request.serialize()
|
||||
for request in valid_requests]
|
||||
self.client.rpush(self._request_queue_name(), *serialized_requests)
|
||||
llm_logger.info(
|
||||
f"Scheduler has put some requests: {[request.id for request in valid_requests]}")
|
||||
main_process_metrics.num_requests_waiting.inc(len(valid_requests))
|
||||
|
||||
if len(duplicated_ids) > 0:
|
||||
llm_logger.warning(
|
||||
f"Scheduler has received some duplicated requests: {duplicated_ids}")
|
||||
|
||||
results = [(request.id, None) for request in valid_requests]
|
||||
results += [(request_id, "duplicated request_id")
|
||||
for request_id in duplicated_ids]
|
||||
return results
|
||||
|
||||
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
|
||||
"""
|
||||
add requests to scheduler
|
||||
"""
|
||||
tasks: List[Tuple[str, Request]] = [
|
||||
(request.request_id, request) for request in requests]
|
||||
self.put_request_workers.put_tasks(tasks)
|
||||
return self.put_request_workers.get_results(10, 0.005)
|
||||
|
||||
def get_requests(self, available_blocks, block_size, reserved_output_blocks,
|
||||
max_num_batched_tokens, batch=1) -> List[Request]:
|
||||
"""
|
||||
get requests blocked from shared cache
|
||||
"""
|
||||
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
llm_logger.debug(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}")
|
||||
return []
|
||||
|
||||
batches = []
|
||||
piece = (batch + 1) // 2
|
||||
while batch > 0:
|
||||
batch -= piece
|
||||
if batch >= 0:
|
||||
batches.append(piece)
|
||||
else:
|
||||
batches.append(piece + batch)
|
||||
|
||||
serialized_requests = []
|
||||
for bs in batches:
|
||||
bs_data = self.client.lpop(self._request_queue_name(), bs)
|
||||
if bs_data is None:
|
||||
break
|
||||
serialized_requests += bs_data
|
||||
|
||||
if len(serialized_requests) == 0:
|
||||
blocked_data = self.client.blpop(
|
||||
self._request_queue_name(), self.wait_request_timeout)
|
||||
if blocked_data is None:
|
||||
return []
|
||||
serialized_requests = blocked_data[1:]
|
||||
|
||||
required_total_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
remaining_request = []
|
||||
requests: List[Request] = []
|
||||
for serialized_request in serialized_requests:
|
||||
if len(remaining_request) > 0:
|
||||
remaining_request.append(serialized_request)
|
||||
continue
|
||||
|
||||
request: ScheduledRequest = ScheduledRequest.unserialize(
|
||||
serialized_request)
|
||||
if (time.time() - request.scheduled_time) > self.ttl:
|
||||
llm_logger.info(
|
||||
f"Request has expired when getting a request from the scheduler: {[request.id]}")
|
||||
continue
|
||||
|
||||
required_input_blocks = self.calc_required_blocks(
|
||||
request.size, block_size)
|
||||
current_prefill_tokens += request.size
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
|
||||
remaining_request.append(serialized_request)
|
||||
continue
|
||||
requests.append(request.raw)
|
||||
|
||||
if len(remaining_request) > 0:
|
||||
self.client.lpush(self._request_queue_name(), *remaining_request)
|
||||
|
||||
if len(requests) > 0:
|
||||
llm_logger.info(
|
||||
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
|
||||
main_process_metrics.num_requests_running.inc(len(requests))
|
||||
main_process_metrics.num_requests_waiting.dec(len(requests))
|
||||
return requests
|
||||
|
||||
def _put_results_worker(self, tasks: List[Tuple[str, RequestOutput]]):
|
||||
"""
|
||||
add tasks to shared cache
|
||||
"""
|
||||
responses: List[ScheduledResponse] = [
|
||||
ScheduledResponse(result) for _, result in tasks]
|
||||
sorted_responses = sorted(
|
||||
responses, key=lambda response: f"{response.id}.{response.index}")
|
||||
|
||||
finished_responses = [
|
||||
response.id for response in responses if response.finished]
|
||||
if len(finished_responses) > 0:
|
||||
llm_logger.info(
|
||||
f"Scheduler has received a finished response: {finished_responses}")
|
||||
|
||||
group = dict()
|
||||
for response in sorted_responses:
|
||||
serialized_response = response.serialize()
|
||||
if response.id not in group:
|
||||
group[response.id] = [serialized_response]
|
||||
continue
|
||||
group[response.id].append(serialized_response)
|
||||
|
||||
for response_id, responses in group.items():
|
||||
ttl = self.client.ttl(self._unique_key_name(
|
||||
response_id)) - self.remote_write_time
|
||||
if ttl <= 0:
|
||||
llm_logger.warning(
|
||||
f"Scheduler has received a expired response: {[response.id]}")
|
||||
continue
|
||||
|
||||
with self.client.pipeline() as pipe:
|
||||
pipe.multi()
|
||||
pipe.rpush(self._response_queue_name(response_id), *responses)
|
||||
pipe.expire(self._response_queue_name(response_id), ttl)
|
||||
pipe.execute()
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
"""
|
||||
add results to shared cache
|
||||
"""
|
||||
tasks: List[Tuple[str, RequestOutput]] = [
|
||||
(result.request_id, result) for result in results]
|
||||
self.put_response_workers.put_tasks(tasks)
|
||||
|
||||
def _get_results_worker(self, tasks: List[Tuple[str, str]]) -> List[Tuple[str, List[ScheduledResponse]]]:
|
||||
"""
|
||||
get results blocked from shared cache
|
||||
"""
|
||||
if len(tasks) != 1:
|
||||
raise ValueError(
|
||||
f"Tasks size of _get_results_worker must be 1. ({len(tasks)})")
|
||||
|
||||
task_id, request_id = tasks[0]
|
||||
key = self._response_queue_name(request_id)
|
||||
size = self.client.llen(key)
|
||||
size = min(size, self.response_max_batch)
|
||||
|
||||
serialized_responses = None
|
||||
if size > 0:
|
||||
serialized_responses = self.client.lpop(key, size)
|
||||
|
||||
if serialized_responses is None or len(serialized_responses) == 0:
|
||||
blocked_data = self.client.blpop(key, self.wait_response_timeout)
|
||||
if blocked_data is None:
|
||||
return []
|
||||
serialized_responses = blocked_data[1:]
|
||||
|
||||
output = [(task_id, [])]
|
||||
for serialized_response in serialized_responses:
|
||||
response = ScheduledResponse.unserialize(serialized_response)
|
||||
output[0][1].append(response)
|
||||
return output
|
||||
|
||||
def get_results(self, request_ids: List[str]) -> Dict[str, RequestOutput]:
|
||||
"""
|
||||
get results blocked from scheduler.
|
||||
"""
|
||||
tasks = [(request_id, request_id) for request_id in request_ids]
|
||||
self.get_response_workers.put_tasks(tasks, deduplication=True)
|
||||
batch_responses: List[Tuple[str, List[ScheduledResponse]]] = self.get_response_workers.get_results(
|
||||
10, self.wait_response_timeout)
|
||||
|
||||
results = dict()
|
||||
for _, responses in batch_responses:
|
||||
for response in responses:
|
||||
if response.id not in results:
|
||||
results[response.id] = []
|
||||
results[response.id].append(response)
|
||||
if response.finished:
|
||||
llm_logger.info(
|
||||
f"Scheduler has pulled a finished response: {[response.id]}")
|
||||
|
||||
request_ids = list(results.keys())
|
||||
for request_id in request_ids:
|
||||
results[request_id] = sorted(
|
||||
results[request_id], key=lambda response: f"{response.id}.{response.index}")
|
||||
results[request_id] = [
|
||||
result.raw for result in results[request_id]]
|
||||
return results
|
||||
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import threading
|
||||
import time
|
||||
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import llm_logger
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
|
||||
|
||||
class LocalScheduler(object):
|
||||
"""
|
||||
LocalScheduler Class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_size: int,
|
||||
ttl: int,
|
||||
wait_response_timeout: float):
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
self.mutex = threading.Lock()
|
||||
self.ids_read_cursor = 0
|
||||
self.ids: List[str] = list()
|
||||
|
||||
self.requests: Dict[str, ScheduledRequest] = dict()
|
||||
self.responses: Dict[str, List[ScheduledResponse]] = dict()
|
||||
|
||||
self.wait_request_timeout = 10
|
||||
self.wait_response_timeout = wait_response_timeout
|
||||
|
||||
self.requests_not_empty = threading.Condition(self.mutex)
|
||||
self.responses_not_empty = threading.Condition(self.mutex)
|
||||
|
||||
def _recycle(self, request_id: Optional[str] = None):
|
||||
"""
|
||||
recycle memory
|
||||
"""
|
||||
if request_id is not None:
|
||||
self.requests.pop(request_id, None)
|
||||
self.responses.pop(request_id, None)
|
||||
self.ids.pop(self.ids.index(request_id))
|
||||
self.ids_read_cursor -= 1
|
||||
return
|
||||
|
||||
if self.max_size <= 0:
|
||||
return
|
||||
|
||||
if len(self.requests) <= self.max_size:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
expired_ids = []
|
||||
for request_id in self.ids:
|
||||
request = self.requests[request_id]
|
||||
if (now - request.scheduled_time < self.ttl):
|
||||
break
|
||||
expired_ids.append(request.id)
|
||||
|
||||
for i, expired_id in enumerate(expired_ids):
|
||||
self.requests.pop(expired_id, None)
|
||||
self.responses.pop(expired_id, None)
|
||||
self.ids.pop(i)
|
||||
|
||||
if len(expired_ids) > 0:
|
||||
if len(expired_ids) - 1 >= self.ids_read_cursor:
|
||||
self.ids_read_cursor = 0
|
||||
else:
|
||||
self.ids_read_cursor -= len(expired_ids)
|
||||
|
||||
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
|
||||
""" submit requests to scheduler
|
||||
Args:
|
||||
requests: List[Request]
|
||||
"""
|
||||
with self.mutex:
|
||||
self._recycle()
|
||||
if self.max_size > 0 and len(self.requests) + len(requests) > self.max_size:
|
||||
msg = f"Exceeding the max length of the local scheduler (max_size={self.max_size})"
|
||||
return [(request.request_id, msg) for request in requests]
|
||||
|
||||
valid_ids = []
|
||||
duplicated_ids = []
|
||||
for request in requests:
|
||||
if request.request_id in self.requests:
|
||||
duplicated_ids.append(request.request_id)
|
||||
else:
|
||||
scheduled_request = ScheduledRequest(request)
|
||||
self.requests[scheduled_request.id] = scheduled_request
|
||||
valid_ids.append(scheduled_request.id)
|
||||
|
||||
self.ids += valid_ids
|
||||
self.requests_not_empty.notify_all()
|
||||
|
||||
llm_logger.info(
|
||||
f"Scheduler has put some requests: {valid_ids}")
|
||||
main_process_metrics.num_requests_waiting.inc(len(valid_ids))
|
||||
|
||||
if len(duplicated_ids) > 0:
|
||||
llm_logger.warning(
|
||||
f"Scheduler has received some duplicated requests: {duplicated_ids}")
|
||||
|
||||
results = [(request_id, None) for request_id in valid_ids]
|
||||
results += [(request_id, "duplicated request_id")
|
||||
for request_id in duplicated_ids]
|
||||
return results
|
||||
|
||||
def calc_required_blocks(self, token_num, block_size):
|
||||
"""calculate required blocks for given token number"""
|
||||
return (token_num + block_size - 1) // block_size
|
||||
|
||||
def get_requests(self, available_blocks, block_size,
|
||||
reserved_output_blocks, max_num_batched_tokens, batch=1) -> List[Request]:
|
||||
"""get requests from local cache
|
||||
Args:
|
||||
available_blocks: int
|
||||
block_size: int
|
||||
reserved_output_blocks: int
|
||||
max_num_batched_tokens: int
|
||||
batch: int
|
||||
"""
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
llm_logger.debug(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}")
|
||||
return []
|
||||
|
||||
with self.requests_not_empty:
|
||||
batch_ids = self.requests_not_empty.wait_for(
|
||||
lambda: self.ids[self.ids_read_cursor:
|
||||
self.ids_read_cursor + batch], self.wait_request_timeout)
|
||||
|
||||
required_total_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
requests: List[Request] = []
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(
|
||||
request.size, block_size)
|
||||
current_prefill_tokens += request.size
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
|
||||
break
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += len(requests)
|
||||
|
||||
if len(requests) > 0:
|
||||
llm_logger.info(
|
||||
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
|
||||
main_process_metrics.num_requests_waiting.dec(len(requests))
|
||||
main_process_metrics.num_requests_running.inc(len(requests))
|
||||
return requests
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
"""put results into local cache"""
|
||||
responses: List[ScheduledResponse] = [
|
||||
ScheduledResponse(result) for result in results]
|
||||
|
||||
finished_responses = [
|
||||
response.id for response in responses if response.finished]
|
||||
if len(finished_responses) > 0:
|
||||
llm_logger.info(
|
||||
f"Scheduler has received a finished response: {finished_responses}")
|
||||
|
||||
with self.mutex:
|
||||
for response in responses:
|
||||
if response.id not in self.requests:
|
||||
llm_logger.warning(
|
||||
f"Scheduler has received a expired response: {[response.id]}")
|
||||
continue
|
||||
|
||||
if response.id not in self.responses:
|
||||
self.responses[response.id] = [response]
|
||||
continue
|
||||
self.responses[response.id].append(response)
|
||||
self.responses_not_empty.notify_all()
|
||||
|
||||
def get_results(self, request_ids: List[str]) -> Dict[str, List[RequestOutput]]:
|
||||
"""get results from local cache"""
|
||||
def _get_results():
|
||||
responses = dict()
|
||||
for request_id in request_ids:
|
||||
if request_id not in responses:
|
||||
responses[request_id] = []
|
||||
responses[request_id] += self.responses.pop(request_id, [])
|
||||
return responses
|
||||
|
||||
with self.responses_not_empty:
|
||||
responses = self.responses_not_empty.wait_for(
|
||||
_get_results, self.wait_response_timeout)
|
||||
|
||||
results = dict()
|
||||
for request_id, resps in responses.items():
|
||||
finished = False
|
||||
results[request_id] = []
|
||||
for resp in resps:
|
||||
results[request_id].append(resp.raw)
|
||||
finished |= resp.finished
|
||||
|
||||
if finished:
|
||||
self._recycle(request_id)
|
||||
llm_logger.info(
|
||||
f"Scheduler has pulled a finished response: {[request_id]}")
|
||||
return results
|
||||
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Optional, List
|
||||
from redis.typing import Number
|
||||
import redis
|
||||
from packaging import version
|
||||
import re
|
||||
|
||||
|
||||
LUA_SCRIPT_LPOP = """
|
||||
local key = KEYS[1]
|
||||
local count = tonumber(ARGV[1])
|
||||
local elements = redis.call('LRANGE', key, 0, count - 1)
|
||||
local elementsCount = #elements
|
||||
if elementsCount > 0 then
|
||||
redis.call('LTRIM', key, count, -1)
|
||||
end
|
||||
return elements
|
||||
"""
|
||||
|
||||
class AdaptedRedis(redis.Redis):
|
||||
"""
|
||||
AdaptedRedis class: Adapt to different versions of Redis
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._old_version = False
|
||||
self._parse_version()
|
||||
self._warm_up()
|
||||
|
||||
def _parse_version(self):
|
||||
"""
|
||||
parse version
|
||||
"""
|
||||
server_info = self.info(section='server')
|
||||
version_string = server_info['redis_version']
|
||||
|
||||
match = re.search(r'^(\d+\.\d+\.\d+)', version_string)
|
||||
if match:
|
||||
redis_version = match.group(1)
|
||||
else:
|
||||
redis_version = "0.0.0"
|
||||
|
||||
current_version = version.parse(redis_version)
|
||||
target_version = version.parse("6.2.28")
|
||||
|
||||
if current_version <= target_version:
|
||||
self._old_version = True
|
||||
|
||||
self.version = redis_version
|
||||
|
||||
def _warm_up(self):
|
||||
"""
|
||||
preload some lua scripts
|
||||
"""
|
||||
if self._old_version:
|
||||
self._lpop = self.register_script(LUA_SCRIPT_LPOP)
|
||||
|
||||
def lpop(self, name: str, count: Optional[int] = None):
|
||||
"""
|
||||
similar to redis lpop
|
||||
"""
|
||||
if self._old_version and count is not None:
|
||||
return self._lpop(keys=[name], args=[count])
|
||||
return super().lpop(name, count)
|
||||
|
||||
def blpop(self, keys: List, timeout: Optional[Number] = 0):
|
||||
"""
|
||||
similar to redis blpop
|
||||
"""
|
||||
if self._old_version:
|
||||
if timeout > 0 and timeout < 1:
|
||||
timeout = 1
|
||||
timeout = int(timeout)
|
||||
return super().blpop(keys=keys, timeout=timeout)
|
||||
|
||||
if timeout > 0 and timeout < 0.01:
|
||||
timeout = 0.01
|
||||
return super().blpop(keys=keys, timeout=timeout)
|
||||
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable, List, Tuple, Any, Dict, Optional
|
||||
import functools
|
||||
import threading
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class Workers:
|
||||
"""
|
||||
Workers class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
work: Callable[[List[Tuple[str, Any]]], Optional[List[Tuple[str, Any]]]],
|
||||
max_batch_size: int = 1):
|
||||
self.name = name
|
||||
self.work = work
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.pool = []
|
||||
|
||||
self.tasks_not_empty = threading.Condition(self.mutex)
|
||||
self.results_not_empty = threading.Condition(self.mutex)
|
||||
|
||||
self.tasks: List[Tuple[str, Any]] = []
|
||||
self.results: List[Tuple[str, Any]] = []
|
||||
self.running_tasks: Dict[int, List[Tuple[str, Any]]] = dict()
|
||||
|
||||
self.not_stop = threading.Condition(self.mutex)
|
||||
self.stop = False
|
||||
self.stopped = 0
|
||||
|
||||
def _stop(self, func: Callable):
|
||||
"""
|
||||
a stop decorator
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper():
|
||||
if self.stop:
|
||||
return True
|
||||
return func()
|
||||
return wrapper
|
||||
|
||||
def _worker(self, number: int):
|
||||
"""
|
||||
worker thread
|
||||
"""
|
||||
with self.mutex:
|
||||
self.running_tasks[number] = []
|
||||
|
||||
@self._stop
|
||||
def _get_tasks():
|
||||
self.running_tasks[number] = []
|
||||
batch = min((len(self.tasks) + len(self.pool) - 1) //
|
||||
len(self.pool), self.max_batch_size)
|
||||
tasks = self.tasks[:batch]
|
||||
del self.tasks[:batch]
|
||||
self.running_tasks[number] = tasks
|
||||
return tasks
|
||||
|
||||
while True:
|
||||
with self.tasks_not_empty:
|
||||
tasks = self.tasks_not_empty.wait_for(_get_tasks)
|
||||
if self.stop:
|
||||
self.stopped += 1
|
||||
if self.stopped == len(self.pool):
|
||||
self.not_stop.notify_all()
|
||||
return
|
||||
|
||||
results = []
|
||||
try:
|
||||
results = self.work(tasks)
|
||||
except Exception as e:
|
||||
llm_logger.info(f"Worker {self.name} execute error: {e}")
|
||||
|
||||
if results is not None and len(results) > 0:
|
||||
with self.mutex:
|
||||
self.results += results
|
||||
self.results_not_empty.notify_all()
|
||||
|
||||
def start(self, size: int):
|
||||
"""
|
||||
start thread pood
|
||||
"""
|
||||
with self.mutex:
|
||||
remain = size - len(self.pool)
|
||||
if remain <= 0:
|
||||
return
|
||||
|
||||
for i in range(remain):
|
||||
t = threading.Thread(target=self._worker, args=(i,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
self.pool.append(t)
|
||||
|
||||
def terminate(self):
|
||||
"""
|
||||
terminame thread pool
|
||||
"""
|
||||
with self.mutex:
|
||||
self.stop = True
|
||||
self.tasks_not_empty.notify_all()
|
||||
self.results_not_empty.notify_all()
|
||||
|
||||
self.not_stop.wait_for(lambda: self.stopped == len(self.pool))
|
||||
|
||||
self.pool = []
|
||||
self.tasks = []
|
||||
self.results = []
|
||||
self.running_tasks = dict()
|
||||
self.stop = False
|
||||
self.stopped = 0
|
||||
|
||||
def get_results(self, max_size: int, timeout: float) -> List[Tuple[str, Any]]:
|
||||
"""
|
||||
get results from thread pool.
|
||||
"""
|
||||
@self._stop
|
||||
def _get_results():
|
||||
results = self.results[:max_size]
|
||||
del self.results[:max_size]
|
||||
return results
|
||||
|
||||
with self.results_not_empty:
|
||||
results = self.results_not_empty.wait_for(_get_results, timeout)
|
||||
if self.stop:
|
||||
return []
|
||||
return results
|
||||
|
||||
def put_tasks(self, tasks: List[Tuple[str, Any]], deduplication: bool = False):
|
||||
"""
|
||||
put tasks into thread pool.
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
return
|
||||
|
||||
with self.mutex:
|
||||
if not deduplication:
|
||||
self.tasks += tasks
|
||||
else:
|
||||
task_set = set([t[0] for t in self.tasks])
|
||||
for _, running in self.running_tasks.items():
|
||||
task_set.update([t[0] for t in running])
|
||||
|
||||
for task in tasks:
|
||||
if task[0] in task_set:
|
||||
continue
|
||||
self.tasks.append(task)
|
||||
self.tasks_not_empty.notify_all()
|
||||
Reference in New Issue
Block a user