【feature】support n parameter (#4273)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* support n parameter

* pre-commit check

* pre-commit check

* restore format_and_add_data

* update n_param

* bug fix index - str to int

* bug fix del child_task

* bug fix metrics

* add debug info

* add debug info2

* remove debug info

* change connecting symbol to '-'

* bugfix change connecting symbol

* bugfix change connecting symbol2

* unit tests fix

* unit test fix2

* unittest add param n=2

* n param add unit tests and adapt to echo

* pre-commit fix

* resolve review

* adjust stop reason

* add unittest for _create_chat_completion_choice

* modify unittest

* solve confict

* solve conflict

* resolve conflict

---------

Co-authored-by: LiqinruiG <37392159+LiqinruiG@users.noreply.github.com>
Co-authored-by: gaoziyuan <m13689897706@163.com>
This commit is contained in:
kxz2002
2025-10-17 20:51:59 +08:00
committed by GitHub
parent 8ccfd975b5
commit b5b993e48e
10 changed files with 459 additions and 110 deletions
@@ -129,7 +129,7 @@ class OpenAIServingCompletion:
if request_prompt_ids is not None:
request_prompts = request_prompt_ids
num_choices = len(request_prompts)
num_choices = len(request_prompts) * (1 if request.n is None else request.n)
api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}")
prompt_batched_token_ids = []
prompt_tokens_list = []
@@ -151,7 +151,7 @@ class OpenAIServingCompletion:
try:
try:
for idx, prompt in enumerate(request_prompts):
request_id_idx = f"{request_id}-{idx}"
request_id_idx = f"{request_id}_{idx}"
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
current_req_dict["arrival_time"] = time.time()
prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict) # tokenize
@@ -163,7 +163,9 @@ class OpenAIServingCompletion:
except ParameterError as e:
api_server_logger.error(f"OpenAIServingCompletion format error: {e}, {e.message}")
self.engine_client.semaphore.release()
return ErrorResponse(code=400, message=str(e.message), type="invalid_request", param=e.param)
return ErrorResponse(
error=ErrorInfo(code="400", message=str(e.message), type="invalid_request", param=e.param)
)
except Exception as e:
error_msg = f"OpenAIServingCompletion format error: {e}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
@@ -220,7 +222,7 @@ class OpenAIServingCompletion:
"""
dealer = None
try:
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
request_ids = [f"{request_id}_{i}" for i in range(num_choices)]
# create dealer
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
request_id, num_choices
@@ -259,7 +261,7 @@ class OpenAIServingCompletion:
continue
for data in response:
rid = int(data["request_id"].split("-")[-1])
rid = int(data["request_id"].split("_")[-1])
if data.get("error_code", 200) != 200:
raise ValueError("{}".format(data["error_msg"]))
@@ -323,7 +325,7 @@ class OpenAIServingCompletion:
Process the echo logic and return the modified text.
"""
if request.echo and res_outputs.get("send_idx", -1) == 0:
prompt_text = self._echo_back_prompt(request, idx)
prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n))
res_outputs["text"] = prompt_text + (res_outputs["text"] or "")
return res_outputs
@@ -355,7 +357,7 @@ class OpenAIServingCompletion:
)
for i in range(num_choices):
req_id = f"{request_id}-{i}"
req_id = f"{request_id}_{i}"
dealer.write([b"", req_id.encode("utf-8")]) # 发送多路请求
output_tokens = [0] * num_choices
inference_start_time = [0] * num_choices
@@ -393,7 +395,7 @@ class OpenAIServingCompletion:
continue
for res in response:
idx = int(res["request_id"].split("-")[-1])
idx = int(res["request_id"].split("_")[-1])
if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"]))
@@ -407,8 +409,12 @@ class OpenAIServingCompletion:
CompletionResponseStreamChoice(
index=idx,
text="",
prompt_token_ids=list(prompt_batched_token_ids[idx]),
prompt_tokens=prompt_tokens_list[idx],
prompt_token_ids=list(
prompt_batched_token_ids[idx // (1 if request.n is None else request.n)]
),
prompt_tokens=prompt_tokens_list[
idx // (1 if request.n is None else request.n)
],
completion_token_ids=None,
)
],
@@ -493,9 +499,14 @@ class OpenAIServingCompletion:
model=model_name,
choices=[],
usage=UsageInfo(
prompt_tokens=len(prompt_batched_token_ids[idx]),
prompt_tokens=len(
prompt_batched_token_ids[idx // (1 if request.n is None else request.n)]
),
completion_tokens=output_tokens[idx],
total_tokens=len(prompt_batched_token_ids[idx]) + output_tokens[idx],
total_tokens=len(
prompt_batched_token_ids[idx // (1 if request.n is None else request.n)]
)
+ output_tokens[idx],
),
)
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
@@ -503,7 +514,7 @@ class OpenAIServingCompletion:
except Exception as e:
api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}")
yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n"
yield f"data: {ErrorResponse(error=ErrorInfo(message=str(e), code='400', type=ErrorType.INTERNAL_ERROR)).model_dump_json(exclude_unset=True)}\n\n"
finally:
del request
if dealer is not None:
@@ -528,7 +539,7 @@ class OpenAIServingCompletion:
for idx in range(len(final_res_batch)):
final_res = final_res_batch[idx]
prompt_token_ids = prompt_batched_token_ids[idx]
prompt_token_ids = prompt_batched_token_ids[idx // (1 if request.n is None else request.n)]
assert prompt_token_ids is not None
completion_token_ids = completion_batched_token_ids[idx]
@@ -540,7 +551,7 @@ class OpenAIServingCompletion:
aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
if request.echo:
prompt_text = self._echo_back_prompt(request, idx)
prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n))
token_ids = [*prompt_token_ids, *output["token_ids"]]
output_text = prompt_text + output["text"]
else:
@@ -555,7 +566,11 @@ class OpenAIServingCompletion:
prompt_token_ids=prompt_token_ids if request.return_token_ids else None,
completion_token_ids=completion_token_ids if request.return_token_ids else None,
completion_tokens=output.get("completion_tokens") if request.return_token_ids else None,
prompt_tokens=prompt_tokens_list[idx] if request.return_token_ids else None,
prompt_tokens=(
prompt_tokens_list[idx // (1 if request.n is None else request.n)]
if request.return_token_ids
else None
),
reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call"),
logprobs=aggregated_logprobs,
@@ -567,6 +582,7 @@ class OpenAIServingCompletion:
num_prompt_tokens += len(prompt_token_ids)
num_prompt_tokens = num_prompt_tokens // (1 if request.n is None else request.n)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,