[Optimization]Unified data processing for online and offline (#6891)

* remove process_request

* fix chat

* fix unit test

* remove process response

* fix unit test

* fix offline decode

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix sampling_params

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
luukunn
2026-03-19 21:56:09 +08:00
committed by GitHub
parent c3d8db85c4
commit f4a79d4c00
19 changed files with 160 additions and 619 deletions
+9 -130
View File
@@ -81,90 +81,6 @@ class Ernie4_5Processor(BaseDataProcessor):
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
def process_request(self, request, max_model_len=None, **kwargs):
"""
Preprocess the request
Args:
request (Dict): may contain text and messages fields
Returns:
bool: Whether preprocessing is successful
str: error message
"""
data_processor_logger.info(f"Start processing request: {request}")
request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids
# processing stop_sequences and stop_token_ids
process_stop_token_ids(request, self.update_stop_seq)
# processing bad_words
bad_words = request.get("bad_words")
bad_words_token_ids = request.get("bad_words_token_ids")
if bad_words:
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
prompt = request.prompt
tokens = self.tokenizer.tokenize(prompt)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request.prompt_token_ids = token_ids
data_processor_logger.debug(
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
)
elif request.messages is not None:
task = request.to_dict()
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
if chat_template_kwargs:
if isinstance(chat_template_kwargs, dict):
for k, v in chat_template_kwargs.items():
if k not in task or task[k] is None:
task[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
else:
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
if len(request.prompt_token_ids) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
max_tokens = max_model_len - len(request.prompt_token_ids)
if request.get("max_tokens") is None:
request.set("max_tokens", max(1, max_tokens))
else:
request.set("max_tokens", min(max_tokens, request.get("max_tokens")))
if request.get("temperature") < _SAMPLING_EPS:
# zero temperature means greedy decoding: set top_k=1 to force argmax
request.set("temperature", 1)
request.set("top_k", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
request.set("top_k", 1)
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
parts = request.request_id.split("_")
if len(parts) > 1:
real_req_id = parts[0]
index = int(parts[1])
n = request.get("n", 1)
for idx in range(index * n, (index + 1) * n):
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
else:
self.model_status_dict[request.request_id] = model_status
request.enable_thinking = model_status == "think_start"
data_processor_logger.info(f"Processed request: {request}")
return request
def process_request_dict(self, request, max_model_len=None):
"""
Preprocess the request
@@ -257,46 +173,6 @@ class Ernie4_5Processor(BaseDataProcessor):
data_processor_logger.info(f"Processed request dict: {request}")
return request
def process_response(self, response_dict, **kwargs):
"""
Preprocess the response
Args:
response_dict (Dict): response for engine, contain ids fields
Returns:
Dict: response contain text fields
"""
req_id = response_dict.request_id
token_ids = response_dict.outputs.token_ids
response_dict.usage = {"completion_tokens": response_dict.outputs.index + 1}
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
full_text = self.tokenizer.decode(token_ids)
if self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
full_text,
response_dict,
self.model_status_dict[req_id],
)
response_dict.outputs.text = text
response_dict.outputs.reasoning_content = reasoning_content
else:
response_dict.outputs.text = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict.outputs.tool_calls = tool_call_info.tool_calls
response_dict.outputs.text = tool_call_info.content
if req_id in self.model_status_dict:
del self.model_status_dict[req_id]
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
return None
return response_dict
def process_response_dict(self, response_dict, stream, **kwargs):
"""
Preprocess the response
@@ -326,11 +202,15 @@ class Ernie4_5Processor(BaseDataProcessor):
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
request = kwargs.get("request", None)
direct_decode = kwargs.get("direct_decode", False)
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["enable_parser"] = False
if direct_decode:
delta_text = self.tokenizer.decode(token_ids)
previous_texts = ""
else:
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
if is_end:
full_text = previous_texts + delta_text
response_dict["outputs"]["text"] = full_text
@@ -351,8 +231,9 @@ class Ernie4_5Processor(BaseDataProcessor):
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
response_dict["outputs"]["text"] = tool_call_info.content
response_dict["outputs"]["completion_tokens"] = full_text
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.decode_status:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.model_status_dict:
del self.model_status_dict[req_id]
return response_dict
@@ -371,7 +252,6 @@ class Ernie4_5Processor(BaseDataProcessor):
req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"]
request = kwargs.get("request", None)
response_dict["outputs"]["enable_parser"] = False
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
@@ -402,7 +282,6 @@ class Ernie4_5Processor(BaseDataProcessor):
if not is_end:
response_dict["outputs"]["skipped"] = True
if self.tool_parser_obj:
response_dict["outputs"]["enable_parser"] = True
if req_id not in self.tool_parser_dict:
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parser_dict[req_id]