[Optimization] Update ZMQ server (#6735)

* add batch zmq send reaponse

* update

* Revert "update"

This reverts commit 0234a25b47.

* update

* remove lock

* fix unit test

* add unit test

* add unit test

* pre commit

* add unit test

* fix unit test

* add unit test

* fix worker>1

* update zmq_worker_pid

* fix unit test

* fix unit test

* fix unit test

* add unit test

* fix unit test

* fix first token time

* fix logprobs

* add unit test

* op

* remore debug log

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
luukunn
2026-03-19 21:53:16 +08:00
committed by GitHub
parent 9148562ed0
commit c3d8db85c4
18 changed files with 2739 additions and 133 deletions
@@ -259,16 +259,14 @@ class OpenAIServingCompletion:
"""
Process the full completion request with multiple choices.
"""
dealer = None
try:
request_ids = [f"{request_id}_{i}" for i in range(num_choices)]
# create dealer
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
request_id, num_choices
)
for rid in request_ids:
dealer.write([b"", rid.encode("utf-8")])
if not envs.ZMQ_SEND_BATCH_DATA:
request_ids = [f"{request_id}_{i}" for i in range(num_choices)]
for rid in request_ids:
dealer.write([b"", rid.encode("utf-8")])
valid_results = [dict()] * num_choices
output_tokens = [0] * num_choices
@@ -291,6 +289,9 @@ class OpenAIServingCompletion:
try:
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.CancelledError:
# Client disconnected, propagate to outer handler
raise
except asyncio.TimeoutError:
current_waiting_time += 10
if current_waiting_time == 300:
@@ -378,8 +379,7 @@ class OpenAIServingCompletion:
trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", ""))
tracing.trace_req_finish(request_id)
self.engine_client.semaphore.release()
if dealer is not None:
await self.engine_client.connection_manager.cleanup_request(request_id)
await self.engine_client.connection_manager.cleanup_request(request_id)
def _echo_back_prompt(self, request, idx):
"""
@@ -432,10 +432,11 @@ class OpenAIServingCompletion:
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
request_id, num_choices
)
if not envs.ZMQ_SEND_BATCH_DATA:
request_ids = [f"{request_id}_{i}" for i in range(num_choices)]
for rid in request_ids:
dealer.write([b"", rid.encode("utf-8")])
for i in range(num_choices):
req_id = f"{request_id}_{i}"
dealer.write([b"", req_id.encode("utf-8")]) # 发送多路请求
output_tokens = [0] * num_choices
num_cache_tokens = [0] * num_choices
num_image_tokens = [0] * num_choices
@@ -463,6 +464,9 @@ class OpenAIServingCompletion:
try:
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.CancelledError:
# Client disconnected, propagate to outer handler
raise
except asyncio.TimeoutError:
current_waiting_time += 10
if current_waiting_time == 300:
@@ -655,9 +659,8 @@ class OpenAIServingCompletion:
trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", ""))
tracing.trace_req_finish(request_id)
del request
if dealer is not None:
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
@@ -899,23 +902,30 @@ class OpenAIServingCompletion:
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Normalize to plain Python lists (support both Tensor and list inputs)
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
logprobs = logprobs.tolist()
ranks = ranks.tolist()
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
if include_logprobs_decode_token:
decoded_tokens = [
self.engine_client.data_processor.process_logprob_response(token_id)
for token_id in token_ids.flatten().tolist()
for row in token_ids
for token_id in row
]
else:
decoded_tokens = None
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
num_prompt_tokens = len(logprobs)
num_logprobs = len(logprobs[0]) if num_prompt_tokens > 0 else 0
# Pythonize the paddle tensors.
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
# Build result.
prompt_token_ranks = ranks
prompt_logprobs = logprobs
result: Optional[PromptLogprobs] = [None]
# Make Logprob for each position.
for pos in range(num_prompt_tokens):