mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization]Unified data processing for online and offline (#6891)
* remove process_request * fix chat * fix unit test * remove process response * fix unit test * fix offline decode * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix sampling_params --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -81,90 +81,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
if reasoning_parser_obj:
|
||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Preprocess the request
|
||||
|
||||
Args:
|
||||
request (Dict): may contain text and messages fields
|
||||
|
||||
Returns:
|
||||
bool: Whether preprocessing is successful
|
||||
str: error message
|
||||
"""
|
||||
data_processor_logger.info(f"Start processing request: {request}")
|
||||
request = self._apply_default_parameters(request)
|
||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||
request.eos_token_ids = self.eos_token_ids
|
||||
|
||||
# processing stop_sequences and stop_token_ids
|
||||
process_stop_token_ids(request, self.update_stop_seq)
|
||||
|
||||
# processing bad_words
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
# processing prompt_token_ids
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
prompt = request.prompt
|
||||
tokens = self.tokenizer.tokenize(prompt)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
request.prompt_token_ids = token_ids
|
||||
data_processor_logger.debug(
|
||||
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
|
||||
)
|
||||
elif request.messages is not None:
|
||||
task = request.to_dict()
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
|
||||
if chat_template_kwargs:
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
for k, v in chat_template_kwargs.items():
|
||||
if k not in task or task[k] is None:
|
||||
task[k] = v
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
|
||||
else:
|
||||
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
|
||||
|
||||
if len(request.prompt_token_ids) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||
if request.get("max_tokens") is None:
|
||||
request.set("max_tokens", max(1, max_tokens))
|
||||
else:
|
||||
request.set("max_tokens", min(max_tokens, request.get("max_tokens")))
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature means greedy decoding: set top_k=1 to force argmax
|
||||
request.set("temperature", 1)
|
||||
request.set("top_k", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
request.set("top_k", 1)
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
parts = request.request_id.split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request.request_id] = model_status
|
||||
request.enable_thinking = model_status == "think_start"
|
||||
|
||||
data_processor_logger.info(f"Processed request: {request}")
|
||||
return request
|
||||
|
||||
def process_request_dict(self, request, max_model_len=None):
|
||||
"""
|
||||
Preprocess the request
|
||||
@@ -257,46 +173,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
data_processor_logger.info(f"Processed request dict: {request}")
|
||||
return request
|
||||
|
||||
def process_response(self, response_dict, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
|
||||
Args:
|
||||
response_dict (Dict): response for engine, contain ids fields
|
||||
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
req_id = response_dict.request_id
|
||||
token_ids = response_dict.outputs.token_ids
|
||||
|
||||
response_dict.usage = {"completion_tokens": response_dict.outputs.index + 1}
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
full_text = self.tokenizer.decode(token_ids)
|
||||
if self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||
full_text,
|
||||
response_dict,
|
||||
self.model_status_dict[req_id],
|
||||
)
|
||||
response_dict.outputs.text = text
|
||||
response_dict.outputs.reasoning_content = reasoning_content
|
||||
else:
|
||||
response_dict.outputs.text = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
||||
response_dict.outputs.text = tool_call_info.content
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
||||
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
|
||||
return None
|
||||
return response_dict
|
||||
|
||||
def process_response_dict(self, response_dict, stream, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
@@ -326,11 +202,15 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
request = kwargs.get("request", None)
|
||||
direct_decode = kwargs.get("direct_decode", False)
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["enable_parser"] = False
|
||||
if direct_decode:
|
||||
delta_text = self.tokenizer.decode(token_ids)
|
||||
previous_texts = ""
|
||||
else:
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
if is_end:
|
||||
full_text = previous_texts + delta_text
|
||||
response_dict["outputs"]["text"] = full_text
|
||||
@@ -351,8 +231,9 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
|
||||
response_dict["outputs"]["text"] = tool_call_info.content
|
||||
response_dict["outputs"]["completion_tokens"] = full_text
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.decode_status:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
return response_dict
|
||||
@@ -371,7 +252,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
request = kwargs.get("request", None)
|
||||
response_dict["outputs"]["enable_parser"] = False
|
||||
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
@@ -402,7 +282,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
if not is_end:
|
||||
response_dict["outputs"]["skipped"] = True
|
||||
if self.tool_parser_obj:
|
||||
response_dict["outputs"]["enable_parser"] = True
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
|
||||
@@ -20,7 +20,6 @@ from collections.abc import Mapping
|
||||
import numpy as np
|
||||
from paddleformers.generation import GenerationConfig
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
|
||||
from fastdeploy.input.utils import IDS_TYPE_FLAG, process_stop_token_ids
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
@@ -120,16 +119,6 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
set_value(request, "presence_penalty", 0.0)
|
||||
return request
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""process the input data"""
|
||||
task = request.to_dict()
|
||||
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
|
||||
self.process_request_dict(task, max_model_len)
|
||||
request = Request.from_dict(task)
|
||||
request = self._apply_default_parameters(request)
|
||||
|
||||
return request
|
||||
|
||||
def _parse_processor_kwargs(self, kwargs):
|
||||
"""解析多模态处理器参数配置"""
|
||||
if not kwargs:
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
@@ -76,25 +75,6 @@ class PaddleOCRVLProcessor(TextProcessor):
|
||||
self.image_patch_id = self.processor.image_patch_id
|
||||
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Process incoming request and generate model inputs.
|
||||
|
||||
Args:
|
||||
request: Input request object
|
||||
max_model_len (int, optional): Maximum context length
|
||||
**kwargs: Additional processing parameters
|
||||
|
||||
Returns:
|
||||
Request: Processed request with model inputs
|
||||
"""
|
||||
task = request.to_dict()
|
||||
task["enable_thinking"] = kwargs.get("enable_thinking", False)
|
||||
self.process_request_dict(task, max_model_len)
|
||||
request = Request.from_dict(task)
|
||||
request = self._apply_default_parameters(request)
|
||||
return request
|
||||
|
||||
def _parse_processor_kwargs(self, kwargs):
|
||||
"""
|
||||
Parse and validate multimodal processor arguments.
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
@@ -74,25 +73,6 @@ class Qwen3VLProcessor(TextProcessor):
|
||||
self.image_patch_id = self.processor.image_token_id
|
||||
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Process incoming request and generate model inputs.
|
||||
|
||||
Args:
|
||||
request: Input request object
|
||||
max_model_len (int, optional): Maximum context length
|
||||
**kwargs: Additional processing parameters
|
||||
|
||||
Returns:
|
||||
Request: Processed request with model inputs
|
||||
"""
|
||||
task = request.to_dict()
|
||||
task["enable_thinking"] = kwargs.get("enable_thinking", False)
|
||||
self.process_request_dict(task, max_model_len)
|
||||
request = Request.from_dict(task)
|
||||
request = self._apply_default_parameters(request)
|
||||
return request
|
||||
|
||||
def _parse_processor_kwargs(self, kwargs):
|
||||
"""
|
||||
Parse and validate multimodal processor arguments.
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
||||
from fastdeploy.input.utils import process_stop_token_ids
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
@@ -75,25 +74,6 @@ class QwenVLProcessor(TextProcessor):
|
||||
self.image_patch_id = self.processor.image_token_id
|
||||
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Process incoming request and generate model inputs.
|
||||
|
||||
Args:
|
||||
request: Input request object
|
||||
max_model_len (int, optional): Maximum context length
|
||||
**kwargs: Additional processing parameters
|
||||
|
||||
Returns:
|
||||
Request: Processed request with model inputs
|
||||
"""
|
||||
task = request.to_dict()
|
||||
task["enable_thinking"] = kwargs.get("enable_thinking", False)
|
||||
self.process_request_dict(task, max_model_len)
|
||||
request = Request.from_dict(task)
|
||||
request = self._apply_default_parameters(request)
|
||||
return request
|
||||
|
||||
def _parse_processor_kwargs(self, kwargs):
|
||||
"""
|
||||
Parse and validate multimodal processor arguments.
|
||||
|
||||
@@ -76,34 +76,6 @@ class BaseDataProcessor(ABC):
|
||||
set_value(request, "presence_penalty", 0.0)
|
||||
return request
|
||||
|
||||
@abstractmethod
|
||||
def process_request(self, request, **kwargs):
|
||||
"""
|
||||
Preprocess the request
|
||||
|
||||
Args:
|
||||
request (Dict): may contain text and messages fields
|
||||
**kwargs: others
|
||||
|
||||
Returns:
|
||||
bool: Whether preprocessing is successful
|
||||
str: error message
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_response(self, response_dict):
|
||||
"""
|
||||
Preprocess the response
|
||||
|
||||
Args:
|
||||
response_dict (Dict): response for engine, contain ids fields
|
||||
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def text2ids(self, text, max_model_len=None):
|
||||
"""
|
||||
text to token ids
|
||||
@@ -294,113 +266,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
logits_processors_args.pop("think_prompt_last_token_id", None)
|
||||
return logits_processors_args
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Preprocess the request
|
||||
|
||||
Args:
|
||||
request (Dict): may contain text and messages fields
|
||||
|
||||
Returns:
|
||||
bool: Whether preprocessing is successful
|
||||
str: error message
|
||||
"""
|
||||
data_processor_logger.info(f"Start processing request: {request}")
|
||||
request = self._apply_default_parameters(request)
|
||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
||||
request.eos_token_ids = self.eos_token_ids
|
||||
|
||||
# processing stop_sequences and stop_token_ids
|
||||
process_stop_token_ids(request, self.update_stop_seq)
|
||||
|
||||
# processing bad_words
|
||||
bad_words = request.get("bad_words")
|
||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
||||
if bad_words:
|
||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
||||
request["bad_words_token_ids"] = bad_words_token_ids
|
||||
|
||||
logits_processors_args = request.get("logits_processors_args") or {}
|
||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
||||
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
|
||||
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
|
||||
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
|
||||
logits_processors_args.pop("think_stop_sentence", None)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
|
||||
# processing prompt_token_ids
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
prompt = request.prompt
|
||||
add_special_tokens = request.get("add_special_tokens", False)
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list): # if prompt is a token id list
|
||||
request.prompt_token_ids = prompt
|
||||
else:
|
||||
request.prompt_token_ids = self.text2ids(
|
||||
request.prompt, max_model_len, add_special_tokens=add_special_tokens
|
||||
)
|
||||
elif request.messages is not None:
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
task = request.to_dict()
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
|
||||
if chat_template_kwargs:
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
for k, v in chat_template_kwargs.items():
|
||||
if k not in task or task[k] is None:
|
||||
task[k] = v
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
task.setdefault("enable_thinking", True)
|
||||
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
|
||||
else:
|
||||
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||
|
||||
if len(request.prompt_token_ids) == 0:
|
||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
||||
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
||||
|
||||
logits_processors_args = request.get("logits_processors_args") or {}
|
||||
logits_processors_args = self._update_thinking_prompt_state(request.prompt_token_ids, logits_processors_args)
|
||||
request["logits_processors_args"] = logits_processors_args
|
||||
|
||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
||||
if request.get("max_tokens") is None:
|
||||
request.set("max_tokens", max(1, max_tokens))
|
||||
else:
|
||||
request.set("max_tokens", min(max_tokens, request.get("max_tokens")))
|
||||
if request.get("temperature") < _SAMPLING_EPS:
|
||||
# zero temperature means greedy decoding: set top_k=1 to force argmax
|
||||
request.set("temperature", 1)
|
||||
request.set("top_k", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
request.set("top_k", 1)
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
parts = request.request_id.split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request.request_id] = model_status
|
||||
request.enable_thinking = model_status == "think_start"
|
||||
|
||||
if request.get("response_max_tokens") is not None and request.enable_thinking is False:
|
||||
request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
|
||||
data_processor_logger.info(f"Processed request: {request}")
|
||||
return request
|
||||
|
||||
def process_request_dict(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Preprocess the request
|
||||
@@ -439,10 +304,17 @@ class DataProcessor(BaseDataProcessor):
|
||||
# processing prompt_token_ids
|
||||
if not request.get("prompt_token_ids"):
|
||||
if request.get("prompt"):
|
||||
add_special_tokens = request.get("add_special_tokens", False)
|
||||
request["prompt_token_ids"] = self.text2ids(
|
||||
request["prompt"], max_model_len, add_special_tokens=add_special_tokens
|
||||
).tolist()
|
||||
prompt = request["prompt"]
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all(isinstance(t, int) for t in prompt)
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list):
|
||||
request["prompt_token_ids"] = prompt
|
||||
else:
|
||||
add_special_tokens = request.get("add_special_tokens", False)
|
||||
request["prompt_token_ids"] = self.text2ids(
|
||||
prompt, max_model_len, add_special_tokens=add_special_tokens
|
||||
).tolist()
|
||||
elif request.get("messages"):
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
@@ -504,39 +376,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
full_text = self.tokenizer.decode(token_ids, **kwargs)
|
||||
return full_text
|
||||
|
||||
def process_response(self, response_dict, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
|
||||
Args:
|
||||
response_dict (Dict): response for engine, contain ids fields
|
||||
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
req_id = response_dict.request_id
|
||||
token_ids = response_dict.outputs.token_ids
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
full_text = self.tokenizer.decode(token_ids)
|
||||
response_dict.outputs.text = full_text
|
||||
if self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||
full_text, response_dict, self.model_status_dict[req_id]
|
||||
)
|
||||
response_dict.outputs.text = text
|
||||
response_dict.outputs.reasoning_content = reasoning_content
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
||||
|
||||
return response_dict
|
||||
|
||||
def process_response_dict_normal(self, response_dict, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
@@ -551,10 +390,15 @@ class DataProcessor(BaseDataProcessor):
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
request = kwargs.get("request", None)
|
||||
direct_decode = kwargs.get("direct_decode", False)
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] in self.eos_token_ids:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
if direct_decode:
|
||||
delta_text = self.tokenizer.decode(token_ids)
|
||||
previous_texts = ""
|
||||
else:
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
if is_end:
|
||||
full_text = previous_texts + delta_text
|
||||
response_dict["outputs"]["completion_tokens"] = full_text
|
||||
@@ -574,8 +418,9 @@ class DataProcessor(BaseDataProcessor):
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.decode_status:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
return response_dict
|
||||
@@ -594,7 +439,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
request = kwargs.get("request", None)
|
||||
response_dict["outputs"]["enable_parser"] = False
|
||||
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] in self.eos_token_ids:
|
||||
@@ -606,7 +450,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
response_dict["outputs"]["tool_calls"] = None
|
||||
response_dict["outputs"]["reasoning_content"] = ""
|
||||
if self.reasoning_parser:
|
||||
response_dict["outputs"]["enable_parser"] = True
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
@@ -626,7 +469,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
if not is_end:
|
||||
response_dict["outputs"]["skipped"] = True
|
||||
if self.tool_parser_obj:
|
||||
response_dict["outputs"]["enable_parser"] = True
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
|
||||
@@ -351,7 +351,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
full_text = previous_texts + delta_text
|
||||
response_obj.outputs.text = full_text
|
||||
if self.reasoning_parser:
|
||||
response_obj.outputs.enable_parser = True
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||
full_text,
|
||||
request,
|
||||
@@ -362,7 +361,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
||||
response_obj.outputs.reasoning_token_num = len(reasoning_tokens)
|
||||
if self.tool_parser_obj:
|
||||
response_obj.outputs.enable_parser = True
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
||||
if tool_call_info.tools_called:
|
||||
@@ -396,7 +394,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_obj.outputs.completion_tokens = delta_text
|
||||
if self.reasoning_parser:
|
||||
response_obj.outputs.enable_parser = True
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
@@ -420,7 +417,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
else:
|
||||
response_obj.outputs.text = delta_text
|
||||
if self.tool_parser_obj:
|
||||
response_obj.outputs.enable_parser = True
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
|
||||
@@ -555,7 +555,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
response_obj.outputs.completion_tokens = full_text
|
||||
response_obj.outputs.text = full_text
|
||||
if self.reasoning_parser:
|
||||
response_obj.outputs.enable_parser = True
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||
full_text,
|
||||
request,
|
||||
@@ -566,7 +565,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else []
|
||||
response_obj.outputs.reasoning_token_num = len(reasoning_tokens)
|
||||
if self.tool_parser_obj:
|
||||
response_obj.outputs.enable_parser = True
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
||||
if tool_call_info.tools_called:
|
||||
@@ -600,7 +598,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_obj.outputs.completion_tokens = delta_text
|
||||
if self.reasoning_parser:
|
||||
response_obj.outputs.enable_parser = True
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
@@ -615,7 +612,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else []
|
||||
response_obj.outputs.reasoning_token_num = len(reasoning_tokens)
|
||||
if self.tool_parser_obj:
|
||||
response_obj.outputs.enable_parser = True
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
|
||||
Reference in New Issue
Block a user