mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] Update ZMQ server (#6735)
* add batch zmq send reaponse
* update
* Revert "update"
This reverts commit 0234a25b47.
* update
* remove lock
* fix unit test
* add unit test
* add unit test
* pre commit
* add unit test
* fix unit test
* add unit test
* fix worker>1
* update zmq_worker_pid
* fix unit test
* fix unit test
* fix unit test
* add unit test
* fix unit test
* fix first token time
* fix logprobs
* add unit test
* op
* remore debug log
---------
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -259,16 +259,14 @@ class OpenAIServingCompletion:
|
||||
"""
|
||||
Process the full completion request with multiple choices.
|
||||
"""
|
||||
dealer = None
|
||||
try:
|
||||
request_ids = [f"{request_id}_{i}" for i in range(num_choices)]
|
||||
# create dealer
|
||||
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
|
||||
request_id, num_choices
|
||||
)
|
||||
|
||||
for rid in request_ids:
|
||||
dealer.write([b"", rid.encode("utf-8")])
|
||||
if not envs.ZMQ_SEND_BATCH_DATA:
|
||||
request_ids = [f"{request_id}_{i}" for i in range(num_choices)]
|
||||
for rid in request_ids:
|
||||
dealer.write([b"", rid.encode("utf-8")])
|
||||
|
||||
valid_results = [dict()] * num_choices
|
||||
output_tokens = [0] * num_choices
|
||||
@@ -291,6 +289,9 @@ class OpenAIServingCompletion:
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
except asyncio.CancelledError:
|
||||
# Client disconnected, propagate to outer handler
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
current_waiting_time += 10
|
||||
if current_waiting_time == 300:
|
||||
@@ -378,8 +379,7 @@ class OpenAIServingCompletion:
|
||||
trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", ""))
|
||||
tracing.trace_req_finish(request_id)
|
||||
self.engine_client.semaphore.release()
|
||||
if dealer is not None:
|
||||
await self.engine_client.connection_manager.cleanup_request(request_id)
|
||||
await self.engine_client.connection_manager.cleanup_request(request_id)
|
||||
|
||||
def _echo_back_prompt(self, request, idx):
|
||||
"""
|
||||
@@ -432,10 +432,11 @@ class OpenAIServingCompletion:
|
||||
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
|
||||
request_id, num_choices
|
||||
)
|
||||
if not envs.ZMQ_SEND_BATCH_DATA:
|
||||
request_ids = [f"{request_id}_{i}" for i in range(num_choices)]
|
||||
for rid in request_ids:
|
||||
dealer.write([b"", rid.encode("utf-8")])
|
||||
|
||||
for i in range(num_choices):
|
||||
req_id = f"{request_id}_{i}"
|
||||
dealer.write([b"", req_id.encode("utf-8")]) # 发送多路请求
|
||||
output_tokens = [0] * num_choices
|
||||
num_cache_tokens = [0] * num_choices
|
||||
num_image_tokens = [0] * num_choices
|
||||
@@ -463,6 +464,9 @@ class OpenAIServingCompletion:
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
except asyncio.CancelledError:
|
||||
# Client disconnected, propagate to outer handler
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
current_waiting_time += 10
|
||||
if current_waiting_time == 300:
|
||||
@@ -655,9 +659,8 @@ class OpenAIServingCompletion:
|
||||
trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", ""))
|
||||
tracing.trace_req_finish(request_id)
|
||||
del request
|
||||
if dealer is not None:
|
||||
await self.engine_client.connection_manager.cleanup_request(request_id)
|
||||
self.engine_client.semaphore.release()
|
||||
await self.engine_client.connection_manager.cleanup_request(request_id)
|
||||
self.engine_client.semaphore.release()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def request_output_to_completion_response(
|
||||
@@ -899,23 +902,30 @@ class OpenAIServingCompletion:
|
||||
|
||||
token_ids, logprobs, ranks = prompt_logprobs_tensors
|
||||
|
||||
# Normalize to plain Python lists (support both Tensor and list inputs)
|
||||
if hasattr(token_ids, "tolist"):
|
||||
token_ids = token_ids.tolist()
|
||||
logprobs = logprobs.tolist()
|
||||
ranks = ranks.tolist()
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
if include_logprobs_decode_token:
|
||||
decoded_tokens = [
|
||||
self.engine_client.data_processor.process_logprob_response(token_id)
|
||||
for token_id in token_ids.flatten().tolist()
|
||||
for row in token_ids
|
||||
for token_id in row
|
||||
]
|
||||
else:
|
||||
decoded_tokens = None
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
num_prompt_tokens = len(logprobs)
|
||||
num_logprobs = len(logprobs[0]) if num_prompt_tokens > 0 else 0
|
||||
|
||||
# Pythonize the paddle tensors.
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
# Build result.
|
||||
prompt_token_ranks = ranks
|
||||
prompt_logprobs = logprobs
|
||||
result: Optional[PromptLogprobs] = [None]
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
|
||||
Reference in New Issue
Block a user