mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] Optimize ttft for prefill pd (#6680)
* optimize ttft * fix * fix * fix ci * fix ci * fix * fix bug * fix * add comments * fix ci * fix * fix ci * fix format * update according to review * add comment * fix * fix format
This commit is contained in:
@@ -287,6 +287,19 @@ class PaddleDisWorkerProc:
|
||||
create=False,
|
||||
)
|
||||
|
||||
# init engine forward signal
|
||||
# If engine is being forward, engine_forward_signal_data should be 1.
|
||||
# If engine is out of forward, engine_forward_signal_data should be 0.
|
||||
# In pd disaggregation + EP parallel, only when engine is out of forward, scheduler send next batch to worker.
|
||||
# When engine is out of forward, engine_forward_signal_data must be 0, otherwise scheduler will not schedule next batch.
|
||||
engine_forward_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.engine_forward_signal = IPCSignal(
|
||||
name="engine_forward_signal",
|
||||
array=engine_forward_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
# gpu_cache_lock: file-based lock for mutual exclusion between worker
|
||||
# and CPU transfer when accessing GPU KV cache.
|
||||
self.gpu_cache_lock = IPCLock(
|
||||
@@ -481,9 +494,6 @@ class PaddleDisWorkerProc:
|
||||
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
|
||||
self.model_weights_signal = np.zeros([1], dtype=np.int32)
|
||||
while True:
|
||||
# run eplb
|
||||
self._run_eplb(tp_rank)
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
if self.ranks > 1:
|
||||
@@ -561,7 +571,7 @@ class PaddleDisWorkerProc:
|
||||
|
||||
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||
|
||||
self.engine_forward_signal.value[0] = 1
|
||||
tasks, read_finish = self.task_queue.get_tasks()
|
||||
# Only one of all tp_size client will get read_finish == True.
|
||||
if read_finish:
|
||||
@@ -570,25 +580,39 @@ class PaddleDisWorkerProc:
|
||||
self.task_queue.read_finish_flag.set(0)
|
||||
else:
|
||||
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
|
||||
# In EP parallel(corresponing to dp attention), we need to barrier for prefill to prevent data imbalance due to inconsistent data arrival.
|
||||
# Only EP + DP prefill should barrier for data arrival.
|
||||
# In mixed mode and decoder in D, we should not barrier to influence decoding.
|
||||
if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill":
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
|
||||
req_dicts, control_reqs = [], []
|
||||
for req_dict, bsz in tasks:
|
||||
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
|
||||
control_reqs.append(req_dict[0])
|
||||
else:
|
||||
max_occupied_batch_index = int(bsz)
|
||||
req_dicts.extend(req_dict)
|
||||
|
||||
# todo: run control request async
|
||||
if len(control_reqs) > 0:
|
||||
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
|
||||
for control_req in control_reqs:
|
||||
if self.parallel_config.use_ep:
|
||||
self.cached_control_reqs.append(control_req)
|
||||
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
|
||||
assert (
|
||||
len(tasks) > 0
|
||||
), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}"
|
||||
# In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue.
|
||||
# tasks[0] contains two part, ([req1, ...] ,real_bsz)
|
||||
# tasks[0][0] is [req1, ...]
|
||||
# if empty batch is delived, eval(tasks[0][0]) should be False ([]),
|
||||
# if batch with requests is delived, eval(tasks[0][0]) should be True, then to be processed as below.
|
||||
if tasks[0][0]:
|
||||
for req_dict, bsz in tasks:
|
||||
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
|
||||
control_reqs.append(req_dict[0])
|
||||
else:
|
||||
self.run_control_method(control_req)
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
max_occupied_batch_index = int(bsz)
|
||||
req_dicts.extend(req_dict)
|
||||
|
||||
# todo: run control request async
|
||||
if len(control_reqs) > 0:
|
||||
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
|
||||
for control_req in control_reqs:
|
||||
if self.parallel_config.use_ep:
|
||||
self.cached_control_reqs.append(control_req)
|
||||
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
|
||||
else:
|
||||
self.run_control_method(control_req)
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
|
||||
if len(req_dicts) > 0:
|
||||
# Count prefill requests in current batch
|
||||
@@ -604,6 +628,12 @@ class PaddleDisWorkerProc:
|
||||
|
||||
# Process prefill inputs
|
||||
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
|
||||
else:
|
||||
if self.scheduler_config.splitwise_role == "prefill":
|
||||
if tp_size > 1:
|
||||
# Synchronize the signal for other workers
|
||||
self._tp_barrier_wait()
|
||||
continue
|
||||
|
||||
# Let the ep group run control method synchronically
|
||||
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS and self.parallel_config.use_ep:
|
||||
@@ -618,6 +648,7 @@ class PaddleDisWorkerProc:
|
||||
and not self.worker.model_runner.not_need_stop()
|
||||
):
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
self.engine_forward_signal.value[0] = 0
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
@@ -642,6 +673,9 @@ class PaddleDisWorkerProc:
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
|
||||
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
|
||||
# run eplb
|
||||
self._run_eplb(tp_rank)
|
||||
self.engine_forward_signal.value[0] = 0
|
||||
|
||||
if (
|
||||
not self.parallel_config.use_ep
|
||||
|
||||
Reference in New Issue
Block a user