mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix cache transfer manager updating/clearing (#5930)
* [fix] fix cache transfer manager updating/clearing * [fix] fix code style * [fix] fix config * [fix] fix engine client * [fix] let worker update kv cache status signal * [fix] update worker process * [fix] fix clear/update for case if comm group is shutdown * [fix] update dynamic weight manager * [fix] fix port * [fix] add num_cpu_blocks arg for async_llm, and remove unnecessary waiting
This commit is contained in:
@@ -178,7 +178,7 @@ class CacheTransferManager:
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_suffix,
|
||||
suffix=args.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
swap_space_ready_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
@@ -186,7 +186,7 @@ class CacheTransferManager:
|
||||
name="swap_space_ready_signal",
|
||||
array=swap_space_ready_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_suffix,
|
||||
suffix=args.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -201,7 +201,7 @@ class CacheTransferManager:
|
||||
name="cache_task_broadcast_signal",
|
||||
array=cache_task_broadcast_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.ipc_suffix,
|
||||
suffix=args.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -230,7 +230,15 @@ class CacheTransferManager:
|
||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||
self.write_policy = args.write_policy
|
||||
|
||||
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
|
||||
# Initialize update/clear signals for RL
|
||||
self.kv_cache_status_signal = IPCSignal(
|
||||
name="kv_cache_status",
|
||||
array=np.zeros([1], dtype=np.int32),
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
|
||||
|
||||
def _init_storage_buffer(self, args):
|
||||
"""
|
||||
@@ -951,29 +959,19 @@ class CacheTransferManager:
|
||||
task_cpu_block_id,
|
||||
)
|
||||
|
||||
def clear_or_update_caches(self, args):
|
||||
def check_cache_status(self, args):
|
||||
# TODO XPU support RL
|
||||
if unset_data_ipc is None:
|
||||
return
|
||||
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
|
||||
logger.info(f"FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}")
|
||||
kv_cache_status = np.zeros([1], dtype=np.int32)
|
||||
kv_cache_status_signal = IPCSignal(
|
||||
name="kv_cache_status",
|
||||
array=kv_cache_status,
|
||||
dtype=np.int32,
|
||||
suffix=self.ipc_suffix,
|
||||
create=False,
|
||||
)
|
||||
while True:
|
||||
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
|
||||
# handle cache clearing/restoring
|
||||
if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
|
||||
assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache."
|
||||
try:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] Start clearing caches {self.cache_ready_signal.value}"
|
||||
)
|
||||
logger.info(f"Start clearing caches {self.cache_ready_signal.value}")
|
||||
# clear cpu caches
|
||||
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
paddle.set_device("cpu")
|
||||
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
|
||||
cuda_host_free(ptrs)
|
||||
@@ -996,49 +994,43 @@ class CacheTransferManager:
|
||||
|
||||
# reset cache_ready_signal
|
||||
self.cache_ready_signal.value[self.rank] = 0
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] Finish clearing caches {self.cache_ready_signal.value}"
|
||||
)
|
||||
logger.info(f"Finish clearing caches {self.cache_ready_signal.value}")
|
||||
|
||||
# wait for all ranks caches to be cleared
|
||||
if np.sum(self.cache_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
|
||||
# reset kv_cache_status_signal
|
||||
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
|
||||
logger.info("All ranks finish clearing caches")
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
|
||||
logger.info(f"All ranks finish clearing caches {self.cache_ready_signal.value}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to clear caches: {e}")
|
||||
logger.error(f"Failed to clear caches: {e}")
|
||||
|
||||
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
|
||||
elif self.kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
|
||||
assert args.splitwise_role == "mixed", "Only mixed mode supports updating cache."
|
||||
try:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] Start restoring caches {self.cache_ready_signal.value}"
|
||||
)
|
||||
logger.info(f"Start restoring caches {self.cache_ready_signal.value}")
|
||||
# restore cpu cache
|
||||
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
self._init_cpu_cache(args)
|
||||
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
|
||||
# restore gpu cache and set cache_ready_signal
|
||||
self._init_gpu_cache(args)
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] Finish restoring caches {self.cache_ready_signal.value}"
|
||||
)
|
||||
logger.info(f"Finish restoring caches {self.cache_ready_signal.value}")
|
||||
|
||||
# wait for all ranks caches to be ready
|
||||
while np.sum(self.cache_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
|
||||
# set kv_cache_status_signal
|
||||
logger.info("All ranks finish restoring caches")
|
||||
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
|
||||
logger.info(f"All ranks finish restoring caches {self.cache_ready_signal.value}")
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to restore caches: {e}")
|
||||
logger.error(f"Failed to restore caches: {e}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@@ -1335,6 +1335,7 @@ class CacheConfig:
|
||||
self.disable_chunked_mm_input = False
|
||||
self.kvcache_storage_backend = None
|
||||
self.write_policy = None
|
||||
self.num_cpu_blocks = None
|
||||
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
@@ -1380,10 +1381,12 @@ class CacheConfig:
|
||||
* byte_size
|
||||
)
|
||||
|
||||
if self.swap_space is None:
|
||||
self.num_cpu_blocks = 0
|
||||
else:
|
||||
self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block)
|
||||
if self.num_cpu_blocks is None:
|
||||
if self.swap_space is None:
|
||||
self.num_cpu_blocks = 0
|
||||
else:
|
||||
self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block)
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def metrics_info(self):
|
||||
|
||||
@@ -1707,6 +1707,7 @@ class EngineService:
|
||||
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
|
||||
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
|
||||
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
|
||||
f" --num_cpu_blocks {self.cfg.cache_config.num_cpu_blocks}"
|
||||
)
|
||||
if self.cfg.structured_outputs_config.logits_processors is not None:
|
||||
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
|
||||
|
||||
@@ -574,6 +574,7 @@ class LLMEngine:
|
||||
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
|
||||
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
|
||||
f" --routing_replay_config '{self.cfg.routing_replay_config.to_json_string()}'"
|
||||
f" --num_cpu_blocks {self.cfg.cache_config.num_cpu_blocks}"
|
||||
)
|
||||
if self.cfg.structured_outputs_config.logits_processors is not None:
|
||||
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
|
||||
|
||||
@@ -34,6 +34,7 @@ from fastdeploy.eplb.utils import RedundantExpertWorkload
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import (
|
||||
IPCSignal,
|
||||
KVCacheStatus,
|
||||
ModelWeightsStatus,
|
||||
PrefixTreeStatus,
|
||||
RearrangeExpertStatus,
|
||||
@@ -79,6 +80,9 @@ class EngineClient:
|
||||
)
|
||||
self.max_model_len = self.fd_config.model_config.max_model_len
|
||||
self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching
|
||||
self.enable_cache_transfer = (
|
||||
self.fd_config.cache_config.swap_space or self.fd_config.cache_config.kvcache_storage_backend
|
||||
)
|
||||
self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed"
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
|
||||
@@ -528,17 +532,22 @@ class EngineClient:
|
||||
2 : worker update finish and notify client
|
||||
"""
|
||||
with self.clear_update_lock:
|
||||
if self.fd_config.cache_config.swap_space:
|
||||
return False, "hierarchical cache updating is not supported"
|
||||
|
||||
if self.enable_prefix_caching:
|
||||
# prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL
|
||||
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
|
||||
api_server_logger.info(f"Start to update prefix tree {self.prefix_tree_status_signal.value[0]}")
|
||||
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
api_server_logger.info(f"..updating prefix tree {self.prefix_tree_status_signal.value[0]}")
|
||||
api_server_logger.info(
|
||||
f">>> start updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0 and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
api_server_logger.info(f"... prefix tree status: {self.prefix_tree_status_signal.value[0]}")
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return False, "Update prefix tree timeout"
|
||||
api_server_logger.info(
|
||||
f"<<< finish updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
|
||||
# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
|
||||
@@ -549,13 +558,30 @@ class EngineClient:
|
||||
return False, "worker is clearing model weight, cannot update now"
|
||||
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
|
||||
api_server_logger.info(f"Start to update model weight {self.model_weights_status_signal.value[0]}")
|
||||
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.NORMAL:
|
||||
api_server_logger.info(f"..updating model weights {self.model_weights_status_signal.value[0]}")
|
||||
api_server_logger.info(
|
||||
f">>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
|
||||
if not self.enable_cache_transfer
|
||||
else f">>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0:
|
||||
api_server_logger.info(
|
||||
f"... weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f"... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
|
||||
)
|
||||
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
|
||||
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
|
||||
if weight_updated and (not self.enable_cache_transfer or cache_updated):
|
||||
break
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return False, "Update model weight timeout"
|
||||
api_server_logger.info(
|
||||
f"<<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f"<<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
return True, ""
|
||||
|
||||
def clear_load_weight(self, timeout=300):
|
||||
@@ -566,17 +592,22 @@ class EngineClient:
|
||||
"""
|
||||
|
||||
with self.clear_update_lock:
|
||||
if self.fd_config.cache_config.swap_space:
|
||||
return False, "hierarchical cache clearing is not supported"
|
||||
|
||||
if self.enable_prefix_caching:
|
||||
# prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED
|
||||
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
|
||||
api_server_logger.info(f"Start to clear prefix tree {self.prefix_tree_status_signal.value[0]}")
|
||||
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
|
||||
api_server_logger.info(f"..clearing prefix tree {self.prefix_tree_status_signal.value[0]}")
|
||||
api_server_logger.info(
|
||||
f">>> start clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0 and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
|
||||
api_server_logger.info(f"... prefix tree status: {self.prefix_tree_status_signal.value[0]}")
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return False, "Clear prefix tree timeout"
|
||||
api_server_logger.info(
|
||||
f"<<< finish clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
|
||||
# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
|
||||
@@ -587,13 +618,30 @@ class EngineClient:
|
||||
return False, "worker is updating model weight, cannot clear now"
|
||||
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
|
||||
api_server_logger.info(f"Start to clear model weight {self.model_weights_status_signal.value[0]}")
|
||||
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.CLEARED:
|
||||
api_server_logger.info(f"..clearing model weights {self.model_weights_status_signal.value[0]}")
|
||||
api_server_logger.info(
|
||||
f">>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f">>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0:
|
||||
api_server_logger.info(
|
||||
f"... weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f"... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
|
||||
)
|
||||
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
|
||||
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
|
||||
if weight_cleared and (not self.enable_cache_transfer or cache_cleared):
|
||||
break
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return False, "Clear model weight timeout"
|
||||
api_server_logger.info(
|
||||
f"<<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f"<<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
return True, ""
|
||||
|
||||
def check_model_weight_status(self):
|
||||
|
||||
@@ -24,7 +24,7 @@ import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.inter_communicator import ModelWeightsStatus
|
||||
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus
|
||||
|
||||
|
||||
class DynamicWeightManager:
|
||||
@@ -258,18 +258,25 @@ class DynamicWeightManager:
|
||||
value[self.rank] = status
|
||||
|
||||
@staticmethod
|
||||
def check_model_weights_status(model_weights_status, model_runner, pid, block):
|
||||
def check_model_weights_status(model_weights_status, kv_cache_status, model_runner, pid, block):
|
||||
"""
|
||||
check model weights status
|
||||
A function to handle the state of model weights, check the model weights state,
|
||||
and perform corresponding operations as needed.
|
||||
|
||||
- model_weights_status (`IPCSignal`): The signal indicating the status of model weights.
|
||||
- kv_cache_status (`IPCSignal`): The signal indicating the status of key-value cache.
|
||||
- model_runner (`ModelRunnerBase`): The model runner instance.
|
||||
- block (`bool`): Block mode keeps the worker process blocked in the status-check loop,
|
||||
avoiding communication operations in the worker event loop.
|
||||
"""
|
||||
# logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
|
||||
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL and (
|
||||
block or model_weights_status.value[0] != ModelWeightsStatus.CLEARED
|
||||
):
|
||||
# 如果为 block 模式,那么循环不会退出,直到权重更新、通信组重建
|
||||
# 如果为非 block 模式,那么循环在权重更新或清理后均会退出
|
||||
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
|
||||
logger.info("infer engine stopped! start to load new checkpoint...")
|
||||
if kv_cache_status:
|
||||
kv_cache_status.value[0] = KVCacheStatus.UPDATING
|
||||
model_runner.clear_requests()
|
||||
model_runner.update_parameters(pid)
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||
@@ -277,9 +284,12 @@ class DynamicWeightManager:
|
||||
logger.info("finished loading new checkpoint")
|
||||
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
|
||||
logger.info("infer engine stopped! start to clear checkpoint...")
|
||||
if kv_cache_status:
|
||||
kv_cache_status.value[0] = KVCacheStatus.CLEARING
|
||||
model_runner.clear_requests()
|
||||
model_runner.clear_parameters(pid)
|
||||
while model_weights_status.value[0] != ModelWeightsStatus.CLEARED:
|
||||
time.sleep(0.01)
|
||||
logger.info("finished clearing checkpoint")
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
@@ -230,6 +230,16 @@ class PaddleDisWorkerProc:
|
||||
create=False,
|
||||
)
|
||||
|
||||
# init kv_cache_status
|
||||
kv_cache_status_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.kv_cache_status = IPCSignal(
|
||||
name="kv_cache_status",
|
||||
array=kv_cache_status_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
# init exist_task_signal
|
||||
workers_exist_task = np.zeros([1], dtype=np.int32)
|
||||
self.exist_task_signal = IPCSignal(
|
||||
@@ -421,8 +431,7 @@ class PaddleDisWorkerProc:
|
||||
self._run_eplb(tp_rank)
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
if self.ranks > 1:
|
||||
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
|
||||
|
||||
@@ -455,8 +464,10 @@ class PaddleDisWorkerProc:
|
||||
)
|
||||
|
||||
self.model_weights_status.value[0] = self.model_weights_signal[0]
|
||||
self.kv_cache_status.value[0] = self.model_weights_signal[0]
|
||||
DynamicWeightManager.check_model_weights_status(
|
||||
self.model_weights_status,
|
||||
self.kv_cache_status if self.fd_config.cache_config.num_cpu_blocks > 0 else None,
|
||||
# model_weights_signal
|
||||
self.worker.model_runner,
|
||||
self.parallel_config.local_engine_worker_queue_port,
|
||||
@@ -464,14 +475,31 @@ class PaddleDisWorkerProc:
|
||||
)
|
||||
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
|
||||
self.task_queue.clear_data()
|
||||
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
||||
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
||||
|
||||
# 只有不关闭通信组时,清理权重后需要额外等待(否则信号量会同步混乱)
|
||||
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
|
||||
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
if self.model_weights_signal[0] == ModelWeightsStatus.UPDATING:
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank} has updated parameters. {self.model_weights_status.value[0]}"
|
||||
)
|
||||
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
||||
elif self.model_weights_signal[0] == ModelWeightsStatus.CLEARING:
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank} has cleared parameters. {self.model_weights_status.value[0]}"
|
||||
)
|
||||
# 如果清理权重后不关闭通信组,那么将推理进程统一阻塞在下面的循环中,否则信号量可能同步混乱;直到下次权重更新时唤醒
|
||||
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
|
||||
if self.ranks > 1: # 所有 Rank 同时入睡,监听下次的更新信号
|
||||
paddle.distributed.barrier()
|
||||
while self.model_weights_signal[0] != ModelWeightsStatus.UPDATING:
|
||||
self.model_weights_signal[0] = self.model_weights_status.value[0]
|
||||
if self.ranks > 1:
|
||||
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
|
||||
src=0, group=None
|
||||
)
|
||||
time.sleep(1)
|
||||
self.model_weights_status.value[0] = (
|
||||
ModelWeightsStatus.UPDATING
|
||||
) # 所有 Rank 已同步唤醒,启动权重更新流程
|
||||
continue
|
||||
|
||||
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||
@@ -906,6 +934,13 @@ def parse_args():
|
||||
help="Enable output of token-level entropy.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_cpu_blocks",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of cpu blocks.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--model "$MODEL_PATH" \
|
||||
--no-shutdown-comm-group-if-worker-idle \
|
||||
--swap-space 10 \
|
||||
--load-strategy ipc_snapshot \
|
||||
--dynamic-load-weight &
|
||||
|
||||
|
||||
Reference in New Issue
Block a user