mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature]Optimization of Thinking Pattern Framework (#4302)
* add model status in vl
* add x1 parser
* add model_status
* fix parser
* fix parser
* fix parser
* fix parser
* Revert "fix parser"
This reverts commit 300f446d8a.
* fix parser
* fix
* fix
* fix
* fix
* fix parser
* fix unit test
* fix unit test
* add unit test
* fix
* fix
* add unit test
* fix unit test
* add unit test
* add unit test
* fix unit test
* fix unit test
* fix bug
* fix unit test
* x1 tool parser
* fix unit test
* fix unit test
* fix unit test
* fix n
* fix unit test
* add unit test
* add unit test
* remove pring
This commit is contained in:
@@ -72,13 +72,12 @@ class ChatResponseProcessor:
|
||||
else:
|
||||
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
|
||||
|
||||
async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output):
|
||||
async def process_response_chat(self, request_outputs, stream, include_stop_str_in_output):
|
||||
"""
|
||||
Process a list of responses into a generator that yields each processed response as it's generated.
|
||||
Args:
|
||||
request_outputs: The list of outputs to be processed.
|
||||
stream: Whether or not to stream the output.
|
||||
enable_thinking: Whether or not to show thinking messages.
|
||||
include_stop_str_in_output: Whether or not to include stop strings in the output.
|
||||
"""
|
||||
for request_output in request_outputs:
|
||||
@@ -99,7 +98,6 @@ class ChatResponseProcessor:
|
||||
response = await self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
audio_tokens=all_audio_tokens,
|
||||
tts=tts,
|
||||
@@ -108,7 +106,6 @@ class ChatResponseProcessor:
|
||||
response = self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
audio_tokens=all_audio_tokens,
|
||||
tts=tts,
|
||||
@@ -127,7 +124,6 @@ class ChatResponseProcessor:
|
||||
yield self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
elif stream:
|
||||
@@ -156,14 +152,12 @@ class ChatResponseProcessor:
|
||||
await self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
else:
|
||||
self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
text = {"type": "text", "text": request_output["outputs"]["text"]}
|
||||
@@ -185,14 +179,12 @@ class ChatResponseProcessor:
|
||||
await self.data_processor.process_response_dict(
|
||||
response_dict=part["request_output"],
|
||||
stream=False,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
else:
|
||||
self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
|
||||
|
||||
@@ -220,8 +220,6 @@ class OpenAIServingChat:
|
||||
|
||||
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
|
||||
|
||||
enable_thinking = self._get_thinking_status(request)
|
||||
|
||||
include_stop_str_in_output = request.include_stop_str_in_output
|
||||
|
||||
stream_options = request.stream_options
|
||||
@@ -277,7 +275,6 @@ class OpenAIServingChat:
|
||||
generator = response_processor.process_response_chat(
|
||||
response,
|
||||
stream=True,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
|
||||
@@ -507,7 +504,6 @@ class OpenAIServingChat:
|
||||
"""
|
||||
created_time = int(time.time())
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
enable_thinking = self._get_thinking_status(request)
|
||||
|
||||
include_stop_str_in_output = request.include_stop_str_in_output
|
||||
try:
|
||||
@@ -562,7 +558,6 @@ class OpenAIServingChat:
|
||||
generator = response_processor.process_response_chat(
|
||||
response,
|
||||
stream=False,
|
||||
enable_thinking=enable_thinking,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)
|
||||
async for data in generator:
|
||||
@@ -824,23 +819,6 @@ class OpenAIServingChat:
|
||||
api_server_logger.error(error_msg)
|
||||
return None
|
||||
|
||||
def _get_thinking_status(self, request: ChatCompletionRequest) -> bool:
|
||||
"""
|
||||
Get the thinking status from the request.
|
||||
"""
|
||||
enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None
|
||||
if enable_thinking is None:
|
||||
enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None
|
||||
options = request.chat_template_kwargs.get("options") if request.chat_template_kwargs else None
|
||||
if options:
|
||||
thinking_mode = options.get("thinking_mode")
|
||||
if thinking_mode:
|
||||
if thinking_mode == "close" or thinking_mode == "false":
|
||||
enable_thinking = False
|
||||
else:
|
||||
enable_thinking = True
|
||||
return enable_thinking
|
||||
|
||||
def _build_prompt_logprobs(
|
||||
self,
|
||||
prompt_logprobs_tensors: LogprobsTensors,
|
||||
|
||||
@@ -61,6 +61,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
self.decode_status = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.thinking_parser_dict = dict()
|
||||
self.model_status_dict = dict()
|
||||
self._load_tokenizer()
|
||||
data_processor_logger.info(
|
||||
f"tokenizer information: bos_token is {self.tokenizer.bos_token} \
|
||||
@@ -148,8 +149,18 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request.set("temperature", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request.enable_thinking = True
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
parts = request.request_id.split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request.request_id] = model_status
|
||||
request.enable_thinking = model_status == "think_start"
|
||||
|
||||
data_processor_logger.info(f"Processed request: {request}")
|
||||
return request
|
||||
@@ -222,9 +233,19 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
request["temperature"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request["enable_thinking"] = True
|
||||
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||
parts = request["request_id"].split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request["request_id"]] = model_status
|
||||
request["enable_thinking"] = model_status == "think_start"
|
||||
data_processor_logger.info(f"Processed request dict: {request}")
|
||||
return request
|
||||
|
||||
@@ -246,7 +267,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
token_ids = token_ids[:-1]
|
||||
full_text = self.tokenizer.decode(token_ids)
|
||||
if self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||
full_text,
|
||||
response_dict,
|
||||
self.model_status_dict[req_id],
|
||||
)
|
||||
response_dict.outputs.text = text
|
||||
response_dict.outputs.reasoning_content = reasoning_content
|
||||
else:
|
||||
@@ -257,6 +282,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
if tool_call_info.tools_called:
|
||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
||||
response_dict.outputs.text = tool_call_info.content
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
||||
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
|
||||
return None
|
||||
@@ -287,7 +314,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
enable_thinking = kwargs.get("enable_thinking")
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
@@ -297,16 +323,17 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
if is_end:
|
||||
full_text = previous_texts + delta_text
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
||||
response_dict["outputs"]["text"] = full_text
|
||||
if self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||
full_text,
|
||||
response_dict,
|
||||
self.model_status_dict[req_id],
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
||||
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
|
||||
else:
|
||||
response_dict["outputs"]["text"] = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
@@ -316,6 +343,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
response_dict["outputs"]["completion_tokens"] = full_text
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def process_response_dict_streaming(self, response_dict, **kwargs):
|
||||
@@ -328,7 +357,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
enable_thinking = kwargs.get("enable_thinking")
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
@@ -338,9 +366,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["completion_tokens"] = delta_text
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
if self.reasoning_parser:
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
@@ -348,6 +374,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
previous_token_ids,
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
self.model_status_dict[req_id],
|
||||
)
|
||||
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
||||
reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None
|
||||
@@ -374,6 +401,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def messages2ids(self, request_or_messages, **kwargs):
|
||||
|
||||
@@ -58,6 +58,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
|
||||
self.tool_parser_dict = dict()
|
||||
self.decode_status = dict()
|
||||
self.model_status_dict = dict()
|
||||
self._load_tokenizer()
|
||||
|
||||
# Generation config
|
||||
@@ -242,15 +243,6 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
request[k] = v
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
options = chat_template_kwargs.get("options")
|
||||
if options:
|
||||
thinking_mode = options.get("thinking_mode")
|
||||
if thinking_mode:
|
||||
if thinking_mode == "close" or thinking_mode == "false":
|
||||
request["enable_thinking"] = False
|
||||
else:
|
||||
request["enable_thinking"] = True
|
||||
request.setdefault("enable_thinking", True)
|
||||
outputs = self.ernie4_5_processor.request2ids(request)
|
||||
else:
|
||||
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
||||
@@ -279,6 +271,18 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||
parts = request["request_id"].split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request["request_id"]] = model_status
|
||||
request["enable_thinking"] = model_status == "think_start"
|
||||
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
|
||||
@@ -314,21 +318,3 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
outs["position_ids"] = np.array(outs["position_ids"], dtype=np.int64)
|
||||
|
||||
return outs
|
||||
|
||||
def process_response_dict(self, response_dict, stream, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
|
||||
Args:
|
||||
response_dict (Dict): response for engine, contain ids fields
|
||||
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
enable_thinking = kwargs.pop("enable_thinking", True)
|
||||
if enable_thinking is None:
|
||||
enable_thinking = True
|
||||
if stream:
|
||||
return self.process_response_dict_streaming(response_dict, enable_thinking=enable_thinking, **kwargs)
|
||||
else:
|
||||
return self.process_response_dict_normal(response_dict, enable_thinking=enable_thinking, **kwargs)
|
||||
|
||||
@@ -254,6 +254,19 @@ class PaddleOCRVLProcessor(TextProcessor):
|
||||
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||
parts = request["request_id"].split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request["request_id"]] = model_status
|
||||
request["enable_thinking"] = model_status == "think_start"
|
||||
|
||||
return request
|
||||
|
||||
def append_generated_tokens(self, multimodal_inputs, generated_token_ids):
|
||||
|
||||
@@ -267,6 +267,18 @@ class QwenVLProcessor(TextProcessor):
|
||||
# Set default max_tokens if not specified
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||
parts = request["request_id"].split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request["request_id"]] = model_status
|
||||
request["enable_thinking"] = model_status == "think_start"
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
return request
|
||||
|
||||
@@ -176,6 +176,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
self.generation_config = None
|
||||
|
||||
self.decode_status = dict()
|
||||
self.model_status_dict = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
data_processor_logger.info(
|
||||
@@ -267,6 +268,18 @@ class DataProcessor(BaseDataProcessor):
|
||||
request.set("temperature", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||
parts = request.request_id.split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request.request_id] = model_status
|
||||
request.enable_thinking = model_status == "think_start"
|
||||
|
||||
data_processor_logger.info(f"Processed request: {request}")
|
||||
return request
|
||||
@@ -340,6 +353,18 @@ class DataProcessor(BaseDataProcessor):
|
||||
request["temperature"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
if self.reasoning_parser:
|
||||
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||
parts = request["request_id"].split("_")
|
||||
if len(parts) > 1:
|
||||
real_req_id = parts[0]
|
||||
index = int(parts[1])
|
||||
n = request.get("n", 1)
|
||||
for idx in range(index * n, (index + 1) * n):
|
||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||
else:
|
||||
self.model_status_dict[request["request_id"]] = model_status
|
||||
request["enable_thinking"] = model_status == "think_start"
|
||||
|
||||
data_processor_logger.info(f"Processed request dict: {request}")
|
||||
return request
|
||||
@@ -363,21 +388,21 @@ class DataProcessor(BaseDataProcessor):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
full_text = self.tokenizer.decode(token_ids)
|
||||
|
||||
# 模型支持思考,并且支持思考
|
||||
response_dict.outputs.text = full_text
|
||||
if self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||
full_text, response_dict, self.model_status_dict[req_id]
|
||||
)
|
||||
response_dict.outputs.text = text
|
||||
response_dict.outputs.reasoning_content = reasoning_content
|
||||
else:
|
||||
# 模型不支持思考,并且没单独设置enable_thinking为false
|
||||
response_dict.outputs.text = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
||||
response_dict.outputs.text = tool_call_info.content
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
||||
|
||||
return response_dict
|
||||
@@ -392,7 +417,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
enable_thinking = kwargs.get("enable_thinking")
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
@@ -403,14 +427,17 @@ class DataProcessor(BaseDataProcessor):
|
||||
if is_end:
|
||||
full_text = previous_texts + delta_text
|
||||
response_dict["outputs"]["completion_tokens"] = full_text
|
||||
if enable_thinking and self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
||||
response_dict["outputs"]["text"] = full_text
|
||||
if self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||
full_text,
|
||||
response_dict,
|
||||
self.model_status_dict[req_id],
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
||||
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
|
||||
else:
|
||||
response_dict["outputs"]["text"] = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
@@ -419,6 +446,8 @@ class DataProcessor(BaseDataProcessor):
|
||||
response_dict["outputs"]["text"] = tool_call_info.content
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def process_response_dict_streaming(self, response_dict, **kwargs):
|
||||
@@ -431,7 +460,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
enable_thinking = kwargs.get("enable_thinking")
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
@@ -441,9 +469,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["completion_tokens"] = delta_text
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
if self.reasoning_parser:
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
@@ -451,6 +477,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
previous_token_ids,
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
self.model_status_dict[req_id],
|
||||
)
|
||||
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
||||
reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None
|
||||
@@ -477,6 +504,8 @@ class DataProcessor(BaseDataProcessor):
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
if req_id in self.model_status_dict:
|
||||
del self.model_status_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def process_response_dict(self, response_dict, **kwargs):
|
||||
@@ -489,16 +518,12 @@ class DataProcessor(BaseDataProcessor):
|
||||
Returns:
|
||||
Dict: response contain text fields
|
||||
"""
|
||||
enable_thinking = kwargs.pop("enable_thinking", True)
|
||||
if enable_thinking is None:
|
||||
enable_thinking = True
|
||||
stream = kwargs.get("stream", True)
|
||||
if stream:
|
||||
return self.process_response_dict_streaming(response_dict, enable_thinking=enable_thinking, **kwargs)
|
||||
return self.process_response_dict_streaming(response_dict, **kwargs)
|
||||
else:
|
||||
return self.process_response_dict_normal(
|
||||
response_dict=response_dict,
|
||||
enable_thinking=enable_thinking,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,25 +35,53 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.think_end_token = "</think>"
|
||||
self.tool_begin_token = "<tool_call>"
|
||||
token_definitions = {
|
||||
"think_start_token": "<think>",
|
||||
"think_end_token": "</think>",
|
||||
"tool_call_start_token": "<tool_call>",
|
||||
"tool_call_end_token": "</tool_call>",
|
||||
}
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
|
||||
)
|
||||
missing_tokens = []
|
||||
for name, token_value in token_definitions.items():
|
||||
setattr(self, name, token_value)
|
||||
token_id = self.vocab.get(token_value)
|
||||
setattr(self, f"{name}_id", token_id)
|
||||
if token_id is None:
|
||||
missing_tokens.append(f"{name.replace('_', ' ')} token")
|
||||
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
self.tool_begin_token_id = self.vocab.get(self.tool_begin_token)
|
||||
if self.tool_begin_token_id is None:
|
||||
self.tool_begin_token_id = -1
|
||||
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Test reasoning parser could not locate think end tokens in the tokenizer!")
|
||||
if missing_tokens:
|
||||
raise RuntimeError(
|
||||
f"ernie vl reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
|
||||
)
|
||||
self.token_status_mapping = {
|
||||
self.think_start_token_id: "think_start",
|
||||
self.think_end_token_id: "think_end",
|
||||
self.tool_call_start_token_id: "tool_call_start",
|
||||
self.tool_call_end_token_id: "tool_call_end",
|
||||
}
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
|
||||
for i in range(len(prompt_token_ids) - 1, -1, -1):
|
||||
if prompt_token_ids[i] in self.token_status_mapping:
|
||||
return prompt_token_ids[i]
|
||||
return -1
|
||||
|
||||
def get_model_status(self, prompt_token_ids: list[int]):
|
||||
special_token_id = self.find_last_special_token(prompt_token_ids)
|
||||
|
||||
if special_token_id == -1:
|
||||
return "think_start"
|
||||
|
||||
return self.token_status_mapping[special_token_id]
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
@@ -62,6 +90,7 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
model_status: str,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract reasoning content from a delta message.
|
||||
@@ -71,6 +100,7 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'xyz' goes to content
|
||||
"""
|
||||
if model_status == "think_start":
|
||||
if self.think_end_token not in current_text:
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
# Skip single special tokens
|
||||
@@ -98,9 +128,18 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
||||
if len(striped_suffix) == 0:
|
||||
return None
|
||||
return DeltaMessage(content=delta_text)
|
||||
elif model_status == "think_end":
|
||||
if current_text.lstrip("\n").startswith(self.tool_call_start_token):
|
||||
return None
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
return None
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
model_status: str,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
@@ -114,23 +153,30 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
||||
"""
|
||||
|
||||
# Check if the model output contains the </think> tokens.
|
||||
if model_status == "think_start":
|
||||
if self.think_end_token not in model_output:
|
||||
return model_output, ""
|
||||
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
||||
if self.tool_begin_token in content:
|
||||
prefix, _, _ = content.partition(self.tool_begin_token)
|
||||
if self.tool_call_start_token in content:
|
||||
prefix, _, _ = content.partition(self.tool_call_start_token)
|
||||
prefix_strip = prefix.lstrip("\n")
|
||||
if len(prefix_strip) > 0:
|
||||
return reasoning_content, content
|
||||
return reasoning_content, ""
|
||||
return reasoning_content, content
|
||||
elif model_status == "think_end":
|
||||
if model_output.lstrip("\n").startswith(self.tool_call_start_token):
|
||||
return "", ""
|
||||
return "", model_output
|
||||
else:
|
||||
return "", ""
|
||||
|
||||
def _is_with_tool(self, current_text: str, current_token_ids: Sequence[int]) -> bool:
|
||||
think_end_index = current_text.find(self.think_end_token)
|
||||
think_end = think_end_index + len(self.think_end_token)
|
||||
middle_str = current_text[think_end:]
|
||||
if self.tool_begin_token_id in current_token_ids:
|
||||
prefix, _, _ = middle_str.partition(self.tool_begin_token)
|
||||
if self.tool_call_start_token_id in current_token_ids:
|
||||
prefix, _, _ = middle_str.partition(self.tool_call_start_token)
|
||||
striped_prefix = prefix.strip("\n")
|
||||
if len(striped_prefix) > 0:
|
||||
return False
|
||||
|
||||
@@ -35,20 +35,48 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.think_end_token = "</think>"
|
||||
token_definitions = {
|
||||
"think_start_token": "<think>",
|
||||
"think_end_token": "</think>",
|
||||
}
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
|
||||
)
|
||||
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
|
||||
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!")
|
||||
missing_tokens = []
|
||||
for name, token_value in token_definitions.items():
|
||||
setattr(self, name, token_value)
|
||||
token_id = self.vocab.get(token_value)
|
||||
setattr(self, f"{name}_id", token_id)
|
||||
if token_id is None:
|
||||
missing_tokens.append(f"{name.replace('_', ' ')} token")
|
||||
|
||||
if missing_tokens:
|
||||
raise RuntimeError(
|
||||
f"ernie vl reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
|
||||
)
|
||||
self.token_status_mapping = {
|
||||
self.think_start_token_id: "think_start",
|
||||
self.think_end_token_id: "think_end",
|
||||
}
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
|
||||
for i in range(len(prompt_token_ids) - 1, -1, -1):
|
||||
if prompt_token_ids[i] in self.token_status_mapping:
|
||||
return prompt_token_ids[i]
|
||||
return -1
|
||||
|
||||
def get_model_status(self, prompt_token_ids: list[int]):
|
||||
special_token_id = self.find_last_special_token(prompt_token_ids)
|
||||
|
||||
if special_token_id == -1:
|
||||
return "think_start"
|
||||
|
||||
return self.token_status_mapping[special_token_id]
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
@@ -57,6 +85,7 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
model_status: str,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract reasoning content from a delta message.
|
||||
@@ -69,18 +98,23 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return None
|
||||
if model_status == "think_start":
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token) :]
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
if self.think_end_token_id in previous_token_ids:
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
else:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
model_status: str,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
@@ -92,11 +126,12 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[str]]: reasoning content and content
|
||||
"""
|
||||
|
||||
# Check if the model output contains the </think> tokens.
|
||||
if model_status == "think_start":
|
||||
if self.think_end_token not in model_output:
|
||||
return "", model_output
|
||||
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
||||
|
||||
final_content = content or ""
|
||||
return reasoning_content, final_content
|
||||
else:
|
||||
return "", model_output
|
||||
|
||||
@@ -18,19 +18,55 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.think_end_token = "</think>"
|
||||
self.response_start_token = "<response>"
|
||||
self.response_end_token = "</response>"
|
||||
self.tool_call_start_token = "<tool_call>"
|
||||
self.tool_call_end_token = "</tool_call>"
|
||||
|
||||
# 定义所有需要检查的token
|
||||
token_definitions = {
|
||||
"think_start_token": "<think>",
|
||||
"think_end_token": "</think>",
|
||||
"response_start_token": "<response>",
|
||||
"response_end_token": "</response>",
|
||||
"tool_call_start_token": "<tool_call>",
|
||||
"tool_call_end_token": "</tool_call>",
|
||||
}
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
|
||||
|
||||
self.think_end_token_id = self.vocab.get("</think>")
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
|
||||
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
|
||||
missing_tokens = []
|
||||
for name, token_value in token_definitions.items():
|
||||
setattr(self, name, token_value)
|
||||
token_id = self.vocab.get(token_value)
|
||||
setattr(self, f"{name}_id", token_id)
|
||||
if token_id is None:
|
||||
missing_tokens.append(token_value)
|
||||
|
||||
if missing_tokens:
|
||||
raise RuntimeError(
|
||||
f"ernie x1 reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
|
||||
)
|
||||
|
||||
self.token_status_mapping = {
|
||||
self.think_start_token_id: "think_start",
|
||||
self.think_end_token_id: "think_end",
|
||||
self.response_start_token_id: "response_start",
|
||||
self.response_end_token_id: "response_end",
|
||||
self.tool_call_start_token_id: "tool_call_start",
|
||||
self.tool_call_end_token_id: "tool_call_end",
|
||||
}
|
||||
|
||||
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
|
||||
for i in range(len(prompt_token_ids) - 1, -1, -1):
|
||||
if prompt_token_ids[i] in self.token_status_mapping:
|
||||
return prompt_token_ids[i]
|
||||
return -1
|
||||
|
||||
def get_model_status(self, prompt_token_ids: list[int]):
|
||||
special_token_id = self.find_last_special_token(prompt_token_ids)
|
||||
|
||||
if special_token_id == -1:
|
||||
return "think_start"
|
||||
|
||||
return self.token_status_mapping[special_token_id]
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
@@ -40,64 +76,82 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
model_status: str,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
# Ignore the single </think> token
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] in [
|
||||
self.think_end_token_id,
|
||||
self.response_start_token_id,
|
||||
self.response_end_token_id,
|
||||
self.tool_call_start_token_id,
|
||||
self.tool_call_end_token_id,
|
||||
]:
|
||||
return None
|
||||
|
||||
# --- Thinking stage handling ---
|
||||
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
|
||||
# If delta is </think>, stop thinking, do not return
|
||||
if delta_text.startswith(self.think_end_token):
|
||||
return None
|
||||
# Otherwise, return thinking content (keep \n as-is)
|
||||
if model_status == "think_start":
|
||||
if self.think_end_token in delta_text:
|
||||
response_content = ""
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
response_start_pos = delta_text.find(self.response_start_token)
|
||||
if response_start_pos != -1:
|
||||
response_content = self._extract_response_content(
|
||||
delta_text[response_start_pos + len(self.response_start_token) :]
|
||||
)
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=response_content)
|
||||
elif self.think_end_token in previous_text:
|
||||
if self.response_start_token in previous_text and self.response_end_token not in previous_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
# --- After thinking ends, check tool_call or response ---
|
||||
remaining_text = previous_text + delta_text
|
||||
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
|
||||
after_think = after_think.lstrip("\n")
|
||||
|
||||
# Handle tool_call case: skip it
|
||||
if after_think.startswith(self.tool_call_start_token):
|
||||
return None
|
||||
|
||||
# Handle response case
|
||||
if after_think.startswith(self.response_start_token) and self.response_end_token not in after_think:
|
||||
# Do not return when <response> tag itself appears
|
||||
if delta_text == self.response_start_token or delta_text == self.response_end_token:
|
||||
return None
|
||||
elif model_status == "think_end":
|
||||
if self.response_start_token in previous_text and self.response_end_token not in previous_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
elif model_status == "response_start":
|
||||
if self.response_end_token not in previous_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Default case: return nothing
|
||||
return None
|
||||
|
||||
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest, model_status: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
优化版解析器。保留推理和响应内容中的换行符,
|
||||
仅删除闭合标签前的单个换行符。
|
||||
"""
|
||||
reasoning_content = ""
|
||||
response_content = ""
|
||||
|
||||
if model_status in ["think_start", "think_end"]:
|
||||
if model_status == "think_start":
|
||||
think_end_pos = model_output.find(self.think_end_token)
|
||||
if think_end_pos != -1:
|
||||
reasoning_content = model_output[:think_end_pos]
|
||||
|
||||
remaining = model_output[think_end_pos + len(self.think_end_token) :]
|
||||
|
||||
# find <response> or <tool>
|
||||
response_pos = remaining.find(self.response_start_token)
|
||||
tool_pos = remaining.find(self.tool_call_start_token)
|
||||
|
||||
# <response> first
|
||||
if response_pos != -1 and (tool_pos == -1 or response_pos < tool_pos):
|
||||
# The content after the response_start position
|
||||
remaining_response = remaining[response_pos + len(self.response_start_token) :]
|
||||
response_end_pos = remaining_response.find(self.response_end_token)
|
||||
if response_end_pos != -1:
|
||||
response_content = remaining_response[:response_end_pos]
|
||||
else:
|
||||
response_content = remaining_response
|
||||
# The content after the response_start position is tool_call
|
||||
remaining = model_output[think_end_pos + len(self.think_end_token) :].lstrip("\n")
|
||||
else:
|
||||
reasoning_content = model_output
|
||||
response_content = ""
|
||||
remaining = ""
|
||||
else:
|
||||
remaining = model_output.lstrip("\n")
|
||||
|
||||
response_start_pos = remaining.find(self.response_start_token)
|
||||
if response_start_pos != -1:
|
||||
response_content = self._extract_response_content(
|
||||
remaining[response_start_pos + len(self.response_start_token) :]
|
||||
)
|
||||
|
||||
elif model_status == "response_start":
|
||||
response_content = self._extract_response_content(model_output)
|
||||
|
||||
return reasoning_content, response_content
|
||||
|
||||
def _extract_response_content(self, remaining: str) -> str:
|
||||
"""
|
||||
Extracts response content, ensuring that the last newline before
|
||||
the </response> tag is removed.
|
||||
"""
|
||||
response_end_pos = remaining.find(self.response_end_token)
|
||||
if response_end_pos != -1:
|
||||
return remaining[:response_end_pos]
|
||||
return remaining
|
||||
|
||||
@@ -35,22 +35,50 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.think_start_token = "<think>"
|
||||
self.think_end_token = "</think>"
|
||||
|
||||
# 定义所有需要检查的token
|
||||
token_definitions = {
|
||||
"think_start_token": "<think>",
|
||||
"think_end_token": "</think>",
|
||||
}
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
|
||||
)
|
||||
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
|
||||
|
||||
self.think_start_token_id = self.vocab.get(self.think_start_token)
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Qwen3 reasoning parser could not locate think end " "tokens in the tokenizer!")
|
||||
missing_tokens = []
|
||||
for name, token_value in token_definitions.items():
|
||||
setattr(self, name, token_value)
|
||||
token_id = self.vocab.get(token_value)
|
||||
setattr(self, f"{name}_id", token_id)
|
||||
if token_id is None:
|
||||
missing_tokens.append(token_value)
|
||||
|
||||
if missing_tokens:
|
||||
raise RuntimeError(
|
||||
f"Qwen3 reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
|
||||
)
|
||||
self.token_status_mapping = {
|
||||
self.think_start_token_id: "think_start",
|
||||
self.think_end_token_id: "think_end",
|
||||
}
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
|
||||
for i in range(len(prompt_token_ids) - 1, -1, -1):
|
||||
if prompt_token_ids[i] in self.token_status_mapping:
|
||||
return prompt_token_ids[i]
|
||||
return -1
|
||||
|
||||
def get_model_status(self, prompt_token_ids: list[int]):
|
||||
special_token_id = self.find_last_special_token(prompt_token_ids)
|
||||
|
||||
if special_token_id == -1:
|
||||
return "think_start"
|
||||
|
||||
return self.token_status_mapping[special_token_id]
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
@@ -59,6 +87,7 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
model_status: str,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract reasoning content from a delta message.
|
||||
@@ -71,6 +100,7 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]):
|
||||
return None
|
||||
|
||||
if model_status == "think_start":
|
||||
# </think> in delta
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# <think> in delta, </think> in delta, extract reasoning content
|
||||
@@ -101,9 +131,11 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
else:
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
else:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: ChatCompletionRequest, model_status: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
@@ -116,6 +148,7 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
tuple[Optional[str], Optional[str]]: reasoning content and content
|
||||
"""
|
||||
|
||||
if model_status == "think_start":
|
||||
# 检查是否包含结束标签
|
||||
if self.think_end_token not in model_output:
|
||||
return None, model_output
|
||||
@@ -149,3 +182,5 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
return reasoning_content, final_content
|
||||
|
||||
return None, model_output
|
||||
else:
|
||||
return None, model_output
|
||||
|
||||
@@ -482,7 +482,7 @@ def test_chat_with_thinking(openai_client, capsys):
|
||||
max_tokens=10,
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
||||
)
|
||||
assert response.choices[0].message.reasoning_content is None
|
||||
assert response.choices[0].message.reasoning_content == ""
|
||||
assert "</think>" not in response.choices[0].message.content
|
||||
|
||||
# test logic
|
||||
@@ -957,4 +957,4 @@ def test_thinking_logic_flag(openai_client, capsys):
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
)
|
||||
assert response_case_3.choices[0].message.reasoning_content is None
|
||||
assert response_case_3.choices[0].message.reasoning_content == ""
|
||||
|
||||
@@ -545,7 +545,7 @@ def test_chat_with_thinking(openai_client, capsys):
|
||||
max_tokens=10,
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
||||
)
|
||||
assert response.choices[0].message.reasoning_content is None
|
||||
assert response.choices[0].message.reasoning_content == ""
|
||||
assert "</think>" not in response.choices[0].message.content
|
||||
|
||||
# test logic
|
||||
@@ -716,4 +716,4 @@ def test_thinking_logic_flag(openai_client, capsys):
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
)
|
||||
assert response_case_3.choices[0].message.reasoning_content is None
|
||||
assert response_case_3.choices[0].message.reasoning_content == ""
|
||||
|
||||
@@ -312,7 +312,7 @@ def test_chat_with_thinking(openai_client, capsys):
|
||||
max_tokens=10,
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
||||
)
|
||||
assert response.choices[0].message.reasoning_content is None
|
||||
assert response.choices[0].message.reasoning_content == ""
|
||||
assert "</think>" not in response.choices[0].message.content
|
||||
|
||||
# test logic
|
||||
@@ -404,4 +404,4 @@ def test_thinking_logic_flag(openai_client, capsys):
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
)
|
||||
assert response_case_3.choices[0].message.reasoning_content is None
|
||||
assert response_case_3.choices[0].message.reasoning_content == ""
|
||||
|
||||
@@ -59,6 +59,8 @@ class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
|
||||
self.multi_modal_processor._check_mm_limits = Mock()
|
||||
self.multi_modal_processor.append_completion_tokens = Mock()
|
||||
self.multi_modal_processor.pack_outputs = lambda x: x
|
||||
self.multi_modal_processor.reasoning_parser = None
|
||||
self.multi_modal_processor.model_status_dict = {}
|
||||
|
||||
self.engine_client = Mock()
|
||||
self.engine_client.connection_initialized = False
|
||||
@@ -258,7 +260,7 @@ class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
|
||||
mock_processor_instance = Mock()
|
||||
mock_processor_instance.enable_multimodal_content.return_value = True
|
||||
|
||||
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
|
||||
async def mock_process_response_chat_async(response, stream, include_stop_str_in_output):
|
||||
yield response
|
||||
|
||||
mock_processor_instance.process_response_chat = mock_process_response_chat_async
|
||||
@@ -439,7 +441,7 @@ class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
|
||||
mock_processor_instance = Mock()
|
||||
mock_processor_instance.enable_multimodal_content.return_value = False
|
||||
|
||||
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
|
||||
async def mock_process_response_chat_async(response, stream, include_stop_str_in_output):
|
||||
if isinstance(response, list):
|
||||
for res in response:
|
||||
yield res
|
||||
|
||||
@@ -165,7 +165,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
|
||||
mock_processor_instance = Mock()
|
||||
|
||||
async def mock_process_response_chat_single(response, stream, enable_thinking, include_stop_str_in_output):
|
||||
async def mock_process_response_chat_single(response, stream, include_stop_str_in_output):
|
||||
yield response
|
||||
|
||||
mock_processor_instance.process_response_chat = mock_process_response_chat_single
|
||||
@@ -539,7 +539,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
|
||||
mock_processor_instance = Mock()
|
||||
|
||||
async def mock_process_response_chat(response, stream, enable_thinking, include_stop_str_in_output):
|
||||
async def mock_process_response_chat(response, stream, include_stop_str_in_output):
|
||||
delta_msg_mock = Mock()
|
||||
delta_msg_mock.content = response["outputs"]["text"]
|
||||
if response["outputs"]["text"] == "a":
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
results = [
|
||||
r
|
||||
async for r in processor.process_response_chat(
|
||||
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
|
||||
request_outputs, stream=False, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
@@ -69,7 +69,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
results = [
|
||||
r
|
||||
async for r in processor.process_response_chat(
|
||||
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
|
||||
request_outputs, stream=True, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
@@ -89,7 +89,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
results = [
|
||||
r
|
||||
async for r in self.processor_mm.process_response_chat(
|
||||
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
|
||||
request_outputs, stream=True, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
@@ -116,7 +116,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
results = [
|
||||
r
|
||||
async for r in self.processor_mm.process_response_chat(
|
||||
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
|
||||
request_outputs, stream=True, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
@@ -134,7 +134,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
||||
results = [
|
||||
r
|
||||
async for r in self.processor_mm.process_response_chat(
|
||||
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
|
||||
request_outputs, stream=False, include_stop_str_in_output=False
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -41,35 +41,6 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
def test_enable_thinking(self):
|
||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={})
|
||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
||||
self.assertEqual(enable_thinking, None)
|
||||
|
||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"enable_thinking": True})
|
||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
||||
self.assertEqual(enable_thinking, True)
|
||||
|
||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"enable_thinking": False})
|
||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
||||
self.assertEqual(enable_thinking, False)
|
||||
|
||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "close"}})
|
||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
||||
self.assertEqual(enable_thinking, False)
|
||||
|
||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "false"}})
|
||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
||||
self.assertEqual(enable_thinking, False)
|
||||
|
||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "open"}})
|
||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
||||
self.assertEqual(enable_thinking, True)
|
||||
|
||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "123"}})
|
||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
||||
self.assertEqual(enable_thinking, True)
|
||||
|
||||
def test_build_prompt_logprobs_basic(self):
|
||||
"""Test basic functionality of _build_prompt_logprobs"""
|
||||
# Create mock data
|
||||
|
||||
@@ -52,33 +52,12 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
self.assertTrue(result.tools_called)
|
||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_extract_tool_calls_partial_arguments(self):
|
||||
"""Test partial extraction when arguments incomplete"""
|
||||
output = '<tool_call>{"name": "get_weather", "arguments": {"location": "北"</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_extract_tool_calls_invalid_response_before_toolcall(self):
|
||||
"""Test case where <response> before <tool_call> is invalid"""
|
||||
output = '<response>hello</response><tool_call>{"name": "get_weather", "arguments": {}}</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
self.assertIn("<response>", result.content)
|
||||
|
||||
def test_extract_tool_calls_no_toolcall(self):
|
||||
"""Test when no tool_call tags are present"""
|
||||
output = "no tool call here"
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
|
||||
def test_extract_tool_calls_invalid_json(self):
|
||||
"""Test tool_call with badly formatted JSON triggers fallback parser"""
|
||||
output = '<tool_call>"name": "get_weather", "arguments": {</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_extract_tool_calls_exception(self):
|
||||
"""Force exception to cover error branch"""
|
||||
with patch(
|
||||
|
||||
@@ -89,6 +89,7 @@ class ErnieX1ReasoningParser:
|
||||
previous_token_ids,
|
||||
all_token_ids,
|
||||
delta_token_ids,
|
||||
model_status,
|
||||
):
|
||||
"""Return a simple object with reasoning_content to cover reasoning branch."""
|
||||
|
||||
@@ -161,6 +162,7 @@ class TestErnie4_5Processor(unittest.TestCase):
|
||||
tool_cls = MockToolParser if tool else None
|
||||
proc = Ernie4_5Processor("dummy-model", reasoning_parser_obj=reasoning_cls, tool_parser_obj=tool_cls)
|
||||
proc._apply_default_parameters = lambda req: req
|
||||
proc.model_status_dict = {"req-1": "think_start"}
|
||||
return proc
|
||||
|
||||
def test_update_bad_words(self):
|
||||
|
||||
@@ -20,6 +20,11 @@ from unittest.mock import MagicMock, patch
|
||||
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
|
||||
|
||||
|
||||
class MockReasoningParser:
|
||||
def get_model_status(self, prompt_token_ids):
|
||||
return "think_start"
|
||||
|
||||
|
||||
class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# 创建 Ernie4_5Processor 实例的模拟对象
|
||||
@@ -30,12 +35,13 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
# 设置必要的属性
|
||||
self.processor.tokenizer = MagicMock()
|
||||
self.processor.tokenizer.eos_token_id = 1
|
||||
self.processor.decode_status = {}
|
||||
self.processor.decode_status = {"test": []}
|
||||
self.processor.reasoning_end_dict = {}
|
||||
self.processor.tool_parser_dict = {}
|
||||
self.processor.generation_config = MagicMock()
|
||||
self.processor.eos_token_ids = [1]
|
||||
self.processor.reasoning_parser = MagicMock()
|
||||
self.processor.reasoning_parser = MockReasoningParser()
|
||||
self.processor.model_status_dict = {"request-id_0": "think_start", "test": "think_start"}
|
||||
|
||||
# 模拟 ids2tokens 方法
|
||||
def mock_ids2tokens(token_ids, task_id):
|
||||
@@ -72,7 +78,7 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
def test_process_response_dict_streaming_normal_case(self):
|
||||
"""测试正常情况下的流式响应处理"""
|
||||
# 准备输入
|
||||
response_dict = {"finished": False, "request_id": "req1", "outputs": {"token_ids": [4, 5]}}
|
||||
response_dict = {"finished": False, "request_id": "test", "outputs": {"token_ids": [4, 5]}}
|
||||
kwargs = {"enable_thinking": True}
|
||||
|
||||
# 调用方法
|
||||
@@ -83,6 +89,7 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
|
||||
def test_process_request_dict(self):
|
||||
request_dict = {
|
||||
"request_id": "123",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"chat_template_kwargs": {"chat_template": "Hello!"},
|
||||
"eos_token_ids": [1],
|
||||
@@ -118,6 +125,31 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
self.assertEqual(result["outputs"]["text"], "Mock final text")
|
||||
self.assertIn("completion_tokens", result["outputs"])
|
||||
|
||||
def test_think_status(self):
|
||||
"""测试 思考机制"""
|
||||
request = {
|
||||
"prompt": "hello",
|
||||
"request_id": "test_1",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
self.processor.reasoning_parser = MagicMock()
|
||||
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
|
||||
self.processor.model_status_dict = {}
|
||||
self.processor.process_request_dict(request, max_model_len=512)
|
||||
self.assertEqual(request["enable_thinking"], True)
|
||||
|
||||
request = {
|
||||
"prompt": "hello",
|
||||
"request_id": "test",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
self.processor.process_request_dict(request, max_model_len=512)
|
||||
self.assertEqual(request["enable_thinking"], True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -29,7 +29,12 @@ from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
|
||||
from fastdeploy.input.utils import IDS_TYPE_FLAG
|
||||
|
||||
|
||||
class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
class MockReasoningParser:
|
||||
def get_model_status(self, prompt_token_ids):
|
||||
return "think_start"
|
||||
|
||||
|
||||
class TestErnie4_5VLProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create mock object for Ernie4_5Processor instance
|
||||
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None) as mock_init:
|
||||
@@ -39,38 +44,41 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
# Set necessary attributes
|
||||
self.processor.tokenizer = MagicMock()
|
||||
self.processor.tokenizer.eos_token_id = 1
|
||||
self.processor.decode_status = {}
|
||||
self.processor.decode_status = {"test": []}
|
||||
self.processor.reasoning_end_dict = {}
|
||||
self.processor.tool_parser_dict = {}
|
||||
self.processor.generation_config = MagicMock()
|
||||
self.processor.eos_token_ids = [1]
|
||||
self.processor.reasoning_parser = MagicMock()
|
||||
self.processor._check_mm_limits = MagicMock()
|
||||
self.processor.reasoning_parser = MockReasoningParser()
|
||||
self.processor.model_status_dict = {"test": "think_start"}
|
||||
self.processor.ernie4_5_processor = MagicMock()
|
||||
self.processor.pack_outputs = MagicMock()
|
||||
|
||||
# Mock ids2tokens method
|
||||
def mock_ids2tokens(token_ids, task_id):
|
||||
self.processor.decode_status[task_id] = "mock_decode_status"
|
||||
return "delta_text", [2, 3], "previous_texts"
|
||||
|
||||
self.processor.ids2tokens = mock_ids2tokens
|
||||
|
||||
def mock_messages2ids(request, **kwargs):
|
||||
if "chat_template" in kwargs:
|
||||
return [1]
|
||||
else:
|
||||
return [0]
|
||||
def mock_request2ids(request, **kwargs):
|
||||
return {"input_ids": np.array([1, 2, 3]), "prompt_token_ids": [0]}
|
||||
|
||||
def mock_check_mm_limits(item):
|
||||
pass
|
||||
|
||||
def mock_apply_default_parameters(request):
|
||||
return request
|
||||
|
||||
def mock_pack_outputs(outputs):
|
||||
return outputs
|
||||
|
||||
self.processor._apply_default_parameters = mock_apply_default_parameters
|
||||
self.processor._check_mm_limits = mock_check_mm_limits
|
||||
self.processor.ernie4_5_processor.request2ids = mock_request2ids
|
||||
self.processor.pack_outputs = mock_pack_outputs
|
||||
|
||||
# Mock reasoning parser
|
||||
self.mock_reasoning_parser = MagicMock()
|
||||
self.mock_reasoning_parser.__class__.__name__ = "ErnieX1ReasoningParser"
|
||||
# self.mock_reasoning_parser.extract_reasoning_content_streaming.return_value = ("reasoning", "text")
|
||||
self.mock_reasoning_parser.extract_reasoning_content_streaming.return_value = None
|
||||
self.processor.reasoning_parser = self.mock_reasoning_parser
|
||||
|
||||
# Mock tool parser
|
||||
@@ -80,82 +88,26 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
self.mock_tool_parser_obj.return_value = self.mock_tool_parser
|
||||
self.processor.tool_parser_obj = self.mock_tool_parser_obj
|
||||
|
||||
def test_process_request_dict_with_options(self):
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
def test_think_status(self):
|
||||
"""测试 思考机制"""
|
||||
request = {
|
||||
"prompt": "hello",
|
||||
"request_id": "test_1",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
self.processor.reasoning_parser = MagicMock()
|
||||
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
|
||||
self.processor.model_status_dict = {}
|
||||
self.processor.process_request_dict(request, max_model_len=512)
|
||||
self.assertEqual(request["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
request = {
|
||||
"prompt": "hello",
|
||||
"request_id": "test",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"options": {"thinking_mode": "open"}},
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"options": {"thinking_mode": "close"}},
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"options": {"thinking_mode": "123"}},
|
||||
"prompt_token_ids": [1, 1, 1],
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], True)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"options": {"thinking_mode": "close"}},
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], False)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], False)
|
||||
|
||||
request_dict = {
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
self.processor.process_request_dict(request_dict, 100)
|
||||
self.assertEqual(request_dict["enable_thinking"], False)
|
||||
self.processor.process_request_dict(request, max_model_len=512)
|
||||
self.assertEqual(request["enable_thinking"], True)
|
||||
|
||||
|
||||
class TestDataProcessorTargetMethods(unittest.TestCase):
|
||||
|
||||
@@ -793,6 +793,8 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
|
||||
self.processor.processor = MagicMock()
|
||||
self.processor.limit_mm_per_prompt = {"image": 1, "video": 1, "audio": 1}
|
||||
self.processor.eos_token_ids = [1]
|
||||
self.processor.reasoning_parser = None
|
||||
self.processor.model_status_dict = {}
|
||||
|
||||
# 模拟 _apply_default_parameters
|
||||
def mock_apply_default_parameters(request_or_dict):
|
||||
@@ -971,6 +973,7 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
|
||||
"prompt": "test prompt",
|
||||
"multimodal_data": {"image": ["image1"]},
|
||||
"metadata": {"generated_token_ids": []},
|
||||
"request_id": "test-request",
|
||||
}
|
||||
request_obj.to_dict.return_value = request_dict
|
||||
|
||||
@@ -1157,6 +1160,27 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
|
||||
self.assertTrue(np.array_equal(result["image_type_ids"], np.array([0])))
|
||||
self.assertTrue(np.array_equal(result["position_ids"], np.array([[0], [1], [2]], dtype=np.int64)))
|
||||
|
||||
def test_think_status(self):
|
||||
"""测试 思考机制"""
|
||||
request = {
|
||||
"prompt": "hello",
|
||||
"request_id": "test_1",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
}
|
||||
self.processor.reasoning_parser = MagicMock()
|
||||
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
|
||||
self.processor.model_status_dict = {}
|
||||
self.processor.process_request_dict(request, max_model_len=512)
|
||||
self.assertEqual(request["enable_thinking"], True)
|
||||
|
||||
request = {
|
||||
"prompt": "hello",
|
||||
"request_id": "test",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
}
|
||||
self.processor.process_request_dict(request, max_model_len=512)
|
||||
self.assertEqual(request["enable_thinking"], True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -365,6 +365,31 @@ class TestQwenVLProcessor(unittest.TestCase):
|
||||
# Verify both methods produce identical prompt strings
|
||||
self.assertEqual(prompt, prompt2)
|
||||
|
||||
def test_think_status(self):
|
||||
"""测试 思考机制"""
|
||||
request = {
|
||||
"prompt": "hello",
|
||||
"request_id": "test_1",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
self.processor.reasoning_parser = MagicMock()
|
||||
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
|
||||
self.processor.model_status_dict = {}
|
||||
self.processor.process_request_dict(request, max_model_len=512)
|
||||
self.assertEqual(request["enable_thinking"], True)
|
||||
|
||||
request = {
|
||||
"prompt": "hello",
|
||||
"request_id": "test",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
self.processor.process_request_dict(request, max_model_len=512)
|
||||
self.assertEqual(request["enable_thinking"], True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -266,7 +266,7 @@ class DataProcessorTestCase(unittest.TestCase):
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def extract_reasoning_content(self, full_text, response_dict):
|
||||
def extract_reasoning_content(self, full_text, response_dict, model_status):
|
||||
return reasoning_content, f"{full_text}!"
|
||||
|
||||
return DummyReasoning(tokenizer)
|
||||
@@ -409,6 +409,7 @@ class DataProcessorTestCase(unittest.TestCase):
|
||||
|
||||
def test_process_response_with_reasoning_and_tools(self):
|
||||
processor = self.processor
|
||||
processor.model_status_dict = {"resp": "normal"}
|
||||
|
||||
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer)
|
||||
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-only")
|
||||
@@ -435,7 +436,7 @@ class DataProcessorTestCase(unittest.TestCase):
|
||||
|
||||
def test_process_response_dict_normal_with_reasoning(self):
|
||||
processor = self.processor
|
||||
|
||||
processor.model_status_dict = {"normal": "normal"}
|
||||
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer, reasoning_content="because")
|
||||
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-text")
|
||||
|
||||
@@ -471,10 +472,10 @@ class DataProcessorTestCase(unittest.TestCase):
|
||||
self.addCleanup(lambda: setattr(processor, "process_response_dict_normal", original_normal))
|
||||
|
||||
response = {"outputs": {}, "finished": False, "request_id": "req"}
|
||||
self.assertEqual(processor.process_response_dict(response), "stream")
|
||||
self.assertEqual(processor.process_response_dict(response, stream=True, enable_thinking=True), "stream")
|
||||
self.assertTrue(calls["stream"]["enable_thinking"])
|
||||
self.assertEqual(
|
||||
processor.process_response_dict(response, stream=False, enable_thinking=None),
|
||||
processor.process_response_dict(response, stream=False, enable_thinking=True),
|
||||
"normal",
|
||||
)
|
||||
self.assertTrue(calls["normal"]["enable_thinking"])
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
import unittest
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from fastdeploy.reasoning.qwen3_reasoning_parsers import Qwen3ReasoningParser
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
"""Minimal tokenizer with vocab for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.vocab = {
|
||||
"<think>": 100,
|
||||
"</think>": 101,
|
||||
}
|
||||
|
||||
def get_vocab(self):
|
||||
"""Return vocab dict for testing."""
|
||||
return self.vocab
|
||||
|
||||
|
||||
class MissingTokenTokenizer:
|
||||
def __init__(self):
|
||||
self.vocab = {
|
||||
"</think>": 100,
|
||||
}
|
||||
|
||||
def get_vocab(self):
|
||||
"""Return vocab dict for testing."""
|
||||
return self.vocab
|
||||
|
||||
|
||||
class TestQwen3ReasoningParser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.parser = Qwen3ReasoningParser(MockTokenizer())
|
||||
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
|
||||
self.tokenizer = MockTokenizer()
|
||||
|
||||
def test_missing_token(self):
|
||||
with self.assertRaises(RuntimeError) as context:
|
||||
Qwen3ReasoningParser(MissingTokenTokenizer())
|
||||
exception_message = str(context.exception)
|
||||
expected_message_part = "Qwen3 reasoning parser could not find the following token ids"
|
||||
self.assertIn(expected_message_part, exception_message)
|
||||
|
||||
def test_get_model_status(self):
|
||||
status = self.parser.get_model_status([1, 2, 100])
|
||||
self.assertEqual(status, "think_start")
|
||||
status = self.parser.get_model_status([1, 2, 101])
|
||||
self.assertEqual(status, "think_end")
|
||||
status = self.parser.get_model_status([1])
|
||||
self.assertEqual(status, "think_start")
|
||||
|
||||
def test_streaming_thinking_content(self):
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="a",
|
||||
delta_text="a",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[200],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.reasoning_content, "a")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="a</think>b",
|
||||
delta_text="a</think>b",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[99, 101, 102],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.reasoning_content, "a")
|
||||
self.assertEqual(msg.content, "b")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="a</think>",
|
||||
current_text="a</think>b",
|
||||
delta_text="b",
|
||||
previous_token_ids=[1, 101],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[102],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "b")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="a",
|
||||
delta_text="a",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.reasoning_content, "a")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="a",
|
||||
delta_text="a",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[200],
|
||||
model_status="think_end",
|
||||
)
|
||||
self.assertEqual(msg.content, "a")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello",
|
||||
current_text="hello</think>hi",
|
||||
delta_text="</think>hi",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[101, 200],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "hi")
|
||||
self.assertEqual(msg.reasoning_content, "")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello",
|
||||
current_text="hello</think>hi",
|
||||
delta_text="hi",
|
||||
previous_token_ids=[100],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, None)
|
||||
self.assertEqual(msg.reasoning_content, "hi")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello",
|
||||
current_text="hello</think>hi",
|
||||
delta_text="<think>hi",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[100, 200],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "")
|
||||
self.assertEqual(msg.reasoning_content, "hi")
|
||||
|
||||
def test_none_streaming_thinking_content(self):
|
||||
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||
model_output="a",
|
||||
request={},
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning_content, None)
|
||||
self.assertEqual(content, "a")
|
||||
|
||||
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||
model_output="a</think>b",
|
||||
request={},
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning_content, "a")
|
||||
self.assertEqual(content, "b")
|
||||
|
||||
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||
model_output="a",
|
||||
request={},
|
||||
model_status="think_end",
|
||||
)
|
||||
self.assertEqual(reasoning_content, None)
|
||||
self.assertEqual(content, "a")
|
||||
|
||||
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||
model_output="<think>a",
|
||||
request={},
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning_content, None)
|
||||
self.assertEqual(content, "<think>a")
|
||||
|
||||
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||
model_output="<think>a</think>b",
|
||||
request={},
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning_content, "a")
|
||||
self.assertEqual(content, "b")
|
||||
|
||||
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||
model_output="</think>b",
|
||||
request={},
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning_content, "")
|
||||
self.assertEqual(content, "b")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -31,10 +31,25 @@ class DummyTokenizer:
|
||||
def __init__(self):
|
||||
self.vocab = {
|
||||
"</think>": 100,
|
||||
"<tool_call>": 101,
|
||||
"</tool_call>": 102,
|
||||
"<response>": 103,
|
||||
"</response>": 104,
|
||||
"<think>": 101,
|
||||
"<tool_call>": 102,
|
||||
"</tool_call>": 103,
|
||||
"<response>": 104,
|
||||
"</response>": 105,
|
||||
}
|
||||
|
||||
def get_vocab(self):
|
||||
"""Return vocab dict for testing."""
|
||||
return self.vocab
|
||||
|
||||
|
||||
class MissingTokenTokenizer:
|
||||
def __init__(self):
|
||||
self.vocab = {
|
||||
"</think>": 100,
|
||||
"<think>": 101,
|
||||
"<tool_call>": 102,
|
||||
"</tool_call>": 103,
|
||||
}
|
||||
|
||||
def get_vocab(self):
|
||||
@@ -132,6 +147,17 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
|
||||
self.tokenizer = DummyTokenizer()
|
||||
|
||||
def test_missing_token(self):
|
||||
with self.assertRaises(RuntimeError) as context:
|
||||
ErnieX1ReasoningParser(MissingTokenTokenizer())
|
||||
exception_message = str(context.exception)
|
||||
expected_message_part = "ernie x1 reasoning parser could not find the following token ids"
|
||||
self.assertIn(expected_message_part, exception_message)
|
||||
|
||||
def test_get_model_status(self):
|
||||
model_status = self.parser.get_model_status([88, 99, 104])
|
||||
self.assertEqual(model_status, "response_start")
|
||||
|
||||
# ---- Streaming parsing ----
|
||||
def test_streaming_thinking_content(self):
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
@@ -141,6 +167,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[200],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.reasoning_content, "a")
|
||||
|
||||
@@ -152,6 +179,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[201],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.reasoning_content, "\n")
|
||||
|
||||
@@ -163,6 +191,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[self.parser.think_end_token_id],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsNone(msg)
|
||||
|
||||
@@ -174,6 +203,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[202],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "h")
|
||||
|
||||
@@ -185,6 +215,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[203],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "\n")
|
||||
|
||||
@@ -197,6 +228,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[self.parser.vocab["<response>"]],
|
||||
model_status="think_start",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -207,6 +239,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[204],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(msg, DeltaMessage)
|
||||
self.assertEqual(msg.content, "\n")
|
||||
@@ -219,9 +252,82 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[self.parser.vocab["</response>"]],
|
||||
model_status="think_start",
|
||||
)
|
||||
)
|
||||
|
||||
def test_extract_reasoning_content_streaming(self):
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello</think>",
|
||||
current_text="hello</think><response>",
|
||||
delta_text="</think><response>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[100, 200],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "")
|
||||
self.assertEqual(msg.reasoning_content, "")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello</think>",
|
||||
current_text="hello</think><response>hi",
|
||||
delta_text="</think><response>hi",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[100, 200],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "hi")
|
||||
self.assertEqual(msg.reasoning_content, "")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="hello</think><response>hi",
|
||||
delta_text="hello</think><response>hi",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[100, 200],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "hi")
|
||||
self.assertEqual(msg.reasoning_content, "hello")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello</think><response>",
|
||||
current_text="hello</think><response>hi",
|
||||
delta_text="hi",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[100, 200],
|
||||
model_status="think_end",
|
||||
)
|
||||
self.assertEqual(msg.content, "hi")
|
||||
self.assertEqual(msg.reasoning_content, None)
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello</think><response>",
|
||||
current_text="hello</think><response>hi",
|
||||
delta_text="hi",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[100, 200],
|
||||
model_status="response_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "hi")
|
||||
self.assertEqual(msg.reasoning_content, None)
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello</think><response>hi</response>",
|
||||
current_text="hello</think><response>hi</response>end",
|
||||
delta_text="end",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[100, 200],
|
||||
model_status="response_start",
|
||||
)
|
||||
self.assertEqual(msg, None)
|
||||
|
||||
def test_streaming_tool_call(self):
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="</think>",
|
||||
@@ -230,40 +336,48 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[self.parser.vocab["<tool_call>"]],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsNone(msg)
|
||||
|
||||
# ---- Batch parsing ----
|
||||
def test_batch_reasoning_and_response(self):
|
||||
text = "abc\n</think>\n<response>hello\nworld</response>"
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||
self.assertEqual(reasoning, "abc\n")
|
||||
self.assertEqual(response, "hello\nworld")
|
||||
|
||||
def test_batch_reasoning_and_tool_call(self):
|
||||
text = "abc</think><tool_call>call_here"
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||
self.assertEqual(reasoning, "abc")
|
||||
self.assertEqual(response, "")
|
||||
|
||||
def test_batch_no_thinking_tag(self):
|
||||
text = "no_thinking_here"
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||
self.assertEqual(reasoning, "no_thinking_here")
|
||||
self.assertEqual(response, "")
|
||||
|
||||
def test_batch_response_without_end_tag(self):
|
||||
text = "abc</think><response>partial response"
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||
self.assertEqual(reasoning, "abc")
|
||||
self.assertEqual(response, "partial response")
|
||||
|
||||
def test_batch_preserve_all_newlines(self):
|
||||
text = "abc\n</think>\n<response>line1\nline2\n</response>"
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||
self.assertEqual(reasoning, "abc\n")
|
||||
self.assertEqual(response, "line1\nline2\n")
|
||||
|
||||
def test_extract_reasoning_content(self):
|
||||
reasoning_content, response_content = self.parser.extract_reasoning_content(
|
||||
model_output="hello", request=self.request, model_status="response_start"
|
||||
)
|
||||
self.assertEqual(reasoning_content, "")
|
||||
self.assertEqual(response_content, "hello")
|
||||
|
||||
|
||||
class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -272,6 +386,9 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
self.test_request = ChatCompletionRequest(
|
||||
model="ernie-test", messages=[{"role": "user", "content": "test prompt"}]
|
||||
)
|
||||
self.parser.token_status_mapping = {
|
||||
100: "think_start",
|
||||
}
|
||||
|
||||
def test_streaming_non_reasoning(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
@@ -281,6 +398,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[200],
|
||||
delta_token_ids=[200],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.reasoning_content, "a")
|
||||
@@ -294,6 +412,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[200, 201],
|
||||
current_token_ids=[200, 201, 100],
|
||||
delta_token_ids=[100],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
|
||||
@@ -305,6 +424,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[200, 201],
|
||||
current_token_ids=[200, 201, 100, 300, 400],
|
||||
delta_token_ids=[100, 300, 400],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertIsNone(result.reasoning_content)
|
||||
@@ -318,6 +438,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 100],
|
||||
delta_token_ids=[100],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
|
||||
@@ -329,9 +450,10 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 100, 200, 101],
|
||||
delta_token_ids=[100, 200, 101],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.reasoning_content, "")
|
||||
self.assertEqual(result.reasoning_content, None)
|
||||
|
||||
def test_streaming_with_reasoning_and_illegal_tool(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
@@ -341,6 +463,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 100, 200, 101],
|
||||
delta_token_ids=[109, 200, 101],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.content, "\n\nhello<tool_call>")
|
||||
@@ -353,6 +476,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 100, 200, 110],
|
||||
delta_token_ids=[100, 200, 110],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.reasoning_content, "hello")
|
||||
@@ -366,6 +490,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[100],
|
||||
current_token_ids=[100, 110, 111],
|
||||
delta_token_ids=[110, 111],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertIsNone(result.reasoning_content)
|
||||
@@ -379,52 +504,140 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[101],
|
||||
current_token_ids=[101, 110],
|
||||
delta_token_ids=[110],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.reasoning_content, "hello")
|
||||
|
||||
def test_think_end_status_streaming(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="<tool_call>",
|
||||
current_text="<tool_call>hello",
|
||||
delta_text="hello",
|
||||
previous_token_ids=[101],
|
||||
current_token_ids=[101, 110],
|
||||
delta_token_ids=[110],
|
||||
model_status="think_end",
|
||||
)
|
||||
self.assertIs(result, None)
|
||||
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello, ",
|
||||
current_text="hello, hi",
|
||||
delta_text="hi",
|
||||
previous_token_ids=[101],
|
||||
current_token_ids=[101, 110],
|
||||
delta_token_ids=[110],
|
||||
model_status="think_end",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.content, "hi")
|
||||
|
||||
def test_other_status_streaming(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="hello, ",
|
||||
current_text="hello, hi",
|
||||
delta_text="hi",
|
||||
previous_token_ids=[101],
|
||||
current_token_ids=[101, 110],
|
||||
delta_token_ids=[110],
|
||||
model_status="tool_call_start",
|
||||
)
|
||||
self.assertIs(result, None)
|
||||
|
||||
def test_batch_no_think_end(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="direct response", request=self.test_request
|
||||
model_output="direct response", request=self.test_request, model_status="think_start"
|
||||
)
|
||||
self.assertEqual(reasoning, "direct response")
|
||||
self.assertEqual(content, "")
|
||||
|
||||
def test_batch_no_think_end_with_tool(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="direct response<tool_call>abc", request=self.test_request
|
||||
model_output="direct response<tool_call>abc", request=self.test_request, model_status="think_start"
|
||||
)
|
||||
self.assertEqual(reasoning, "direct response<tool_call>abc")
|
||||
self.assertEqual(content, "")
|
||||
|
||||
def test_batch_think_end_normal_content(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="reasoning</think>\nresponse", request=self.test_request
|
||||
model_output="reasoning</think>\nresponse", request=self.test_request, model_status="think_start"
|
||||
)
|
||||
self.assertEqual(reasoning, "reasoning")
|
||||
self.assertEqual(content, "\nresponse")
|
||||
|
||||
def test_batch_think_end_with_tool(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="reasoning</think>\n<tool_call>tool params</tool_call>", request=self.test_request
|
||||
model_output="reasoning</think>\n<tool_call>tool params</tool_call>",
|
||||
request=self.test_request,
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning, "reasoning")
|
||||
self.assertEqual(content, "")
|
||||
|
||||
def test_batch_think_end_with_illegal_tool(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="reasoning</think>\nABC\n<tool_call>tool params</tool_call>", request=self.test_request
|
||||
model_output="reasoning</think>\nABC\n<tool_call>tool params</tool_call>",
|
||||
request=self.test_request,
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning, "reasoning")
|
||||
self.assertEqual(content, "\nABC\n<tool_call>tool params</tool_call>")
|
||||
|
||||
def test_batch_think_end_content_with_newline(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="reasoning</think>\n\n actual response", request=self.test_request
|
||||
model_output="reasoning</think>\n\n actual response",
|
||||
request=self.test_request,
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning, "reasoning")
|
||||
self.assertEqual(content, "\n\n actual response")
|
||||
|
||||
def test_think_end_status_non_streaming(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="response", request=self.test_request, model_status="think_end"
|
||||
)
|
||||
self.assertEqual(reasoning, "")
|
||||
self.assertEqual(content, "response")
|
||||
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="<tool_call>response", request=self.test_request, model_status="think_end"
|
||||
)
|
||||
self.assertEqual(reasoning, "")
|
||||
self.assertEqual(content, "")
|
||||
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="\n 1<tool_call>response", request=self.test_request, model_status="think_end"
|
||||
)
|
||||
self.assertEqual(reasoning, "")
|
||||
self.assertEqual(content, "\n 1<tool_call>response")
|
||||
|
||||
def test_other_status_non_streaming(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="response", request=self.test_request, model_status="tool_call_start"
|
||||
)
|
||||
self.assertEqual(reasoning, "")
|
||||
self.assertEqual(content, "")
|
||||
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="response", request=self.test_request, model_status="tool_call_end"
|
||||
)
|
||||
self.assertEqual(reasoning, "")
|
||||
self.assertEqual(content, "")
|
||||
|
||||
def test_find_last_special_token(self):
|
||||
result = self.parser.find_last_special_token([100, 110, 120, 130])
|
||||
self.assertEqual(result, 100)
|
||||
result = self.parser.find_last_special_token([0])
|
||||
self.assertEqual(result, -1)
|
||||
|
||||
def test_get_model_status(self):
|
||||
result = self.parser.get_model_status([100, 110, 120, 130])
|
||||
self.assertEqual(result, "think_start")
|
||||
|
||||
result = self.parser.get_model_status([0])
|
||||
self.assertEqual(result, "think_start")
|
||||
|
||||
|
||||
class TestErnieVLReasoningParser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -442,6 +655,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
|
||||
delta_token_ids=[100, 110, 120, 130],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.reasoning_content, "")
|
||||
@@ -455,6 +669,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[200, 201, 202, 100],
|
||||
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
|
||||
delta_token_ids=[110, 120, 130],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertIsNone(result.reasoning_content)
|
||||
@@ -468,6 +683,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 110, 120, 130],
|
||||
delta_token_ids=[110, 120, 130],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertIsNone(result.content)
|
||||
@@ -475,7 +691,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
|
||||
|
||||
def test_extract_reasoning_content(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="reasoning</think>\nactual response", request=self.test_request
|
||||
model_output="reasoning</think>\nactual response", request=self.test_request, model_status="think_start"
|
||||
)
|
||||
self.assertEqual(reasoning, "reasoning")
|
||||
self.assertEqual(content, "\nactual response")
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from fastdeploy.reasoning.ernie_vl_reasoning_parsers import ErnieVLReasoningParser
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
"""Minimal tokenizer with vocab for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.vocab = {
|
||||
"<think>": 100,
|
||||
"</think>": 101,
|
||||
}
|
||||
|
||||
def get_vocab(self):
|
||||
"""Return vocab dict for testing."""
|
||||
return self.vocab
|
||||
|
||||
|
||||
class TestErnieVLReasoningParser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.parser = ErnieVLReasoningParser(MockTokenizer())
|
||||
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
|
||||
self.tokenizer = MockTokenizer()
|
||||
|
||||
def test_get_model_status(self):
|
||||
status = self.parser.get_model_status([1, 2, 100])
|
||||
self.assertEqual(status, "think_start")
|
||||
status = self.parser.get_model_status([1, 2, 101])
|
||||
self.assertEqual(status, "think_end")
|
||||
status = self.parser.get_model_status([1])
|
||||
self.assertEqual(status, "think_start")
|
||||
|
||||
def test_streaming_thinking_content(self):
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="a",
|
||||
delta_text="a",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[200],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.reasoning_content, "a")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="a</think>b",
|
||||
delta_text="a</think>b",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[100, 101, 102],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.reasoning_content, "a")
|
||||
self.assertEqual(msg.content, "b")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="a</think>",
|
||||
current_text="a</think>b",
|
||||
delta_text="b",
|
||||
previous_token_ids=[1, 101],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[102],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.content, "b")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="a",
|
||||
delta_text="a",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(msg.reasoning_content, "a")
|
||||
|
||||
msg = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="a",
|
||||
delta_text="a",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[200],
|
||||
model_status="think_end",
|
||||
)
|
||||
self.assertEqual(msg.content, "a")
|
||||
|
||||
def test_none_streaming_thinking_content(self):
|
||||
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||
model_output="a",
|
||||
request={},
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning_content, "")
|
||||
self.assertEqual(content, "a")
|
||||
|
||||
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||
model_output="a</think>b",
|
||||
request={},
|
||||
model_status="think_start",
|
||||
)
|
||||
self.assertEqual(reasoning_content, "a")
|
||||
self.assertEqual(content, "b")
|
||||
|
||||
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||
model_output="a",
|
||||
request={},
|
||||
model_status="think_end",
|
||||
)
|
||||
self.assertEqual(reasoning_content, "")
|
||||
self.assertEqual(content, "a")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user