[BugFix] fix cache transfer manager updating/clearing (#5930)

* [fix] fix cache transfer manager updating/clearing

* [fix] fix code style

* [fix] fix config

* [fix] fix engine client

* [fix] let worker update kv cache status signal

* [fix] update worker process

* [fix] fix clear/update for case if comm group is shutdown

* [fix] update dynamic weight manager

* [fix] fix port

* [fix] add num_cpu_blocks arg for async_llm, and remove unnecessary waiting
This commit is contained in:
Yonghua Li
2026-01-13 21:09:29 +08:00
committed by GitHub
parent 6da06abc17
commit 456637002d
8 changed files with 165 additions and 74 deletions
+44 -9
View File
@@ -230,6 +230,16 @@ class PaddleDisWorkerProc:
create=False,
)
# init kv_cache_status
kv_cache_status_data = np.zeros(shape=[1], dtype=np.int32)
self.kv_cache_status = IPCSignal(
name="kv_cache_status",
array=kv_cache_status_data,
dtype=np.int32,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
# init exist_task_signal
workers_exist_task = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
@@ -421,8 +431,7 @@ class PaddleDisWorkerProc:
self._run_eplb(tp_rank)
if self.fd_config.load_config.dynamic_load_weight:
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.ranks > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
@@ -455,8 +464,10 @@ class PaddleDisWorkerProc:
)
self.model_weights_status.value[0] = self.model_weights_signal[0]
self.kv_cache_status.value[0] = self.model_weights_signal[0]
DynamicWeightManager.check_model_weights_status(
self.model_weights_status,
self.kv_cache_status if self.fd_config.cache_config.num_cpu_blocks > 0 else None,
# model_weights_signal
self.worker.model_runner,
self.parallel_config.local_engine_worker_queue_port,
@@ -464,14 +475,31 @@ class PaddleDisWorkerProc:
)
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
self.task_queue.clear_data()
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
# 只有不关闭通信组时,清理权重后需要额外等待(否则信号量会同步混乱)
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
time.sleep(0.01)
continue
if self.model_weights_signal[0] == ModelWeightsStatus.UPDATING:
logger.info(
f"Rank: {self.local_rank} has updated parameters. {self.model_weights_status.value[0]}"
)
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
elif self.model_weights_signal[0] == ModelWeightsStatus.CLEARING:
logger.info(
f"Rank: {self.local_rank} has cleared parameters. {self.model_weights_status.value[0]}"
)
# 如果清理权重后不关闭通信组,那么将推理进程统一阻塞在下面的循环中,否则信号量可能同步混乱;直到下次权重更新时唤醒
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
if self.ranks > 1: # 所有 Rank 同时入睡,监听下次的更新信号
paddle.distributed.barrier()
while self.model_weights_signal[0] != ModelWeightsStatus.UPDATING:
self.model_weights_signal[0] = self.model_weights_status.value[0]
if self.ranks > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
src=0, group=None
)
time.sleep(1)
self.model_weights_status.value[0] = (
ModelWeightsStatus.UPDATING
) # 所有 Rank 已同步唤醒,启动权重更新流程
continue
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
@@ -906,6 +934,13 @@ def parse_args():
help="Enable output of token-level entropy.",
)
parser.add_argument(
"--num_cpu_blocks",
type=int,
default=0,
help="Number of cpu blocks.",
)
args = parser.parse_args()
return args