Revert "[Optimize] Optimize ttft for ep (#6098)" (#6402)

This reverts commit 90db0bdd0d.
This commit is contained in:
chenjian
2026-02-09 19:01:23 +08:00
committed by GitHub
parent d60daca4a8
commit 35c24f3f71
10 changed files with 142 additions and 118 deletions
+3
View File
@@ -147,6 +147,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to enable the decode caches requests for preallocating resource
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
# Batched token timeout in EP
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
# Max pre-fetch requests number in PD
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
+3
View File
@@ -147,6 +147,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# 是否启用 decode 缓存请求以预分配资源
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
# EP 中批处理 token 的超时时间
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
# PD 中最大预取请求数量
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
+18 -31
View File
@@ -313,15 +313,6 @@ class EngineService:
create=True,
)
engine_forward_signal_data = np.zeros([1], dtype=np.int32)
self.engine_forward_signal = IPCSignal(
name="engine_forward_signal",
array=engine_forward_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
@@ -984,23 +975,26 @@ class EngineService:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
try:
if not is_fetching:
# Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool.
try:
is_fetching = True
get_request_pool.submit(_fetch_request)
except RuntimeError as e:
if "shutdown" in str(e):
self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
break
else:
raise
# Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished.
# Once the forward pass finishes, these accumulated requests can be scheduled in larger,
# more efficient batches.
if not (self.engine_worker_queue.num_tasks() == 0 and self.engine_forward_signal.value[0] == 0):
if self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
continue
if self.cfg.scheduler_config.splitwise_role != "mixed":
if not is_fetching:
is_fetching = True
get_request_pool.submit(_fetch_request)
else:
if len(self.resource_manager.waiting) == 0 and (not is_fetching):
# Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool.
try:
is_fetching = True
get_request_pool.submit(_fetch_request)
except RuntimeError as e:
if "shutdown" in str(e):
self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
break
else:
raise
# 2. Schedule requests
tasks, error_tasks = self.resource_manager.schedule()
@@ -1050,13 +1044,6 @@ class EngineService:
else:
task.metrics.inference_start_time = time.time()
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
else:
# When there are no actual tasks to schedule, send an empty task batch to EP workers.
# This helps EP workers barrier for syncing tasks not hang.
if self.cfg.parallel_config.enable_expert_parallel:
self.engine_worker_queue.put_tasks(
([], self.resource_manager.real_bsz)
) # Empty (as idle tasks for ep)
# 4. Response error tasks
if error_tasks:
+2
View File
@@ -121,6 +121,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# Whether to enable the decode caches requests for preallocating resource
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
# Batched token timeout in EP
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
# Max pre-fetch requests number in PD
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
# Enable or disable model caching.
+44 -11
View File
@@ -23,7 +23,7 @@ from typing import Dict, List, Optional
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledResponse
from fastdeploy.scheduler.local_scheduler import LocalScheduler
from fastdeploy.utils import get_logger
from fastdeploy.utils import envs, get_logger
class DPLocalScheduler(LocalScheduler):
@@ -131,19 +131,52 @@ class DPLocalScheduler(LocalScheduler):
Returns:
List of Request objects ready for processing
"""
# DP scheduler is used in V1, there is no need to manage request fetching in the scheduler, resource_manager_v1 will do that.
if available_blocks <= reserved_output_blocks or batch < 1:
self.scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}"
)
return []
required_total_blocks = 0
current_prefill_tokens = 0
start_batch_time = time.time()
requests: List[Request] = []
with self.requests_not_empty:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
requests.append(request.raw)
self.ids_read_cursor += 1
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if current_prefill_tokens > max_num_batched_tokens:
break
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
if batch_ids:
if len(batch_ids) > 0 and len(requests) == 0:
self.scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
)
if len(requests) > 0:
self.scheduler_logger.info(
@@ -53,9 +53,6 @@ class InternalAdapter:
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
available_block_num = self.engine.resource_manager.available_block_num()
unhandled_request_num = self.engine.scheduler.get_unhandled_request_num()
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
unhandled_request_num = max(unhandled_request_num, len(self.engine.resource_manager.waiting))
server_info = {
"splitwise_role": self.cfg.scheduler_config.splitwise_role,
"block_size": int(self.cfg.cache_config.block_size),
@@ -65,7 +62,7 @@ class InternalAdapter:
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
"max_batch_size": int(available_batch_size),
"max_input_token_num": self.cfg.model_config.max_model_len,
"unhandled_request_num": unhandled_request_num,
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
"available_batch": int(self.engine.resource_manager.available_batch()),
}
return server_info
+29 -52
View File
@@ -283,16 +283,6 @@ class PaddleDisWorkerProc:
create=False,
)
# init engine forward signal
engine_forward_signal_data = np.zeros([1], dtype=np.int32)
self.engine_forward_signal = IPCSignal(
name="engine_forward_signal",
array=engine_forward_signal_data,
dtype=np.int32,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
def update_weights_from_tensor(self, mmap_infos):
"""
update_weights_from_tensor
@@ -450,6 +440,9 @@ class PaddleDisWorkerProc:
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True:
# run eplb
self._run_eplb(tp_rank)
if self.fd_config.load_config.dynamic_load_weight:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.ranks > 1:
@@ -523,7 +516,7 @@ class PaddleDisWorkerProc:
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.engine_forward_signal.value[0] = 1
tasks, read_finish = self.task_queue.get_tasks()
# Only one of all tp_size client will get read_finish == True.
if read_finish:
@@ -532,48 +525,35 @@ class PaddleDisWorkerProc:
self.task_queue.read_finish_flag.set(0)
else:
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
# In EP parallel(corresponing to dp attention), we need to barrier for prefill to prevent data imbalance due to inconsistent data arrival.
# Only EP + DP prefill should barrier for data arrival.
# In mixed mode and decoder in D, we should not barrier to influence decoding.
if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill":
paddle.distributed.barrier(self.parallel_config.ep_group)
req_dicts, control_reqs = [], []
# In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue.
if tasks[0][0]:
for req_dict, bsz in tasks:
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
control_reqs.append(req_dict[0])
else:
max_occupied_batch_index = int(bsz)
req_dicts.extend(req_dict)
for req_dict, bsz in tasks:
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
control_reqs.append(req_dict[0])
else:
max_occupied_batch_index = int(bsz)
req_dicts.extend(req_dict)
# todo: run control request async
if len(control_reqs) > 0:
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
for control_req in control_reqs:
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
# todo: run control request async
if len(control_reqs) > 0:
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
for control_req in control_reqs:
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
# Count prefill requests in current batch
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
num_scheduled_requests = len(req_dicts)
scheduled_request_ids = [req.request_id for req in req_dicts]
logger.info(
f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, "
f"max_occupied_batch_index: {max_occupied_batch_index}, "
f"num_scheduled_requests: {num_scheduled_requests}, "
f"scheduled_request_ids: {scheduled_request_ids}"
)
# Count prefill requests in current batch
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
num_scheduled_requests = len(req_dicts)
scheduled_request_ids = [req.request_id for req in req_dicts]
logger.info(
f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, "
f"max_occupied_batch_index: {max_occupied_batch_index}, "
f"num_scheduled_requests: {num_scheduled_requests}, "
f"scheduled_request_ids: {scheduled_request_ids}"
)
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
else:
if self.scheduler_config.splitwise_role == "prefill":
if tp_size > 1:
# Synchronize the signal for other workers
self._tp_barrier_wait()
continue
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
if (
(not self.parallel_config.use_ep)
@@ -581,7 +561,7 @@ class PaddleDisWorkerProc:
and (not self.enable_overlap_schedule)
):
self._tp_barrier_wait() if tp_size > 1 else None
self.engine_forward_signal.value[0] = 0
time.sleep(0.001)
continue
@@ -593,9 +573,6 @@ class PaddleDisWorkerProc:
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
# run eplb
self._run_eplb(tp_rank)
self.engine_forward_signal.value[0] = 0
def initialize_kv_cache(self) -> None:
"""Profiles the peak memory usage of the model to determine how many
+16 -17
View File
@@ -214,29 +214,28 @@ def test_metrics_with_clear_and_reset():
"""
Test the metrics monitoring endpoint.
"""
pass # not stable, uncomment after bug fix
# metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
# async_concurrency(n=10)
async_concurrency(n=10)
# time.sleep(0.3)
time.sleep(0.3)
# ===== clear_load_weight =====
# clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight"
# print("Calling clear_load_weight...")
# r = requests.get(clear_url, timeout=30)
# assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}"
clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight"
print("Calling clear_load_weight...")
r = requests.get(clear_url, timeout=30)
assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}"
# metrics = get_metrics_dict(metrics_url)
# running = metrics["fastdeploy:num_requests_running"]
# waiting = metrics["fastdeploy:num_requests_waiting"]
metrics = get_metrics_dict(metrics_url)
running = metrics["fastdeploy:num_requests_running"]
waiting = metrics["fastdeploy:num_requests_waiting"]
# print(
# "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
# running,
# "waiting:",
# waiting,
# )
print(
"ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
running,
"waiting:",
waiting,
)
# assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight"
+26
View File
@@ -411,6 +411,32 @@ class TestDPLocalScheduler(unittest.TestCase):
self.assertEqual(scheduler.ids, ["fresh_req"])
self.assertEqual(scheduler.ids_read_cursor, 1)
def test_get_requests_insufficient_resources(self):
"""Test getting requests when resources are insufficient."""
mock_logger.reset_mock()
# Test with insufficient blocks - mock the condition variable to avoid threading issues
with patch.object(self.scheduler, "requests_not_empty"):
requests = self.scheduler.get_requests(
available_blocks=5, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1
)
self.assertEqual(requests, [])
# The logger should have been called for insufficient resources
self.assertTrue(mock_logger.debug.called)
# Check the message contains expected content
call_args = mock_logger.debug.call_args[0][0]
self.assertIn("insufficient", call_args.lower())
def test_get_requests_insufficient_batch(self):
"""Test getting requests when batch size is insufficient."""
with patch.object(self.scheduler, "requests_not_empty"):
requests = self.scheduler.get_requests(
available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=0
)
self.assertEqual(requests, [])
@patch("time.time")
@patch.object(dp_scheduler_module, "envs")
def test_get_requests_no_requests_available(self, mock_envs, mock_time):
@@ -25,9 +25,6 @@ class DummyEngine:
"""Dummy Engine class to simulate the actual Engine for testing."""
class ResourceManager:
def __init__(self):
self.waiting = []
def available_batch(self):
return 4