mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] Update ZMQ server (#6735)
* add batch zmq send reaponse
* update
* Revert "update"
This reverts commit 0234a25b47.
* update
* remove lock
* fix unit test
* add unit test
* add unit test
* pre commit
* add unit test
* fix unit test
* add unit test
* fix worker>1
* update zmq_worker_pid
* fix unit test
* fix unit test
* fix unit test
* add unit test
* fix unit test
* fix first token time
* fix logprobs
* add unit test
* op
* remore debug log
---------
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import multiprocessing
|
||||
@@ -1117,13 +1118,25 @@ class EngineService:
|
||||
self.internal_adapter = InternalAdapter(
|
||||
cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id
|
||||
)
|
||||
# ROUTER mode: need to receive client handles
|
||||
self.recv_result_handle_thread = threading.Thread(
|
||||
target=self.send_response_server.recv_result_handle, daemon=True
|
||||
)
|
||||
self.recv_result_handle_thread.start()
|
||||
else:
|
||||
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
|
||||
self.recv_result_handle_thread = threading.Thread(
|
||||
target=self.send_response_server.recv_result_handle, daemon=True
|
||||
)
|
||||
self.recv_result_handle_thread.start()
|
||||
if envs.ZMQ_SEND_BATCH_DATA:
|
||||
# PUSH mode: batch send, no need to receive client handles
|
||||
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PUSH)
|
||||
# Mapping from request_id to worker_pid for routing batch responses
|
||||
self.request_worker_map = {}
|
||||
else:
|
||||
# ROUTER mode: per-query send, need to receive client handles
|
||||
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
|
||||
self.recv_result_handle_thread = threading.Thread(
|
||||
target=self.send_response_server.recv_result_handle, daemon=True
|
||||
)
|
||||
self.recv_result_handle_thread.start()
|
||||
time.sleep(3)
|
||||
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
|
||||
self.insert_task_to_scheduler_thread.start()
|
||||
@@ -1161,8 +1174,17 @@ class EngineService:
|
||||
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
|
||||
continue
|
||||
|
||||
# Extract zmq_worker_pid for per-worker PUSH routing.
|
||||
# Only needed when ZMQ_SEND_BATCH_DATA=True AND not using internal adapter,
|
||||
# because FD_ENABLE_INTERNAL_ADAPTER uses ROUTER (worker_pid is irrelevant).
|
||||
worker_pid = None
|
||||
if envs.ZMQ_SEND_BATCH_DATA and not envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
worker_pid = data["zmq_worker_pid"]
|
||||
|
||||
if ControlRequest.is_control_request(data):
|
||||
try: # todo: run control request async, do not block request generation
|
||||
if worker_pid is not None:
|
||||
self.request_worker_map[data.get("request_id")] = worker_pid
|
||||
control_req = ControlRequest.from_dict(data)
|
||||
self.run_control_method(control_req)
|
||||
except Exception as e:
|
||||
@@ -1175,6 +1197,11 @@ class EngineService:
|
||||
request, insert_task = data, []
|
||||
results: List[Tuple[str, Optional[str]]] = list()
|
||||
if data:
|
||||
# Store worker_pid mapping for normal/abort requests
|
||||
if worker_pid is not None:
|
||||
req_id_for_map = data.get("request_id")
|
||||
if req_id_for_map:
|
||||
self.request_worker_map[req_id_for_map] = worker_pid
|
||||
status_value = data.get("status", None)
|
||||
if status_value is not None and status_value == RequestStatus.ABORT.value:
|
||||
req_id = data["request_id"]
|
||||
@@ -1200,7 +1227,9 @@ class EngineService:
|
||||
if self.is_paused:
|
||||
self.llm_logger.warning(f"Engine is paused, drop request: {request}")
|
||||
self._send_error_response(
|
||||
request.request_id, "Request is aborted since LLM Engine is paused."
|
||||
request.request_id,
|
||||
"Request is aborted since LLM Engine is paused.",
|
||||
worker_pid=worker_pid,
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
@@ -1261,6 +1290,11 @@ class EngineService:
|
||||
method = control_req.get_method()
|
||||
request_id = control_req.request_id
|
||||
|
||||
# Look up worker_pid for routing control response
|
||||
worker_pid = None
|
||||
if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
|
||||
worker_pid = self.request_worker_map.pop(request_id, None)
|
||||
|
||||
try:
|
||||
self.llm_logger.info(f"START run control method {request_id}: {method}")
|
||||
|
||||
@@ -1269,19 +1303,22 @@ class EngineService:
|
||||
if handler is None or not callable(handler):
|
||||
error_result = ControlResponse(request_id, 400, f"unknown control method:{method}")
|
||||
self.llm_logger.error(str(error_result))
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
|
||||
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
|
||||
return
|
||||
|
||||
result = handler(control_req)
|
||||
self.llm_logger.info(f"SUCCESS run control method {method}.")
|
||||
succ_result = ControlResponse(request_id, 200, "Success", result)
|
||||
self.send_response_server.send_response(request_id, [succ_result])
|
||||
data = [[succ_result]] if envs.ZMQ_SEND_BATCH_DATA else [succ_result]
|
||||
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed run control method {method}: {str(e)}"
|
||||
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
error_result = ControlResponse(request_id, 500, error_msg)
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
|
||||
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
|
||||
|
||||
def _control_pause(self, control_request: ControlRequest):
|
||||
"""Pauses the LLM engine and aborts all running/inflight requests.
|
||||
@@ -1454,7 +1491,7 @@ class EngineService:
|
||||
# Use a single asyncio.run() to concurrently wait for all worker responses.
|
||||
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
|
||||
|
||||
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
|
||||
def _send_error_response(self, request_id, error_msg, error_code: int = 500, worker_pid=None):
|
||||
self.llm_logger.error(
|
||||
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
|
||||
)
|
||||
@@ -1464,10 +1501,15 @@ class EngineService:
|
||||
error_code=error_code,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
# Look up worker_pid from mapping if not provided
|
||||
if worker_pid is None and envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
|
||||
worker_pid = self.request_worker_map.pop(request_id, None)
|
||||
# Since the request is not in scheduler
|
||||
# Send result by zmq directly
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.send_response_server.send_response(None, [[error_result]])
|
||||
elif envs.ZMQ_SEND_BATCH_DATA:
|
||||
self.send_response_server.send_response(None, [[error_result]], worker_pid=worker_pid)
|
||||
else:
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
|
||||
@@ -1529,6 +1571,7 @@ class EngineService:
|
||||
self.send_response_server.send_response(None, new_contents)
|
||||
|
||||
else:
|
||||
worker_batches = collections.defaultdict(list)
|
||||
for request_id, contents in results.items():
|
||||
new_contents = []
|
||||
for content in contents:
|
||||
@@ -1537,7 +1580,9 @@ class EngineService:
|
||||
delta_text = ""
|
||||
if decode_type == 0:
|
||||
delta_text, token_ids = self._decode_token(
|
||||
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
|
||||
token_ids=content.outputs.token_ids,
|
||||
req_id=request_id,
|
||||
is_end=content.finished,
|
||||
)
|
||||
else:
|
||||
token_ids = content.outputs.token_ids
|
||||
@@ -1553,9 +1598,19 @@ class EngineService:
|
||||
)
|
||||
else:
|
||||
new_contents.append(content)
|
||||
if len(new_contents):
|
||||
self.llm_logger.debug(f"Send response for request id: {request_id}")
|
||||
self.send_response_server.send_response(request_id, new_contents)
|
||||
if new_contents:
|
||||
if envs.ZMQ_SEND_BATCH_DATA:
|
||||
wpid = self.request_worker_map.get(request_id)
|
||||
worker_batches[wpid].append(new_contents)
|
||||
is_finished = any(getattr(c, "finished", False) for c in new_contents)
|
||||
if is_finished:
|
||||
self.request_worker_map.pop(request_id, None)
|
||||
else:
|
||||
self.send_response_server.send_response(request_id, new_contents)
|
||||
if envs.ZMQ_SEND_BATCH_DATA:
|
||||
for wpid, batch_data in worker_batches.items():
|
||||
if batch_data:
|
||||
self.send_response_server.send_response(None, batch_data, worker_pid=wpid)
|
||||
except Exception as e:
|
||||
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
@@ -1714,6 +1769,9 @@ class EngineService:
|
||||
self.cache_task_queue.clear_transfer_task()
|
||||
self.send_response_server.req_dict.clear()
|
||||
self.recv_request_server.req_dict.clear()
|
||||
# Clean up worker_pid mapping (batch mode)
|
||||
if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
|
||||
self.request_worker_map.clear()
|
||||
self.llm_logger.info("Clear Data: Successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user