[Optimization] Optimize ttft for prefill pd (#6680)

* optimize ttft

* fix

* fix

* fix ci

* fix ci

* fix

* fix bug

* fix

* add comments

* fix ci

* fix

* fix ci

* fix format

* update according to review

* add comment

* fix

* fix format
This commit is contained in:
chenjian
2026-03-30 20:36:23 +08:00
committed by GitHub
parent 05f2d95729
commit 6727df8286
11 changed files with 134 additions and 135 deletions
+54 -20
View File
@@ -287,6 +287,19 @@ class PaddleDisWorkerProc:
create=False,
)
# init engine forward signal
# If engine is being forward, engine_forward_signal_data should be 1.
# If engine is out of forward, engine_forward_signal_data should be 0.
# In pd disaggregation + EP parallel, only when engine is out of forward, scheduler send next batch to worker.
# When engine is out of forward, engine_forward_signal_data must be 0, otherwise scheduler will not schedule next batch.
engine_forward_signal_data = np.zeros([1], dtype=np.int32)
self.engine_forward_signal = IPCSignal(
name="engine_forward_signal",
array=engine_forward_signal_data,
dtype=np.int32,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
# gpu_cache_lock: file-based lock for mutual exclusion between worker
# and CPU transfer when accessing GPU KV cache.
self.gpu_cache_lock = IPCLock(
@@ -481,9 +494,6 @@ class PaddleDisWorkerProc:
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True:
# run eplb
self._run_eplb(tp_rank)
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.ranks > 1:
@@ -561,7 +571,7 @@ class PaddleDisWorkerProc:
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.engine_forward_signal.value[0] = 1
tasks, read_finish = self.task_queue.get_tasks()
# Only one of all tp_size client will get read_finish == True.
if read_finish:
@@ -570,25 +580,39 @@ class PaddleDisWorkerProc:
self.task_queue.read_finish_flag.set(0)
else:
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
# In EP parallel(corresponing to dp attention), we need to barrier for prefill to prevent data imbalance due to inconsistent data arrival.
# Only EP + DP prefill should barrier for data arrival.
# In mixed mode and decoder in D, we should not barrier to influence decoding.
if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill":
paddle.distributed.barrier(self.parallel_config.ep_group)
req_dicts, control_reqs = [], []
for req_dict, bsz in tasks:
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
control_reqs.append(req_dict[0])
else:
max_occupied_batch_index = int(bsz)
req_dicts.extend(req_dict)
# todo: run control request async
if len(control_reqs) > 0:
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
for control_req in control_reqs:
if self.parallel_config.use_ep:
self.cached_control_reqs.append(control_req)
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
assert (
len(tasks) > 0
), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}"
# In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue.
# tasks[0] contains two part, ([req1, ...] ,real_bsz)
# tasks[0][0] is [req1, ...]
# if empty batch is delived, eval(tasks[0][0]) should be False ([]),
# if batch with requests is delived, eval(tasks[0][0]) should be True, then to be processed as below.
if tasks[0][0]:
for req_dict, bsz in tasks:
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
control_reqs.append(req_dict[0])
else:
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
max_occupied_batch_index = int(bsz)
req_dicts.extend(req_dict)
# todo: run control request async
if len(control_reqs) > 0:
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
for control_req in control_reqs:
if self.parallel_config.use_ep:
self.cached_control_reqs.append(control_req)
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
else:
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
if len(req_dicts) > 0:
# Count prefill requests in current batch
@@ -604,6 +628,12 @@ class PaddleDisWorkerProc:
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
else:
if self.scheduler_config.splitwise_role == "prefill":
if tp_size > 1:
# Synchronize the signal for other workers
self._tp_barrier_wait()
continue
# Let the ep group run control method synchronically
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS and self.parallel_config.use_ep:
@@ -618,6 +648,7 @@ class PaddleDisWorkerProc:
and not self.worker.model_runner.not_need_stop()
):
self._tp_barrier_wait() if tp_size > 1 else None
self.engine_forward_signal.value[0] = 0
time.sleep(0.001)
continue
@@ -642,6 +673,9 @@ class PaddleDisWorkerProc:
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
# run eplb
self._run_eplb(tp_rank)
self.engine_forward_signal.value[0] = 0
if (
not self.parallel_config.use_ep