mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support block scheduler v1 for FD (#2928)
* Support FD block scheduler v1 * Support FD block scheduler v1 * Support FD block scheduler v1 * Fix according to copilot review * Fix according to review * Remove is_dummy * Fix bug when real_bsz=1 * Fix infer first token cost time --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,7 @@ import time
|
||||
import traceback
|
||||
import uuid
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -40,6 +41,7 @@ from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.expert_service import start_expert_service
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import (
|
||||
EngineCacheQueue,
|
||||
@@ -52,7 +54,7 @@ from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
@@ -108,7 +110,18 @@ class LLMEngine:
|
||||
|
||||
self.start_queue_service()
|
||||
|
||||
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.resource_manager = ResourceManagerV1(
|
||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
||||
)
|
||||
if cfg.splitwise_role != "mixed":
|
||||
raise NotImplementedError(
|
||||
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
|
||||
)
|
||||
else:
|
||||
self.resource_manager = ResourceManager(
|
||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
||||
)
|
||||
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
|
||||
|
||||
@@ -203,7 +216,10 @@ class LLMEngine:
|
||||
|
||||
self.token_processor.tasks_queue = self.engine_worker_queue
|
||||
|
||||
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
|
||||
else:
|
||||
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
|
||||
self.insert_task_to_worker_thread.start()
|
||||
|
||||
if self.api_server_pid is not None:
|
||||
@@ -343,6 +359,56 @@ class LLMEngine:
|
||||
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
|
||||
llm_logger.error(err_msg)
|
||||
|
||||
def _scheduler_task_to_worker_v1(self):
|
||||
"""
|
||||
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
|
||||
"""
|
||||
get_request_pool = ThreadPoolExecutor(max_workers=1)
|
||||
is_fetching = False
|
||||
|
||||
def _fetch_request():
|
||||
nonlocal is_fetching
|
||||
is_fetching = True
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
self.cfg.max_prefill_batch,
|
||||
)
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=self.resource_manager.available_block_num(),
|
||||
block_size=self.cfg.cache_config.block_size,
|
||||
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
||||
max_num_batched_tokens=self.cfg.max_model_len,
|
||||
batch=num_prefill_batch,
|
||||
)
|
||||
# Fetch requests and add them to the scheduling queue
|
||||
for task in tasks:
|
||||
self.resource_manager.add_request(task)
|
||||
is_fetching = False
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if self.engine_worker_queue.num_tasks() > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if (
|
||||
len(self.resource_manager.waiting) == 0
|
||||
and (not is_fetching)
|
||||
and self.exist_prefill_task_signal.value[0] == 0
|
||||
):
|
||||
get_request_pool.submit(_fetch_request)
|
||||
# 2. Schedule requests
|
||||
tasks = self.resource_manager.schedule()
|
||||
# 3. Send to engine
|
||||
if tasks:
|
||||
self.resource_manager.get_real_bsz()
|
||||
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
||||
else:
|
||||
time.sleep(0.005)
|
||||
|
||||
except Exception as e:
|
||||
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
|
||||
llm_logger.error(err_msg)
|
||||
|
||||
def _insert_zmq_task_to_scheduler(self):
|
||||
if self.api_server_pid is None:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user