[Feature] [PD Disaggregation] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports (#5415)

* [feat] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports

* [fix] fix some bugs

* [fix] fix rdma port for cache manager/messager

* [fix] temporarily cancel port availability check to see if it can pass ci test

* [feat] simplify args for multi api server

* [fix] fix dp

* [fix] fix port for xpu

* [fix] add tests for ports post processing & fix ci

* [test] fix test_multi_api_server

* [fix] fix rdma_comm_ports args for multi_api_server

* [fix] fix test_common_engine

* [fix] fix test_cache_transfer_manager

* [chore] automatically setting FD_ENABLE_MULTI_API_SERVER

* [fix] avoid api server from creating engine_args twice

* [fix] fix test_run_batch

* [fix] fix test_metrics

* [fix] fix splitwise connector init

* [test] add test_rdma_transfer and test_expert_service

* [fix] fix code syntax

* [fix] fix test_rdma_transfer and build wheel with rdma script
This commit is contained in:
Yonghua Li
2025-12-17 15:50:42 +08:00
committed by GitHub
parent cdc0004894
commit 0c8c6369ed
34 changed files with 1323 additions and 409 deletions
+63 -86
View File
@@ -81,13 +81,6 @@ class EngineService:
"""
self.cfg = cfg
self.use_async_llm = use_async_llm
if cfg.scheduler_config.splitwise_role != "mixed" or cfg.cache_config.enable_prefix_caching:
if isinstance(self.cfg.cache_config.cache_queue_port, str):
self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port.split(",")
if isinstance(self.cfg.cache_config.cache_queue_port, list):
self.cfg.cache_config.cache_queue_port = int(
self.cfg.cache_config.cache_queue_port[self.cfg.parallel_config.local_data_parallel_id]
)
if self.cfg.parallel_config.data_parallel_size > 1:
self.llm_logger = get_logger(
@@ -120,9 +113,8 @@ class EngineService:
self.start_worker_queue_service(start_queue)
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.parallel_config.engine_worker_queue_port[
self.cfg.parallel_config.local_data_parallel_id
]
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.parallel_config.local_engine_worker_queue_port)
self.llm_logger.info(f"INFERENCE_MSG_QUEUE_ID: {str(self.cfg.parallel_config.local_engine_worker_queue_port)}")
self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager)
self.token_processor = TokenProcessor(
@@ -151,9 +143,7 @@ class EngineService:
self._init_worker_monitor_signals()
if self.cfg.eplb_config.enable_eplb:
current_suffix = int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
)
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
init_eplb_signals(cfg, current_suffix)
if self.use_async_llm:
@@ -211,7 +201,7 @@ class EngineService:
def check_worker_initialize_status_func(res: dict):
res["worker_is_alive"] = True
if not self.check_worker_initialize_status():
llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
self.llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
res["worker_is_alive"] = False
self.check_worker_initialize_status_func_thread = threading.Thread(
@@ -241,7 +231,7 @@ class EngineService:
# Worker launched
self.check_worker_initialize_status_func_thread.join()
if not result_container["worker_is_alive"]:
llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
self.llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
return False
# Start ZMQ service for communication with AsyncLLM
@@ -259,9 +249,7 @@ class EngineService:
self.data_processor = self.input_processor.create_processor()
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
current_suffix = int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
)
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
self.llm_logger.info(f"current_suffix: {current_suffix}")
exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
@@ -353,52 +341,43 @@ class EngineService:
"""
start queue service for engine worker communication
"""
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (
self.cfg.master_ip,
int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
),
)
address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port)
else:
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]}.sock"
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock"
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
# Dynamically updates the port value if an anonymous port is used
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = (
str(self.engine_worker_queue_server.get_server_port())
)
address = (
self.cfg.master_ip,
int(
self.cfg.parallel_config.engine_worker_queue_port[
self.cfg.parallel_config.local_data_parallel_id
]
),
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
if start_queue:
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
# Dynamically updates the port value if an anonymous port is used
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
self.cfg.parallel_config.local_engine_worker_queue_port = (
self.engine_worker_queue_server.get_server_port()
)
address = (
self.cfg.master_ip,
self.cfg.parallel_config.local_engine_worker_queue_port,
)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.llm_logger.info(
f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}"
)
self.cache_task_queue = EngineCacheQueue(
address=(
self.cfg.master_ip,
self.cfg.cache_config.cache_queue_port,
),
address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port),
authkey=b"cache_queue_service",
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
self.cfg.cache_config.cache_queue_port = self.cache_task_queue.get_server_port()
self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port()
self.engine_worker_queue = EngineWorkerQueue(
address=address,
@@ -753,7 +732,7 @@ class EngineService:
# so the same request sent by the decode api server will be ignored
continue
llm_logger.debug(f"get tasks from scheduler: {tasks}")
self.llm_logger.debug(f"get tasks from scheduler: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
for task in tasks:
task.metrics.ask_decode_resource_start_time = time.time()
@@ -965,7 +944,7 @@ class EngineService:
get_request_pool.submit(_fetch_request)
except RuntimeError as e:
if "shutdown" in str(e):
llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
break
else:
raise
@@ -1023,7 +1002,7 @@ class EngineService:
if error_tasks:
for request_id, failed in error_tasks:
if failed is None:
llm_logger.warning(f"Request {request_id} has no error, skip sending error response.")
self.llm_logger.warning(f"Request {request_id} has no error, skip sending error response.")
continue
self._send_error_response(request_id, failed)
@@ -1137,7 +1116,7 @@ class EngineService:
)
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
llm_logger.error(
self.llm_logger.error(
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
)
error_result = RequestOutput(
@@ -1200,7 +1179,7 @@ class EngineService:
elif content.finished:
new_step_contents.append(content)
else:
llm_logger.warning(
self.llm_logger.warning(
f"current tokens need to accumulate, req_id: {content.request_id} {content.outputs.token_ids}"
)
else:
@@ -1230,16 +1209,16 @@ class EngineService:
elif content.finished:
new_contents.append(content)
else:
llm_logger.warning(
self.llm_logger.warning(
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
)
else:
new_contents.append(content)
if len(new_contents):
llm_logger.debug(f"Send response for request id: {request_id}")
self.llm_logger.debug(f"Send response for request id: {request_id}")
self.send_response_server.send_response(request_id, new_contents)
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def _decode_process_splitwise_requests(self):
"""
@@ -1378,10 +1357,8 @@ class EngineService:
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
device_ids=device_ids,
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
),
pid_suffix=ipc_signal_suffix,
engine_worker_queue_port=self.cfg.parallel_config.local_engine_worker_queue_port,
ipc_suffix=ipc_signal_suffix,
create_cache_tensor=False,
)
@@ -1390,15 +1367,15 @@ class EngineService:
def clear_data(self):
try:
llm_logger.info("Clear Data: Start")
self.llm_logger.info("Clear Data: Start")
self.token_processor.clear_data()
self.engine_worker_queue.clear_data()
self.send_response_server.req_dict.clear()
self.recv_request_server.req_dict.clear()
llm_logger.info("Clear Data: Successfully")
self.llm_logger.info("Clear Data: Successfully")
return True
except Exception as e:
llm_logger.error(f"Clear data error: {e}")
self.llm_logger.error(f"Clear data error: {e}")
return False
def _register_to_router(self):
@@ -1424,18 +1401,18 @@ class EngineService:
)
if resp.ok:
llm_logger.info("Successfully registered to the router!")
self.llm_logger.info("Successfully registered to the router!")
break
else:
llm_logger.error(
self.llm_logger.error(
f"Router registration failed: {resp.status_code}, "
f"{resp.text}, {self.cfg.register_info}"
)
time.sleep(sleep_seconds)
except requests.exceptions.RequestException as e:
llm_logger.error(f"Register to router request error: {e}")
self.llm_logger.error(f"Register to router request error: {e}")
except Exception as e:
llm_logger.exception(f"Unexpected error during router registration: {e}")
self.llm_logger.exception(f"Unexpected error during router registration: {e}")
if self.cfg.router_config.router is not None:
register_thread = threading.Thread(target=_register, daemon=True)
@@ -1445,45 +1422,45 @@ class EngineService:
"""
exit sub services
"""
llm_logger.info("Exit sub services.....")
self.llm_logger.info("Exit sub services.....")
self.running = False
if self.use_async_llm:
# Clean up worker processes first (before closing multiprocessing services)
if hasattr(self, "worker_proc") and self.worker_proc is not None:
llm_logger.info("Cleaning up worker processes...")
self.llm_logger.info("Cleaning up worker processes...")
try:
pgid = os.getpgid(self.worker_proc.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e:
llm_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}")
self.llm_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}")
# Clean up cache manager processes
if hasattr(self, "cache_manager_processes"):
llm_logger.info("Cleaning up cache manager processes...")
self.llm_logger.info("Cleaning up cache manager processes...")
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
self.llm_logger.info(f"Killing cache manager process {p.pid}")
try:
pgid = os.getpgid(p.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e:
llm_logger.error(
self.llm_logger.error(
f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}"
)
if hasattr(self, "cache_task_queue") and self.cache_task_queue is not None:
llm_logger.info("Cleaning up cache_task_queue...")
self.llm_logger.info("Cleaning up cache_task_queue...")
# Check if cleanup method exists
if hasattr(self.cache_task_queue, "cleanup"):
self.cache_task_queue.cleanup()
elif hasattr(self.cache_task_queue, "manager"):
try:
llm_logger.info("Shutting down cache_task_queue manager...")
self.llm_logger.info("Shutting down cache_task_queue manager...")
self.cache_task_queue.manager.shutdown()
except Exception as e:
llm_logger.warning(f"Error shutting down cache_task_queue manager: {e}")
self.llm_logger.warning(f"Error shutting down cache_task_queue manager: {e}")
if hasattr(self, "get_profile_block_num_signal"):
self.get_profile_block_num_signal.clear()
@@ -1494,7 +1471,7 @@ class EngineService:
# Clean up other services
if hasattr(self, "dp_processed"):
for p in self.dp_processed:
llm_logger.info(f"Waiting for worker {p.pid} to exit")
self.llm_logger.info(f"Waiting for worker {p.pid} to exit")
p.join()
for p in self.dp_engine_worker_queue_server:
p.cleanup()
@@ -1662,13 +1639,13 @@ class EngineService:
think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
if think_end_id > 0:
llm_logger.info(f"Get think_end_id {think_end_id} from vocab.")
self.llm_logger.info(f"Get think_end_id {think_end_id} from vocab.")
else:
llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
self.llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1)
line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1)
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
ports = ",".join(map(str, self.cfg.parallel_config.engine_worker_queue_port))
ips = None
if self.cfg.ips is not None:
ips = ",".join(self.cfg.ips)
@@ -1744,7 +1721,7 @@ class EngineService:
if self.cfg.nnode > 1:
pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}"
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
llm_logger.info(f"Launch worker service command: {pd_cmd}")
self.llm_logger.info(f"Launch worker service command: {pd_cmd}")
p = subprocess.Popen(
pd_cmd,
stdout=subprocess.PIPE,
@@ -1820,7 +1797,7 @@ class EngineService:
else:
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock"
llm_logger.info(f"dp start queue service {address}")
self.llm_logger.info(f"dp start queue service {address}")
self.dp_engine_worker_queue_server.append(
EngineWorkerQueue(
address=address,
@@ -1842,7 +1819,7 @@ class EngineService:
),
)
)
llm_logger.info(
self.llm_logger.info(
f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}"
+ f" data parallel id {i}"
)