mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature]Optimization of Thinking Pattern Framework (#4302)
* add model status in vl
* add x1 parser
* add model_status
* fix parser
* fix parser
* fix parser
* fix parser
* Revert "fix parser"
This reverts commit 300f446d8a.
* fix parser
* fix
* fix
* fix
* fix
* fix parser
* fix unit test
* fix unit test
* add unit test
* fix
* fix
* add unit test
* fix unit test
* add unit test
* add unit test
* fix unit test
* fix unit test
* fix bug
* fix unit test
* x1 tool parser
* fix unit test
* fix unit test
* fix unit test
* fix n
* fix unit test
* add unit test
* add unit test
* remove pring
This commit is contained in:
@@ -72,13 +72,12 @@ class ChatResponseProcessor:
|
|||||||
else:
|
else:
|
||||||
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
|
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
|
||||||
|
|
||||||
async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output):
|
async def process_response_chat(self, request_outputs, stream, include_stop_str_in_output):
|
||||||
"""
|
"""
|
||||||
Process a list of responses into a generator that yields each processed response as it's generated.
|
Process a list of responses into a generator that yields each processed response as it's generated.
|
||||||
Args:
|
Args:
|
||||||
request_outputs: The list of outputs to be processed.
|
request_outputs: The list of outputs to be processed.
|
||||||
stream: Whether or not to stream the output.
|
stream: Whether or not to stream the output.
|
||||||
enable_thinking: Whether or not to show thinking messages.
|
|
||||||
include_stop_str_in_output: Whether or not to include stop strings in the output.
|
include_stop_str_in_output: Whether or not to include stop strings in the output.
|
||||||
"""
|
"""
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
@@ -99,7 +98,6 @@ class ChatResponseProcessor:
|
|||||||
response = await self.data_processor.process_response_dict(
|
response = await self.data_processor.process_response_dict(
|
||||||
response_dict=request_output,
|
response_dict=request_output,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
audio_tokens=all_audio_tokens,
|
audio_tokens=all_audio_tokens,
|
||||||
tts=tts,
|
tts=tts,
|
||||||
@@ -108,7 +106,6 @@ class ChatResponseProcessor:
|
|||||||
response = self.data_processor.process_response_dict(
|
response = self.data_processor.process_response_dict(
|
||||||
response_dict=request_output,
|
response_dict=request_output,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
audio_tokens=all_audio_tokens,
|
audio_tokens=all_audio_tokens,
|
||||||
tts=tts,
|
tts=tts,
|
||||||
@@ -127,7 +124,6 @@ class ChatResponseProcessor:
|
|||||||
yield self.data_processor.process_response_dict(
|
yield self.data_processor.process_response_dict(
|
||||||
response_dict=request_output,
|
response_dict=request_output,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
)
|
)
|
||||||
elif stream:
|
elif stream:
|
||||||
@@ -156,14 +152,12 @@ class ChatResponseProcessor:
|
|||||||
await self.data_processor.process_response_dict(
|
await self.data_processor.process_response_dict(
|
||||||
response_dict=request_output,
|
response_dict=request_output,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.data_processor.process_response_dict(
|
self.data_processor.process_response_dict(
|
||||||
response_dict=request_output,
|
response_dict=request_output,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
)
|
)
|
||||||
text = {"type": "text", "text": request_output["outputs"]["text"]}
|
text = {"type": "text", "text": request_output["outputs"]["text"]}
|
||||||
@@ -185,14 +179,12 @@ class ChatResponseProcessor:
|
|||||||
await self.data_processor.process_response_dict(
|
await self.data_processor.process_response_dict(
|
||||||
response_dict=part["request_output"],
|
response_dict=part["request_output"],
|
||||||
stream=False,
|
stream=False,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.data_processor.process_response_dict(
|
self.data_processor.process_response_dict(
|
||||||
response_dict=request_output,
|
response_dict=request_output,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
)
|
)
|
||||||
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
|
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
|
||||||
|
|||||||
@@ -220,8 +220,6 @@ class OpenAIServingChat:
|
|||||||
|
|
||||||
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
|
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
|
||||||
|
|
||||||
enable_thinking = self._get_thinking_status(request)
|
|
||||||
|
|
||||||
include_stop_str_in_output = request.include_stop_str_in_output
|
include_stop_str_in_output = request.include_stop_str_in_output
|
||||||
|
|
||||||
stream_options = request.stream_options
|
stream_options = request.stream_options
|
||||||
@@ -277,7 +275,6 @@ class OpenAIServingChat:
|
|||||||
generator = response_processor.process_response_chat(
|
generator = response_processor.process_response_chat(
|
||||||
response,
|
response,
|
||||||
stream=True,
|
stream=True,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -507,7 +504,6 @@ class OpenAIServingChat:
|
|||||||
"""
|
"""
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
num_choices = 1 if request.n is None else request.n
|
num_choices = 1 if request.n is None else request.n
|
||||||
enable_thinking = self._get_thinking_status(request)
|
|
||||||
|
|
||||||
include_stop_str_in_output = request.include_stop_str_in_output
|
include_stop_str_in_output = request.include_stop_str_in_output
|
||||||
try:
|
try:
|
||||||
@@ -562,7 +558,6 @@ class OpenAIServingChat:
|
|||||||
generator = response_processor.process_response_chat(
|
generator = response_processor.process_response_chat(
|
||||||
response,
|
response,
|
||||||
stream=False,
|
stream=False,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
)
|
)
|
||||||
async for data in generator:
|
async for data in generator:
|
||||||
@@ -824,23 +819,6 @@ class OpenAIServingChat:
|
|||||||
api_server_logger.error(error_msg)
|
api_server_logger.error(error_msg)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_thinking_status(self, request: ChatCompletionRequest) -> bool:
|
|
||||||
"""
|
|
||||||
Get the thinking status from the request.
|
|
||||||
"""
|
|
||||||
enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None
|
|
||||||
if enable_thinking is None:
|
|
||||||
enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None
|
|
||||||
options = request.chat_template_kwargs.get("options") if request.chat_template_kwargs else None
|
|
||||||
if options:
|
|
||||||
thinking_mode = options.get("thinking_mode")
|
|
||||||
if thinking_mode:
|
|
||||||
if thinking_mode == "close" or thinking_mode == "false":
|
|
||||||
enable_thinking = False
|
|
||||||
else:
|
|
||||||
enable_thinking = True
|
|
||||||
return enable_thinking
|
|
||||||
|
|
||||||
def _build_prompt_logprobs(
|
def _build_prompt_logprobs(
|
||||||
self,
|
self,
|
||||||
prompt_logprobs_tensors: LogprobsTensors,
|
prompt_logprobs_tensors: LogprobsTensors,
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
self.decode_status = dict()
|
self.decode_status = dict()
|
||||||
self.tool_parser_dict = dict()
|
self.tool_parser_dict = dict()
|
||||||
self.thinking_parser_dict = dict()
|
self.thinking_parser_dict = dict()
|
||||||
|
self.model_status_dict = dict()
|
||||||
self._load_tokenizer()
|
self._load_tokenizer()
|
||||||
data_processor_logger.info(
|
data_processor_logger.info(
|
||||||
f"tokenizer information: bos_token is {self.tokenizer.bos_token} \
|
f"tokenizer information: bos_token is {self.tokenizer.bos_token} \
|
||||||
@@ -148,8 +149,18 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
request.set("temperature", 1)
|
request.set("temperature", 1)
|
||||||
if request.get("top_p") < _SAMPLING_EPS:
|
if request.get("top_p") < _SAMPLING_EPS:
|
||||||
request.set("top_p", _SAMPLING_EPS)
|
request.set("top_p", _SAMPLING_EPS)
|
||||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
if self.reasoning_parser:
|
||||||
request.enable_thinking = True
|
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||||
|
parts = request.request_id.split("_")
|
||||||
|
if len(parts) > 1:
|
||||||
|
real_req_id = parts[0]
|
||||||
|
index = int(parts[1])
|
||||||
|
n = request.get("n", 1)
|
||||||
|
for idx in range(index * n, (index + 1) * n):
|
||||||
|
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||||
|
else:
|
||||||
|
self.model_status_dict[request.request_id] = model_status
|
||||||
|
request.enable_thinking = model_status == "think_start"
|
||||||
|
|
||||||
data_processor_logger.info(f"Processed request: {request}")
|
data_processor_logger.info(f"Processed request: {request}")
|
||||||
return request
|
return request
|
||||||
@@ -222,9 +233,19 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
request["temperature"] = 1
|
request["temperature"] = 1
|
||||||
if request.get("top_p") < _SAMPLING_EPS:
|
if request.get("top_p") < _SAMPLING_EPS:
|
||||||
request["top_p"] = _SAMPLING_EPS
|
request["top_p"] = _SAMPLING_EPS
|
||||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
|
||||||
request["enable_thinking"] = True
|
|
||||||
|
|
||||||
|
if self.reasoning_parser:
|
||||||
|
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||||
|
parts = request["request_id"].split("_")
|
||||||
|
if len(parts) > 1:
|
||||||
|
real_req_id = parts[0]
|
||||||
|
index = int(parts[1])
|
||||||
|
n = request.get("n", 1)
|
||||||
|
for idx in range(index * n, (index + 1) * n):
|
||||||
|
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||||
|
else:
|
||||||
|
self.model_status_dict[request["request_id"]] = model_status
|
||||||
|
request["enable_thinking"] = model_status == "think_start"
|
||||||
data_processor_logger.info(f"Processed request dict: {request}")
|
data_processor_logger.info(f"Processed request dict: {request}")
|
||||||
return request
|
return request
|
||||||
|
|
||||||
@@ -246,7 +267,11 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
full_text = self.tokenizer.decode(token_ids)
|
full_text = self.tokenizer.decode(token_ids)
|
||||||
if self.reasoning_parser:
|
if self.reasoning_parser:
|
||||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||||
|
full_text,
|
||||||
|
response_dict,
|
||||||
|
self.model_status_dict[req_id],
|
||||||
|
)
|
||||||
response_dict.outputs.text = text
|
response_dict.outputs.text = text
|
||||||
response_dict.outputs.reasoning_content = reasoning_content
|
response_dict.outputs.reasoning_content = reasoning_content
|
||||||
else:
|
else:
|
||||||
@@ -257,6 +282,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
if tool_call_info.tools_called:
|
if tool_call_info.tools_called:
|
||||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
||||||
response_dict.outputs.text = tool_call_info.content
|
response_dict.outputs.text = tool_call_info.content
|
||||||
|
if req_id in self.model_status_dict:
|
||||||
|
del self.model_status_dict[req_id]
|
||||||
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
||||||
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
|
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
|
||||||
return None
|
return None
|
||||||
@@ -287,7 +314,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict: response contain text fields
|
Dict: response contain text fields
|
||||||
"""
|
"""
|
||||||
enable_thinking = kwargs.get("enable_thinking")
|
|
||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
is_end = response_dict["finished"]
|
is_end = response_dict["finished"]
|
||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
@@ -297,16 +323,17 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
if is_end:
|
if is_end:
|
||||||
full_text = previous_texts + delta_text
|
full_text = previous_texts + delta_text
|
||||||
if self.reasoning_parser and (
|
response_dict["outputs"]["text"] = full_text
|
||||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
if self.reasoning_parser:
|
||||||
):
|
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
full_text,
|
||||||
|
response_dict,
|
||||||
|
self.model_status_dict[req_id],
|
||||||
|
)
|
||||||
response_dict["outputs"]["text"] = text
|
response_dict["outputs"]["text"] = text
|
||||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
||||||
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
|
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
|
||||||
else:
|
|
||||||
response_dict["outputs"]["text"] = full_text
|
|
||||||
if self.tool_parser_obj:
|
if self.tool_parser_obj:
|
||||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||||
@@ -316,6 +343,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
response_dict["outputs"]["completion_tokens"] = full_text
|
response_dict["outputs"]["completion_tokens"] = full_text
|
||||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||||
del self.decode_status[req_id]
|
del self.decode_status[req_id]
|
||||||
|
if req_id in self.model_status_dict:
|
||||||
|
del self.model_status_dict[req_id]
|
||||||
return response_dict
|
return response_dict
|
||||||
|
|
||||||
def process_response_dict_streaming(self, response_dict, **kwargs):
|
def process_response_dict_streaming(self, response_dict, **kwargs):
|
||||||
@@ -328,7 +357,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict: response contain text fields
|
Dict: response contain text fields
|
||||||
"""
|
"""
|
||||||
enable_thinking = kwargs.get("enable_thinking")
|
|
||||||
is_end = response_dict["finished"]
|
is_end = response_dict["finished"]
|
||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
@@ -338,9 +366,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
response_dict["outputs"]["completion_tokens"] = delta_text
|
response_dict["outputs"]["completion_tokens"] = delta_text
|
||||||
if self.reasoning_parser and (
|
if self.reasoning_parser:
|
||||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
|
||||||
):
|
|
||||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||||
previous_texts,
|
previous_texts,
|
||||||
previous_texts + delta_text,
|
previous_texts + delta_text,
|
||||||
@@ -348,6 +374,7 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
previous_token_ids,
|
previous_token_ids,
|
||||||
previous_token_ids + token_ids,
|
previous_token_ids + token_ids,
|
||||||
token_ids,
|
token_ids,
|
||||||
|
self.model_status_dict[req_id],
|
||||||
)
|
)
|
||||||
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
||||||
reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None
|
reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None
|
||||||
@@ -374,6 +401,8 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
del self.decode_status[req_id]
|
del self.decode_status[req_id]
|
||||||
if req_id in self.tool_parser_dict:
|
if req_id in self.tool_parser_dict:
|
||||||
del self.tool_parser_dict[req_id]
|
del self.tool_parser_dict[req_id]
|
||||||
|
if req_id in self.model_status_dict:
|
||||||
|
del self.model_status_dict[req_id]
|
||||||
return response_dict
|
return response_dict
|
||||||
|
|
||||||
def messages2ids(self, request_or_messages, **kwargs):
|
def messages2ids(self, request_or_messages, **kwargs):
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
|
|
||||||
self.tool_parser_dict = dict()
|
self.tool_parser_dict = dict()
|
||||||
self.decode_status = dict()
|
self.decode_status = dict()
|
||||||
|
self.model_status_dict = dict()
|
||||||
self._load_tokenizer()
|
self._load_tokenizer()
|
||||||
|
|
||||||
# Generation config
|
# Generation config
|
||||||
@@ -242,15 +243,6 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
request[k] = v
|
request[k] = v
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||||
options = chat_template_kwargs.get("options")
|
|
||||||
if options:
|
|
||||||
thinking_mode = options.get("thinking_mode")
|
|
||||||
if thinking_mode:
|
|
||||||
if thinking_mode == "close" or thinking_mode == "false":
|
|
||||||
request["enable_thinking"] = False
|
|
||||||
else:
|
|
||||||
request["enable_thinking"] = True
|
|
||||||
request.setdefault("enable_thinking", True)
|
|
||||||
outputs = self.ernie4_5_processor.request2ids(request)
|
outputs = self.ernie4_5_processor.request2ids(request)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
||||||
@@ -279,6 +271,18 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
|
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
|
||||||
data_processor_logger.info(f"Processed request {request}")
|
data_processor_logger.info(f"Processed request {request}")
|
||||||
|
|
||||||
|
if self.reasoning_parser:
|
||||||
|
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||||
|
parts = request["request_id"].split("_")
|
||||||
|
if len(parts) > 1:
|
||||||
|
real_req_id = parts[0]
|
||||||
|
index = int(parts[1])
|
||||||
|
n = request.get("n", 1)
|
||||||
|
for idx in range(index * n, (index + 1) * n):
|
||||||
|
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||||
|
else:
|
||||||
|
self.model_status_dict[request["request_id"]] = model_status
|
||||||
|
request["enable_thinking"] = model_status == "think_start"
|
||||||
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
||||||
request["top_p"] = _SAMPLING_EPS
|
request["top_p"] = _SAMPLING_EPS
|
||||||
|
|
||||||
@@ -314,21 +318,3 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
outs["position_ids"] = np.array(outs["position_ids"], dtype=np.int64)
|
outs["position_ids"] = np.array(outs["position_ids"], dtype=np.int64)
|
||||||
|
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
def process_response_dict(self, response_dict, stream, **kwargs):
|
|
||||||
"""
|
|
||||||
Preprocess the response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_dict (Dict): response for engine, contain ids fields
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: response contain text fields
|
|
||||||
"""
|
|
||||||
enable_thinking = kwargs.pop("enable_thinking", True)
|
|
||||||
if enable_thinking is None:
|
|
||||||
enable_thinking = True
|
|
||||||
if stream:
|
|
||||||
return self.process_response_dict_streaming(response_dict, enable_thinking=enable_thinking, **kwargs)
|
|
||||||
else:
|
|
||||||
return self.process_response_dict_normal(response_dict, enable_thinking=enable_thinking, **kwargs)
|
|
||||||
|
|||||||
@@ -254,6 +254,19 @@ class PaddleOCRVLProcessor(TextProcessor):
|
|||||||
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
|
||||||
request["top_p"] = _SAMPLING_EPS
|
request["top_p"] = _SAMPLING_EPS
|
||||||
|
|
||||||
|
if self.reasoning_parser:
|
||||||
|
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||||
|
parts = request["request_id"].split("_")
|
||||||
|
if len(parts) > 1:
|
||||||
|
real_req_id = parts[0]
|
||||||
|
index = int(parts[1])
|
||||||
|
n = request.get("n", 1)
|
||||||
|
for idx in range(index * n, (index + 1) * n):
|
||||||
|
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||||
|
else:
|
||||||
|
self.model_status_dict[request["request_id"]] = model_status
|
||||||
|
request["enable_thinking"] = model_status == "think_start"
|
||||||
|
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def append_generated_tokens(self, multimodal_inputs, generated_token_ids):
|
def append_generated_tokens(self, multimodal_inputs, generated_token_ids):
|
||||||
|
|||||||
@@ -267,6 +267,18 @@ class QwenVLProcessor(TextProcessor):
|
|||||||
# Set default max_tokens if not specified
|
# Set default max_tokens if not specified
|
||||||
if request.get("max_tokens") is None:
|
if request.get("max_tokens") is None:
|
||||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token
|
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token
|
||||||
|
if self.reasoning_parser:
|
||||||
|
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||||
|
parts = request["request_id"].split("_")
|
||||||
|
if len(parts) > 1:
|
||||||
|
real_req_id = parts[0]
|
||||||
|
index = int(parts[1])
|
||||||
|
n = request.get("n", 1)
|
||||||
|
for idx in range(index * n, (index + 1) * n):
|
||||||
|
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||||
|
else:
|
||||||
|
self.model_status_dict[request["request_id"]] = model_status
|
||||||
|
request["enable_thinking"] = model_status == "think_start"
|
||||||
data_processor_logger.info(f"Processed request {request}")
|
data_processor_logger.info(f"Processed request {request}")
|
||||||
|
|
||||||
return request
|
return request
|
||||||
|
|||||||
@@ -176,6 +176,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
self.generation_config = None
|
self.generation_config = None
|
||||||
|
|
||||||
self.decode_status = dict()
|
self.decode_status = dict()
|
||||||
|
self.model_status_dict = dict()
|
||||||
self.tool_parser_dict = dict()
|
self.tool_parser_dict = dict()
|
||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
data_processor_logger.info(
|
data_processor_logger.info(
|
||||||
@@ -267,6 +268,18 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
request.set("temperature", 1)
|
request.set("temperature", 1)
|
||||||
if request.get("top_p") < _SAMPLING_EPS:
|
if request.get("top_p") < _SAMPLING_EPS:
|
||||||
request.set("top_p", _SAMPLING_EPS)
|
request.set("top_p", _SAMPLING_EPS)
|
||||||
|
if self.reasoning_parser:
|
||||||
|
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
||||||
|
parts = request.request_id.split("_")
|
||||||
|
if len(parts) > 1:
|
||||||
|
real_req_id = parts[0]
|
||||||
|
index = int(parts[1])
|
||||||
|
n = request.get("n", 1)
|
||||||
|
for idx in range(index * n, (index + 1) * n):
|
||||||
|
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||||
|
else:
|
||||||
|
self.model_status_dict[request.request_id] = model_status
|
||||||
|
request.enable_thinking = model_status == "think_start"
|
||||||
|
|
||||||
data_processor_logger.info(f"Processed request: {request}")
|
data_processor_logger.info(f"Processed request: {request}")
|
||||||
return request
|
return request
|
||||||
@@ -340,6 +353,18 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
request["temperature"] = 1
|
request["temperature"] = 1
|
||||||
if request.get("top_p") < _SAMPLING_EPS:
|
if request.get("top_p") < _SAMPLING_EPS:
|
||||||
request["top_p"] = _SAMPLING_EPS
|
request["top_p"] = _SAMPLING_EPS
|
||||||
|
if self.reasoning_parser:
|
||||||
|
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
|
||||||
|
parts = request["request_id"].split("_")
|
||||||
|
if len(parts) > 1:
|
||||||
|
real_req_id = parts[0]
|
||||||
|
index = int(parts[1])
|
||||||
|
n = request.get("n", 1)
|
||||||
|
for idx in range(index * n, (index + 1) * n):
|
||||||
|
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
||||||
|
else:
|
||||||
|
self.model_status_dict[request["request_id"]] = model_status
|
||||||
|
request["enable_thinking"] = model_status == "think_start"
|
||||||
|
|
||||||
data_processor_logger.info(f"Processed request dict: {request}")
|
data_processor_logger.info(f"Processed request dict: {request}")
|
||||||
return request
|
return request
|
||||||
@@ -363,21 +388,21 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
full_text = self.tokenizer.decode(token_ids)
|
full_text = self.tokenizer.decode(token_ids)
|
||||||
|
response_dict.outputs.text = full_text
|
||||||
# 模型支持思考,并且支持思考
|
|
||||||
if self.reasoning_parser:
|
if self.reasoning_parser:
|
||||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||||
|
full_text, response_dict, self.model_status_dict[req_id]
|
||||||
|
)
|
||||||
response_dict.outputs.text = text
|
response_dict.outputs.text = text
|
||||||
response_dict.outputs.reasoning_content = reasoning_content
|
response_dict.outputs.reasoning_content = reasoning_content
|
||||||
else:
|
|
||||||
# 模型不支持思考,并且没单独设置enable_thinking为false
|
|
||||||
response_dict.outputs.text = full_text
|
|
||||||
if self.tool_parser_obj:
|
if self.tool_parser_obj:
|
||||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||||
if tool_call_info.tools_called:
|
if tool_call_info.tools_called:
|
||||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
||||||
response_dict.outputs.text = tool_call_info.content
|
response_dict.outputs.text = tool_call_info.content
|
||||||
|
if req_id in self.model_status_dict:
|
||||||
|
del self.model_status_dict[req_id]
|
||||||
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
||||||
|
|
||||||
return response_dict
|
return response_dict
|
||||||
@@ -392,7 +417,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict: response contain text fields
|
Dict: response contain text fields
|
||||||
"""
|
"""
|
||||||
enable_thinking = kwargs.get("enable_thinking")
|
|
||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
is_end = response_dict["finished"]
|
is_end = response_dict["finished"]
|
||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
@@ -403,14 +427,17 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
if is_end:
|
if is_end:
|
||||||
full_text = previous_texts + delta_text
|
full_text = previous_texts + delta_text
|
||||||
response_dict["outputs"]["completion_tokens"] = full_text
|
response_dict["outputs"]["completion_tokens"] = full_text
|
||||||
if enable_thinking and self.reasoning_parser:
|
response_dict["outputs"]["text"] = full_text
|
||||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
if self.reasoning_parser:
|
||||||
|
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||||
|
full_text,
|
||||||
|
response_dict,
|
||||||
|
self.model_status_dict[req_id],
|
||||||
|
)
|
||||||
response_dict["outputs"]["text"] = text
|
response_dict["outputs"]["text"] = text
|
||||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
||||||
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
|
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
|
||||||
else:
|
|
||||||
response_dict["outputs"]["text"] = full_text
|
|
||||||
if self.tool_parser_obj:
|
if self.tool_parser_obj:
|
||||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||||
@@ -419,6 +446,8 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
response_dict["outputs"]["text"] = tool_call_info.content
|
response_dict["outputs"]["text"] = tool_call_info.content
|
||||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||||
del self.decode_status[req_id]
|
del self.decode_status[req_id]
|
||||||
|
if req_id in self.model_status_dict:
|
||||||
|
del self.model_status_dict[req_id]
|
||||||
return response_dict
|
return response_dict
|
||||||
|
|
||||||
def process_response_dict_streaming(self, response_dict, **kwargs):
|
def process_response_dict_streaming(self, response_dict, **kwargs):
|
||||||
@@ -431,7 +460,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict: response contain text fields
|
Dict: response contain text fields
|
||||||
"""
|
"""
|
||||||
enable_thinking = kwargs.get("enable_thinking")
|
|
||||||
is_end = response_dict["finished"]
|
is_end = response_dict["finished"]
|
||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
@@ -441,9 +469,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
response_dict["outputs"]["completion_tokens"] = delta_text
|
response_dict["outputs"]["completion_tokens"] = delta_text
|
||||||
if self.reasoning_parser and (
|
if self.reasoning_parser:
|
||||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
|
||||||
):
|
|
||||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||||
previous_texts,
|
previous_texts,
|
||||||
previous_texts + delta_text,
|
previous_texts + delta_text,
|
||||||
@@ -451,6 +477,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
previous_token_ids,
|
previous_token_ids,
|
||||||
previous_token_ids + token_ids,
|
previous_token_ids + token_ids,
|
||||||
token_ids,
|
token_ids,
|
||||||
|
self.model_status_dict[req_id],
|
||||||
)
|
)
|
||||||
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
||||||
reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None
|
reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None
|
||||||
@@ -477,6 +504,8 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
del self.decode_status[req_id]
|
del self.decode_status[req_id]
|
||||||
if req_id in self.tool_parser_dict:
|
if req_id in self.tool_parser_dict:
|
||||||
del self.tool_parser_dict[req_id]
|
del self.tool_parser_dict[req_id]
|
||||||
|
if req_id in self.model_status_dict:
|
||||||
|
del self.model_status_dict[req_id]
|
||||||
return response_dict
|
return response_dict
|
||||||
|
|
||||||
def process_response_dict(self, response_dict, **kwargs):
|
def process_response_dict(self, response_dict, **kwargs):
|
||||||
@@ -489,16 +518,12 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict: response contain text fields
|
Dict: response contain text fields
|
||||||
"""
|
"""
|
||||||
enable_thinking = kwargs.pop("enable_thinking", True)
|
|
||||||
if enable_thinking is None:
|
|
||||||
enable_thinking = True
|
|
||||||
stream = kwargs.get("stream", True)
|
stream = kwargs.get("stream", True)
|
||||||
if stream:
|
if stream:
|
||||||
return self.process_response_dict_streaming(response_dict, enable_thinking=enable_thinking, **kwargs)
|
return self.process_response_dict_streaming(response_dict, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self.process_response_dict_normal(
|
return self.process_response_dict_normal(
|
||||||
response_dict=response_dict,
|
response_dict=response_dict,
|
||||||
enable_thinking=enable_thinking,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -35,25 +35,53 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
|||||||
|
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer):
|
||||||
super().__init__(tokenizer)
|
super().__init__(tokenizer)
|
||||||
self.think_end_token = "</think>"
|
token_definitions = {
|
||||||
self.tool_begin_token = "<tool_call>"
|
"think_start_token": "<think>",
|
||||||
|
"think_end_token": "</think>",
|
||||||
|
"tool_call_start_token": "<tool_call>",
|
||||||
|
"tool_call_end_token": "</tool_call>",
|
||||||
|
}
|
||||||
|
|
||||||
if not self.model_tokenizer:
|
if not self.model_tokenizer:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
|
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
|
||||||
)
|
)
|
||||||
|
missing_tokens = []
|
||||||
|
for name, token_value in token_definitions.items():
|
||||||
|
setattr(self, name, token_value)
|
||||||
|
token_id = self.vocab.get(token_value)
|
||||||
|
setattr(self, f"{name}_id", token_id)
|
||||||
|
if token_id is None:
|
||||||
|
missing_tokens.append(f"{name.replace('_', ' ')} token")
|
||||||
|
|
||||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
if missing_tokens:
|
||||||
self.tool_begin_token_id = self.vocab.get(self.tool_begin_token)
|
raise RuntimeError(
|
||||||
if self.tool_begin_token_id is None:
|
f"ernie vl reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
|
||||||
self.tool_begin_token_id = -1
|
)
|
||||||
|
self.token_status_mapping = {
|
||||||
if self.think_end_token_id is None:
|
self.think_start_token_id: "think_start",
|
||||||
raise RuntimeError("Test reasoning parser could not locate think end tokens in the tokenizer!")
|
self.think_end_token_id: "think_end",
|
||||||
|
self.tool_call_start_token_id: "tool_call_start",
|
||||||
|
self.tool_call_end_token_id: "tool_call_end",
|
||||||
|
}
|
||||||
|
|
||||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
return self.think_end_token_id in input_ids
|
return self.think_end_token_id in input_ids
|
||||||
|
|
||||||
|
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
|
||||||
|
for i in range(len(prompt_token_ids) - 1, -1, -1):
|
||||||
|
if prompt_token_ids[i] in self.token_status_mapping:
|
||||||
|
return prompt_token_ids[i]
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def get_model_status(self, prompt_token_ids: list[int]):
|
||||||
|
special_token_id = self.find_last_special_token(prompt_token_ids)
|
||||||
|
|
||||||
|
if special_token_id == -1:
|
||||||
|
return "think_start"
|
||||||
|
|
||||||
|
return self.token_status_mapping[special_token_id]
|
||||||
|
|
||||||
def extract_reasoning_content_streaming(
|
def extract_reasoning_content_streaming(
|
||||||
self,
|
self,
|
||||||
previous_text: str,
|
previous_text: str,
|
||||||
@@ -62,6 +90,7 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
|||||||
previous_token_ids: Sequence[int],
|
previous_token_ids: Sequence[int],
|
||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
|
model_status: str,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
"""
|
"""
|
||||||
Extract reasoning content from a delta message.
|
Extract reasoning content from a delta message.
|
||||||
@@ -71,6 +100,7 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
|||||||
- 'abc' goes to reasoning_content
|
- 'abc' goes to reasoning_content
|
||||||
- 'xyz' goes to content
|
- 'xyz' goes to content
|
||||||
"""
|
"""
|
||||||
|
if model_status == "think_start":
|
||||||
if self.think_end_token not in current_text:
|
if self.think_end_token not in current_text:
|
||||||
return DeltaMessage(reasoning_content=delta_text)
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
# Skip single special tokens
|
# Skip single special tokens
|
||||||
@@ -98,9 +128,18 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
|||||||
if len(striped_suffix) == 0:
|
if len(striped_suffix) == 0:
|
||||||
return None
|
return None
|
||||||
return DeltaMessage(content=delta_text)
|
return DeltaMessage(content=delta_text)
|
||||||
|
elif model_status == "think_end":
|
||||||
|
if current_text.lstrip("\n").startswith(self.tool_call_start_token):
|
||||||
|
return None
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def extract_reasoning_content(
|
def extract_reasoning_content(
|
||||||
self, model_output: str, request: ChatCompletionRequest
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
model_status: str,
|
||||||
) -> tuple[Optional[str], Optional[str]]:
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Extract reasoning content from the model output.
|
Extract reasoning content from the model output.
|
||||||
@@ -114,23 +153,30 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Check if the model output contains the </think> tokens.
|
# Check if the model output contains the </think> tokens.
|
||||||
|
if model_status == "think_start":
|
||||||
if self.think_end_token not in model_output:
|
if self.think_end_token not in model_output:
|
||||||
return model_output, ""
|
return model_output, ""
|
||||||
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
||||||
if self.tool_begin_token in content:
|
if self.tool_call_start_token in content:
|
||||||
prefix, _, _ = content.partition(self.tool_begin_token)
|
prefix, _, _ = content.partition(self.tool_call_start_token)
|
||||||
prefix_strip = prefix.lstrip("\n")
|
prefix_strip = prefix.lstrip("\n")
|
||||||
if len(prefix_strip) > 0:
|
if len(prefix_strip) > 0:
|
||||||
return reasoning_content, content
|
return reasoning_content, content
|
||||||
return reasoning_content, ""
|
return reasoning_content, ""
|
||||||
return reasoning_content, content
|
return reasoning_content, content
|
||||||
|
elif model_status == "think_end":
|
||||||
|
if model_output.lstrip("\n").startswith(self.tool_call_start_token):
|
||||||
|
return "", ""
|
||||||
|
return "", model_output
|
||||||
|
else:
|
||||||
|
return "", ""
|
||||||
|
|
||||||
def _is_with_tool(self, current_text: str, current_token_ids: Sequence[int]) -> bool:
|
def _is_with_tool(self, current_text: str, current_token_ids: Sequence[int]) -> bool:
|
||||||
think_end_index = current_text.find(self.think_end_token)
|
think_end_index = current_text.find(self.think_end_token)
|
||||||
think_end = think_end_index + len(self.think_end_token)
|
think_end = think_end_index + len(self.think_end_token)
|
||||||
middle_str = current_text[think_end:]
|
middle_str = current_text[think_end:]
|
||||||
if self.tool_begin_token_id in current_token_ids:
|
if self.tool_call_start_token_id in current_token_ids:
|
||||||
prefix, _, _ = middle_str.partition(self.tool_begin_token)
|
prefix, _, _ = middle_str.partition(self.tool_call_start_token)
|
||||||
striped_prefix = prefix.strip("\n")
|
striped_prefix = prefix.strip("\n")
|
||||||
if len(striped_prefix) > 0:
|
if len(striped_prefix) > 0:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -35,20 +35,48 @@ class ErnieVLReasoningParser(ReasoningParser):
|
|||||||
|
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer):
|
||||||
super().__init__(tokenizer)
|
super().__init__(tokenizer)
|
||||||
self.think_end_token = "</think>"
|
token_definitions = {
|
||||||
|
"think_start_token": "<think>",
|
||||||
|
"think_end_token": "</think>",
|
||||||
|
}
|
||||||
|
|
||||||
if not self.model_tokenizer:
|
if not self.model_tokenizer:
|
||||||
raise ValueError(
|
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
|
||||||
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
missing_tokens = []
|
||||||
if self.think_end_token_id is None:
|
for name, token_value in token_definitions.items():
|
||||||
raise RuntimeError("Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!")
|
setattr(self, name, token_value)
|
||||||
|
token_id = self.vocab.get(token_value)
|
||||||
|
setattr(self, f"{name}_id", token_id)
|
||||||
|
if token_id is None:
|
||||||
|
missing_tokens.append(f"{name.replace('_', ' ')} token")
|
||||||
|
|
||||||
|
if missing_tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"ernie vl reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
|
||||||
|
)
|
||||||
|
self.token_status_mapping = {
|
||||||
|
self.think_start_token_id: "think_start",
|
||||||
|
self.think_end_token_id: "think_end",
|
||||||
|
}
|
||||||
|
|
||||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
return self.think_end_token_id in input_ids
|
return self.think_end_token_id in input_ids
|
||||||
|
|
||||||
|
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
|
||||||
|
for i in range(len(prompt_token_ids) - 1, -1, -1):
|
||||||
|
if prompt_token_ids[i] in self.token_status_mapping:
|
||||||
|
return prompt_token_ids[i]
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def get_model_status(self, prompt_token_ids: list[int]):
|
||||||
|
special_token_id = self.find_last_special_token(prompt_token_ids)
|
||||||
|
|
||||||
|
if special_token_id == -1:
|
||||||
|
return "think_start"
|
||||||
|
|
||||||
|
return self.token_status_mapping[special_token_id]
|
||||||
|
|
||||||
def extract_reasoning_content_streaming(
|
def extract_reasoning_content_streaming(
|
||||||
self,
|
self,
|
||||||
previous_text: str,
|
previous_text: str,
|
||||||
@@ -57,6 +85,7 @@ class ErnieVLReasoningParser(ReasoningParser):
|
|||||||
previous_token_ids: Sequence[int],
|
previous_token_ids: Sequence[int],
|
||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
|
model_status: str,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
"""
|
"""
|
||||||
Extract reasoning content from a delta message.
|
Extract reasoning content from a delta message.
|
||||||
@@ -69,18 +98,23 @@ class ErnieVLReasoningParser(ReasoningParser):
|
|||||||
# Skip single special tokens
|
# Skip single special tokens
|
||||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||||
return None
|
return None
|
||||||
|
if model_status == "think_start":
|
||||||
if self.think_end_token_id in delta_token_ids:
|
if self.think_end_token_id in delta_token_ids:
|
||||||
end_index = delta_text.find(self.think_end_token)
|
end_index = delta_text.find(self.think_end_token)
|
||||||
reasoning_content = delta_text[:end_index]
|
reasoning_content = delta_text[:end_index]
|
||||||
content = delta_text[end_index + len(self.think_end_token) :]
|
content = delta_text[end_index + len(self.think_end_token) :]
|
||||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||||
elif self.think_end_token_id in previous_token_ids:
|
if self.think_end_token_id in previous_token_ids:
|
||||||
return DeltaMessage(content=delta_text)
|
return DeltaMessage(content=delta_text)
|
||||||
else:
|
|
||||||
return DeltaMessage(reasoning_content=delta_text)
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
else:
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
def extract_reasoning_content(
|
def extract_reasoning_content(
|
||||||
self, model_output: str, request: ChatCompletionRequest
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
model_status: str,
|
||||||
) -> tuple[Optional[str], Optional[str]]:
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Extract reasoning content from the model output.
|
Extract reasoning content from the model output.
|
||||||
@@ -92,11 +126,12 @@ class ErnieVLReasoningParser(ReasoningParser):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[Optional[str], Optional[str]]: reasoning content and content
|
tuple[Optional[str], Optional[str]]: reasoning content and content
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Check if the model output contains the </think> tokens.
|
# Check if the model output contains the </think> tokens.
|
||||||
|
if model_status == "think_start":
|
||||||
if self.think_end_token not in model_output:
|
if self.think_end_token not in model_output:
|
||||||
return "", model_output
|
return "", model_output
|
||||||
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
||||||
|
|
||||||
final_content = content or ""
|
final_content = content or ""
|
||||||
return reasoning_content, final_content
|
return reasoning_content, final_content
|
||||||
|
else:
|
||||||
|
return "", model_output
|
||||||
|
|||||||
@@ -18,19 +18,55 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
|||||||
|
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer):
|
||||||
super().__init__(tokenizer)
|
super().__init__(tokenizer)
|
||||||
self.think_end_token = "</think>"
|
|
||||||
self.response_start_token = "<response>"
|
# 定义所有需要检查的token
|
||||||
self.response_end_token = "</response>"
|
token_definitions = {
|
||||||
self.tool_call_start_token = "<tool_call>"
|
"think_start_token": "<think>",
|
||||||
self.tool_call_end_token = "</tool_call>"
|
"think_end_token": "</think>",
|
||||||
|
"response_start_token": "<response>",
|
||||||
|
"response_end_token": "</response>",
|
||||||
|
"tool_call_start_token": "<tool_call>",
|
||||||
|
"tool_call_end_token": "</tool_call>",
|
||||||
|
}
|
||||||
|
|
||||||
if not self.model_tokenizer:
|
if not self.model_tokenizer:
|
||||||
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
|
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
|
||||||
|
|
||||||
self.think_end_token_id = self.vocab.get("</think>")
|
missing_tokens = []
|
||||||
if self.think_end_token_id is None:
|
for name, token_value in token_definitions.items():
|
||||||
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
|
setattr(self, name, token_value)
|
||||||
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
|
token_id = self.vocab.get(token_value)
|
||||||
|
setattr(self, f"{name}_id", token_id)
|
||||||
|
if token_id is None:
|
||||||
|
missing_tokens.append(token_value)
|
||||||
|
|
||||||
|
if missing_tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"ernie x1 reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.token_status_mapping = {
|
||||||
|
self.think_start_token_id: "think_start",
|
||||||
|
self.think_end_token_id: "think_end",
|
||||||
|
self.response_start_token_id: "response_start",
|
||||||
|
self.response_end_token_id: "response_end",
|
||||||
|
self.tool_call_start_token_id: "tool_call_start",
|
||||||
|
self.tool_call_end_token_id: "tool_call_end",
|
||||||
|
}
|
||||||
|
|
||||||
|
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
|
||||||
|
for i in range(len(prompt_token_ids) - 1, -1, -1):
|
||||||
|
if prompt_token_ids[i] in self.token_status_mapping:
|
||||||
|
return prompt_token_ids[i]
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def get_model_status(self, prompt_token_ids: list[int]):
|
||||||
|
special_token_id = self.find_last_special_token(prompt_token_ids)
|
||||||
|
|
||||||
|
if special_token_id == -1:
|
||||||
|
return "think_start"
|
||||||
|
|
||||||
|
return self.token_status_mapping[special_token_id]
|
||||||
|
|
||||||
def extract_reasoning_content_streaming(
|
def extract_reasoning_content_streaming(
|
||||||
self,
|
self,
|
||||||
@@ -40,64 +76,82 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
|||||||
previous_token_ids: Sequence[int],
|
previous_token_ids: Sequence[int],
|
||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
|
model_status: str,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
# Ignore the single </think> token
|
|
||||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
if len(delta_token_ids) == 1 and delta_token_ids[0] in [
|
||||||
|
self.think_end_token_id,
|
||||||
|
self.response_start_token_id,
|
||||||
|
self.response_end_token_id,
|
||||||
|
self.tool_call_start_token_id,
|
||||||
|
self.tool_call_end_token_id,
|
||||||
|
]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# --- Thinking stage handling ---
|
if model_status == "think_start":
|
||||||
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
|
if self.think_end_token in delta_text:
|
||||||
# If delta is </think>, stop thinking, do not return
|
response_content = ""
|
||||||
if delta_text.startswith(self.think_end_token):
|
end_index = delta_text.find(self.think_end_token)
|
||||||
return None
|
reasoning_content = delta_text[:end_index]
|
||||||
# Otherwise, return thinking content (keep \n as-is)
|
response_start_pos = delta_text.find(self.response_start_token)
|
||||||
|
if response_start_pos != -1:
|
||||||
|
response_content = self._extract_response_content(
|
||||||
|
delta_text[response_start_pos + len(self.response_start_token) :]
|
||||||
|
)
|
||||||
|
return DeltaMessage(reasoning_content=reasoning_content, content=response_content)
|
||||||
|
elif self.think_end_token in previous_text:
|
||||||
|
if self.response_start_token in previous_text and self.response_end_token not in previous_text:
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
else:
|
||||||
return DeltaMessage(reasoning_content=delta_text)
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
elif model_status == "think_end":
|
||||||
# --- After thinking ends, check tool_call or response ---
|
if self.response_start_token in previous_text and self.response_end_token not in previous_text:
|
||||||
remaining_text = previous_text + delta_text
|
return DeltaMessage(content=delta_text)
|
||||||
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
|
elif model_status == "response_start":
|
||||||
after_think = after_think.lstrip("\n")
|
if self.response_end_token not in previous_text:
|
||||||
|
|
||||||
# Handle tool_call case: skip it
|
|
||||||
if after_think.startswith(self.tool_call_start_token):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle response case
|
|
||||||
if after_think.startswith(self.response_start_token) and self.response_end_token not in after_think:
|
|
||||||
# Do not return when <response> tag itself appears
|
|
||||||
if delta_text == self.response_start_token or delta_text == self.response_end_token:
|
|
||||||
return None
|
|
||||||
return DeltaMessage(content=delta_text)
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
# Default case: return nothing
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
|
def extract_reasoning_content(
|
||||||
|
self, model_output: str, request: ChatCompletionRequest, model_status: str
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
优化版解析器。保留推理和响应内容中的换行符,
|
||||||
|
仅删除闭合标签前的单个换行符。
|
||||||
|
"""
|
||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
response_content = ""
|
response_content = ""
|
||||||
|
|
||||||
|
if model_status in ["think_start", "think_end"]:
|
||||||
|
if model_status == "think_start":
|
||||||
think_end_pos = model_output.find(self.think_end_token)
|
think_end_pos = model_output.find(self.think_end_token)
|
||||||
if think_end_pos != -1:
|
if think_end_pos != -1:
|
||||||
reasoning_content = model_output[:think_end_pos]
|
reasoning_content = model_output[:think_end_pos]
|
||||||
|
remaining = model_output[think_end_pos + len(self.think_end_token) :].lstrip("\n")
|
||||||
remaining = model_output[think_end_pos + len(self.think_end_token) :]
|
|
||||||
|
|
||||||
# find <response> or <tool>
|
|
||||||
response_pos = remaining.find(self.response_start_token)
|
|
||||||
tool_pos = remaining.find(self.tool_call_start_token)
|
|
||||||
|
|
||||||
# <response> first
|
|
||||||
if response_pos != -1 and (tool_pos == -1 or response_pos < tool_pos):
|
|
||||||
# The content after the response_start position
|
|
||||||
remaining_response = remaining[response_pos + len(self.response_start_token) :]
|
|
||||||
response_end_pos = remaining_response.find(self.response_end_token)
|
|
||||||
if response_end_pos != -1:
|
|
||||||
response_content = remaining_response[:response_end_pos]
|
|
||||||
else:
|
|
||||||
response_content = remaining_response
|
|
||||||
# The content after the response_start position is tool_call
|
|
||||||
else:
|
else:
|
||||||
reasoning_content = model_output
|
reasoning_content = model_output
|
||||||
response_content = ""
|
remaining = ""
|
||||||
|
else:
|
||||||
|
remaining = model_output.lstrip("\n")
|
||||||
|
|
||||||
|
response_start_pos = remaining.find(self.response_start_token)
|
||||||
|
if response_start_pos != -1:
|
||||||
|
response_content = self._extract_response_content(
|
||||||
|
remaining[response_start_pos + len(self.response_start_token) :]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_status == "response_start":
|
||||||
|
response_content = self._extract_response_content(model_output)
|
||||||
|
|
||||||
return reasoning_content, response_content
|
return reasoning_content, response_content
|
||||||
|
|
||||||
|
def _extract_response_content(self, remaining: str) -> str:
|
||||||
|
"""
|
||||||
|
Extracts response content, ensuring that the last newline before
|
||||||
|
the </response> tag is removed.
|
||||||
|
"""
|
||||||
|
response_end_pos = remaining.find(self.response_end_token)
|
||||||
|
if response_end_pos != -1:
|
||||||
|
return remaining[:response_end_pos]
|
||||||
|
return remaining
|
||||||
|
|||||||
@@ -35,22 +35,50 @@ class Qwen3ReasoningParser(ReasoningParser):
|
|||||||
|
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer):
|
||||||
super().__init__(tokenizer)
|
super().__init__(tokenizer)
|
||||||
self.think_start_token = "<think>"
|
|
||||||
self.think_end_token = "</think>"
|
# 定义所有需要检查的token
|
||||||
|
token_definitions = {
|
||||||
|
"think_start_token": "<think>",
|
||||||
|
"think_end_token": "</think>",
|
||||||
|
}
|
||||||
|
|
||||||
if not self.model_tokenizer:
|
if not self.model_tokenizer:
|
||||||
raise ValueError(
|
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
|
||||||
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.think_start_token_id = self.vocab.get(self.think_start_token)
|
missing_tokens = []
|
||||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
for name, token_value in token_definitions.items():
|
||||||
if self.think_end_token_id is None:
|
setattr(self, name, token_value)
|
||||||
raise RuntimeError("Qwen3 reasoning parser could not locate think end " "tokens in the tokenizer!")
|
token_id = self.vocab.get(token_value)
|
||||||
|
setattr(self, f"{name}_id", token_id)
|
||||||
|
if token_id is None:
|
||||||
|
missing_tokens.append(token_value)
|
||||||
|
|
||||||
|
if missing_tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Qwen3 reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
|
||||||
|
)
|
||||||
|
self.token_status_mapping = {
|
||||||
|
self.think_start_token_id: "think_start",
|
||||||
|
self.think_end_token_id: "think_end",
|
||||||
|
}
|
||||||
|
|
||||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
return self.think_end_token_id in input_ids
|
return self.think_end_token_id in input_ids
|
||||||
|
|
||||||
|
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
|
||||||
|
for i in range(len(prompt_token_ids) - 1, -1, -1):
|
||||||
|
if prompt_token_ids[i] in self.token_status_mapping:
|
||||||
|
return prompt_token_ids[i]
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def get_model_status(self, prompt_token_ids: list[int]):
|
||||||
|
special_token_id = self.find_last_special_token(prompt_token_ids)
|
||||||
|
|
||||||
|
if special_token_id == -1:
|
||||||
|
return "think_start"
|
||||||
|
|
||||||
|
return self.token_status_mapping[special_token_id]
|
||||||
|
|
||||||
def extract_reasoning_content_streaming(
|
def extract_reasoning_content_streaming(
|
||||||
self,
|
self,
|
||||||
previous_text: str,
|
previous_text: str,
|
||||||
@@ -59,6 +87,7 @@ class Qwen3ReasoningParser(ReasoningParser):
|
|||||||
previous_token_ids: Sequence[int],
|
previous_token_ids: Sequence[int],
|
||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
|
model_status: str,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
"""
|
"""
|
||||||
Extract reasoning content from a delta message.
|
Extract reasoning content from a delta message.
|
||||||
@@ -71,6 +100,7 @@ class Qwen3ReasoningParser(ReasoningParser):
|
|||||||
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]):
|
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if model_status == "think_start":
|
||||||
# </think> in delta
|
# </think> in delta
|
||||||
if self.think_end_token_id in delta_token_ids:
|
if self.think_end_token_id in delta_token_ids:
|
||||||
# <think> in delta, </think> in delta, extract reasoning content
|
# <think> in delta, </think> in delta, extract reasoning content
|
||||||
@@ -101,9 +131,11 @@ class Qwen3ReasoningParser(ReasoningParser):
|
|||||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||||
else:
|
else:
|
||||||
return DeltaMessage(reasoning_content=delta_text)
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
else:
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
def extract_reasoning_content(
|
def extract_reasoning_content(
|
||||||
self, model_output: str, request: ChatCompletionRequest
|
self, model_output: str, request: ChatCompletionRequest, model_status: str
|
||||||
) -> tuple[Optional[str], Optional[str]]:
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Extract reasoning content from the model output.
|
Extract reasoning content from the model output.
|
||||||
@@ -116,6 +148,7 @@ class Qwen3ReasoningParser(ReasoningParser):
|
|||||||
tuple[Optional[str], Optional[str]]: reasoning content and content
|
tuple[Optional[str], Optional[str]]: reasoning content and content
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if model_status == "think_start":
|
||||||
# 检查是否包含结束标签
|
# 检查是否包含结束标签
|
||||||
if self.think_end_token not in model_output:
|
if self.think_end_token not in model_output:
|
||||||
return None, model_output
|
return None, model_output
|
||||||
@@ -149,3 +182,5 @@ class Qwen3ReasoningParser(ReasoningParser):
|
|||||||
return reasoning_content, final_content
|
return reasoning_content, final_content
|
||||||
|
|
||||||
return None, model_output
|
return None, model_output
|
||||||
|
else:
|
||||||
|
return None, model_output
|
||||||
|
|||||||
@@ -482,7 +482,7 @@ def test_chat_with_thinking(openai_client, capsys):
|
|||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.reasoning_content is None
|
assert response.choices[0].message.reasoning_content == ""
|
||||||
assert "</think>" not in response.choices[0].message.content
|
assert "</think>" not in response.choices[0].message.content
|
||||||
|
|
||||||
# test logic
|
# test logic
|
||||||
@@ -957,4 +957,4 @@ def test_thinking_logic_flag(openai_client, capsys):
|
|||||||
"chat_template_kwargs": {"enable_thinking": False},
|
"chat_template_kwargs": {"enable_thinking": False},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response_case_3.choices[0].message.reasoning_content is None
|
assert response_case_3.choices[0].message.reasoning_content == ""
|
||||||
|
|||||||
@@ -545,7 +545,7 @@ def test_chat_with_thinking(openai_client, capsys):
|
|||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.reasoning_content is None
|
assert response.choices[0].message.reasoning_content == ""
|
||||||
assert "</think>" not in response.choices[0].message.content
|
assert "</think>" not in response.choices[0].message.content
|
||||||
|
|
||||||
# test logic
|
# test logic
|
||||||
@@ -716,4 +716,4 @@ def test_thinking_logic_flag(openai_client, capsys):
|
|||||||
"chat_template_kwargs": {"enable_thinking": False},
|
"chat_template_kwargs": {"enable_thinking": False},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response_case_3.choices[0].message.reasoning_content is None
|
assert response_case_3.choices[0].message.reasoning_content == ""
|
||||||
|
|||||||
@@ -312,7 +312,7 @@ def test_chat_with_thinking(openai_client, capsys):
|
|||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.reasoning_content is None
|
assert response.choices[0].message.reasoning_content == ""
|
||||||
assert "</think>" not in response.choices[0].message.content
|
assert "</think>" not in response.choices[0].message.content
|
||||||
|
|
||||||
# test logic
|
# test logic
|
||||||
@@ -404,4 +404,4 @@ def test_thinking_logic_flag(openai_client, capsys):
|
|||||||
"chat_template_kwargs": {"enable_thinking": False},
|
"chat_template_kwargs": {"enable_thinking": False},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response_case_3.choices[0].message.reasoning_content is None
|
assert response_case_3.choices[0].message.reasoning_content == ""
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
|
|||||||
self.multi_modal_processor._check_mm_limits = Mock()
|
self.multi_modal_processor._check_mm_limits = Mock()
|
||||||
self.multi_modal_processor.append_completion_tokens = Mock()
|
self.multi_modal_processor.append_completion_tokens = Mock()
|
||||||
self.multi_modal_processor.pack_outputs = lambda x: x
|
self.multi_modal_processor.pack_outputs = lambda x: x
|
||||||
|
self.multi_modal_processor.reasoning_parser = None
|
||||||
|
self.multi_modal_processor.model_status_dict = {}
|
||||||
|
|
||||||
self.engine_client = Mock()
|
self.engine_client = Mock()
|
||||||
self.engine_client.connection_initialized = False
|
self.engine_client.connection_initialized = False
|
||||||
@@ -258,7 +260,7 @@ class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
|
|||||||
mock_processor_instance = Mock()
|
mock_processor_instance = Mock()
|
||||||
mock_processor_instance.enable_multimodal_content.return_value = True
|
mock_processor_instance.enable_multimodal_content.return_value = True
|
||||||
|
|
||||||
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
|
async def mock_process_response_chat_async(response, stream, include_stop_str_in_output):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
mock_processor_instance.process_response_chat = mock_process_response_chat_async
|
mock_processor_instance.process_response_chat = mock_process_response_chat_async
|
||||||
@@ -439,7 +441,7 @@ class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
|
|||||||
mock_processor_instance = Mock()
|
mock_processor_instance = Mock()
|
||||||
mock_processor_instance.enable_multimodal_content.return_value = False
|
mock_processor_instance.enable_multimodal_content.return_value = False
|
||||||
|
|
||||||
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
|
async def mock_process_response_chat_async(response, stream, include_stop_str_in_output):
|
||||||
if isinstance(response, list):
|
if isinstance(response, list):
|
||||||
for res in response:
|
for res in response:
|
||||||
yield res
|
yield res
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
|||||||
|
|
||||||
mock_processor_instance = Mock()
|
mock_processor_instance = Mock()
|
||||||
|
|
||||||
async def mock_process_response_chat_single(response, stream, enable_thinking, include_stop_str_in_output):
|
async def mock_process_response_chat_single(response, stream, include_stop_str_in_output):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
mock_processor_instance.process_response_chat = mock_process_response_chat_single
|
mock_processor_instance.process_response_chat = mock_process_response_chat_single
|
||||||
@@ -539,7 +539,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
|||||||
|
|
||||||
mock_processor_instance = Mock()
|
mock_processor_instance = Mock()
|
||||||
|
|
||||||
async def mock_process_response_chat(response, stream, enable_thinking, include_stop_str_in_output):
|
async def mock_process_response_chat(response, stream, include_stop_str_in_output):
|
||||||
delta_msg_mock = Mock()
|
delta_msg_mock = Mock()
|
||||||
delta_msg_mock.content = response["outputs"]["text"]
|
delta_msg_mock.content = response["outputs"]["text"]
|
||||||
if response["outputs"]["text"] == "a":
|
if response["outputs"]["text"] == "a":
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
|||||||
results = [
|
results = [
|
||||||
r
|
r
|
||||||
async for r in processor.process_response_chat(
|
async for r in processor.process_response_chat(
|
||||||
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
|
request_outputs, stream=False, include_stop_str_in_output=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -69,7 +69,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
|||||||
results = [
|
results = [
|
||||||
r
|
r
|
||||||
async for r in processor.process_response_chat(
|
async for r in processor.process_response_chat(
|
||||||
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
|
request_outputs, stream=True, include_stop_str_in_output=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
|||||||
results = [
|
results = [
|
||||||
r
|
r
|
||||||
async for r in self.processor_mm.process_response_chat(
|
async for r in self.processor_mm.process_response_chat(
|
||||||
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
|
request_outputs, stream=True, include_stop_str_in_output=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
|||||||
results = [
|
results = [
|
||||||
r
|
r
|
||||||
async for r in self.processor_mm.process_response_chat(
|
async for r in self.processor_mm.process_response_chat(
|
||||||
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
|
request_outputs, stream=True, include_stop_str_in_output=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -134,7 +134,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
|
|||||||
results = [
|
results = [
|
||||||
r
|
r
|
||||||
async for r in self.processor_mm.process_response_chat(
|
async for r in self.processor_mm.process_response_chat(
|
||||||
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
|
request_outputs, stream=False, include_stop_str_in_output=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -41,35 +41,6 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
|
|||||||
chat_template=None,
|
chat_template=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_enable_thinking(self):
|
|
||||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={})
|
|
||||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
|
||||||
self.assertEqual(enable_thinking, None)
|
|
||||||
|
|
||||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"enable_thinking": True})
|
|
||||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
|
||||||
self.assertEqual(enable_thinking, True)
|
|
||||||
|
|
||||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"enable_thinking": False})
|
|
||||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
|
||||||
self.assertEqual(enable_thinking, False)
|
|
||||||
|
|
||||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "close"}})
|
|
||||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
|
||||||
self.assertEqual(enable_thinking, False)
|
|
||||||
|
|
||||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "false"}})
|
|
||||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
|
||||||
self.assertEqual(enable_thinking, False)
|
|
||||||
|
|
||||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "open"}})
|
|
||||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
|
||||||
self.assertEqual(enable_thinking, True)
|
|
||||||
|
|
||||||
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "123"}})
|
|
||||||
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
|
|
||||||
self.assertEqual(enable_thinking, True)
|
|
||||||
|
|
||||||
def test_build_prompt_logprobs_basic(self):
|
def test_build_prompt_logprobs_basic(self):
|
||||||
"""Test basic functionality of _build_prompt_logprobs"""
|
"""Test basic functionality of _build_prompt_logprobs"""
|
||||||
# Create mock data
|
# Create mock data
|
||||||
|
|||||||
@@ -52,33 +52,12 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
self.assertTrue(result.tools_called)
|
self.assertTrue(result.tools_called)
|
||||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||||
|
|
||||||
def test_extract_tool_calls_partial_arguments(self):
|
|
||||||
"""Test partial extraction when arguments incomplete"""
|
|
||||||
output = '<tool_call>{"name": "get_weather", "arguments": {"location": "北"</tool_call>'
|
|
||||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
|
||||||
self.assertFalse(result.tools_called)
|
|
||||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
|
||||||
|
|
||||||
def test_extract_tool_calls_invalid_response_before_toolcall(self):
|
|
||||||
"""Test case where <response> before <tool_call> is invalid"""
|
|
||||||
output = '<response>hello</response><tool_call>{"name": "get_weather", "arguments": {}}</tool_call>'
|
|
||||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
|
||||||
self.assertFalse(result.tools_called)
|
|
||||||
self.assertIn("<response>", result.content)
|
|
||||||
|
|
||||||
def test_extract_tool_calls_no_toolcall(self):
|
def test_extract_tool_calls_no_toolcall(self):
|
||||||
"""Test when no tool_call tags are present"""
|
"""Test when no tool_call tags are present"""
|
||||||
output = "no tool call here"
|
output = "no tool call here"
|
||||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||||
self.assertFalse(result.tools_called)
|
self.assertFalse(result.tools_called)
|
||||||
|
|
||||||
def test_extract_tool_calls_invalid_json(self):
|
|
||||||
"""Test tool_call with badly formatted JSON triggers fallback parser"""
|
|
||||||
output = '<tool_call>"name": "get_weather", "arguments": {</tool_call>'
|
|
||||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
|
||||||
self.assertFalse(result.tools_called)
|
|
||||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
|
||||||
|
|
||||||
def test_extract_tool_calls_exception(self):
|
def test_extract_tool_calls_exception(self):
|
||||||
"""Force exception to cover error branch"""
|
"""Force exception to cover error branch"""
|
||||||
with patch(
|
with patch(
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ class ErnieX1ReasoningParser:
|
|||||||
previous_token_ids,
|
previous_token_ids,
|
||||||
all_token_ids,
|
all_token_ids,
|
||||||
delta_token_ids,
|
delta_token_ids,
|
||||||
|
model_status,
|
||||||
):
|
):
|
||||||
"""Return a simple object with reasoning_content to cover reasoning branch."""
|
"""Return a simple object with reasoning_content to cover reasoning branch."""
|
||||||
|
|
||||||
@@ -161,6 +162,7 @@ class TestErnie4_5Processor(unittest.TestCase):
|
|||||||
tool_cls = MockToolParser if tool else None
|
tool_cls = MockToolParser if tool else None
|
||||||
proc = Ernie4_5Processor("dummy-model", reasoning_parser_obj=reasoning_cls, tool_parser_obj=tool_cls)
|
proc = Ernie4_5Processor("dummy-model", reasoning_parser_obj=reasoning_cls, tool_parser_obj=tool_cls)
|
||||||
proc._apply_default_parameters = lambda req: req
|
proc._apply_default_parameters = lambda req: req
|
||||||
|
proc.model_status_dict = {"req-1": "think_start"}
|
||||||
return proc
|
return proc
|
||||||
|
|
||||||
def test_update_bad_words(self):
|
def test_update_bad_words(self):
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ from unittest.mock import MagicMock, patch
|
|||||||
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
|
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
|
||||||
|
|
||||||
|
|
||||||
|
class MockReasoningParser:
|
||||||
|
def get_model_status(self, prompt_token_ids):
|
||||||
|
return "think_start"
|
||||||
|
|
||||||
|
|
||||||
class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# 创建 Ernie4_5Processor 实例的模拟对象
|
# 创建 Ernie4_5Processor 实例的模拟对象
|
||||||
@@ -30,12 +35,13 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
# 设置必要的属性
|
# 设置必要的属性
|
||||||
self.processor.tokenizer = MagicMock()
|
self.processor.tokenizer = MagicMock()
|
||||||
self.processor.tokenizer.eos_token_id = 1
|
self.processor.tokenizer.eos_token_id = 1
|
||||||
self.processor.decode_status = {}
|
self.processor.decode_status = {"test": []}
|
||||||
self.processor.reasoning_end_dict = {}
|
self.processor.reasoning_end_dict = {}
|
||||||
self.processor.tool_parser_dict = {}
|
self.processor.tool_parser_dict = {}
|
||||||
self.processor.generation_config = MagicMock()
|
self.processor.generation_config = MagicMock()
|
||||||
self.processor.eos_token_ids = [1]
|
self.processor.eos_token_ids = [1]
|
||||||
self.processor.reasoning_parser = MagicMock()
|
self.processor.reasoning_parser = MockReasoningParser()
|
||||||
|
self.processor.model_status_dict = {"request-id_0": "think_start", "test": "think_start"}
|
||||||
|
|
||||||
# 模拟 ids2tokens 方法
|
# 模拟 ids2tokens 方法
|
||||||
def mock_ids2tokens(token_ids, task_id):
|
def mock_ids2tokens(token_ids, task_id):
|
||||||
@@ -72,7 +78,7 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
def test_process_response_dict_streaming_normal_case(self):
|
def test_process_response_dict_streaming_normal_case(self):
|
||||||
"""测试正常情况下的流式响应处理"""
|
"""测试正常情况下的流式响应处理"""
|
||||||
# 准备输入
|
# 准备输入
|
||||||
response_dict = {"finished": False, "request_id": "req1", "outputs": {"token_ids": [4, 5]}}
|
response_dict = {"finished": False, "request_id": "test", "outputs": {"token_ids": [4, 5]}}
|
||||||
kwargs = {"enable_thinking": True}
|
kwargs = {"enable_thinking": True}
|
||||||
|
|
||||||
# 调用方法
|
# 调用方法
|
||||||
@@ -83,6 +89,7 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
|
|
||||||
def test_process_request_dict(self):
|
def test_process_request_dict(self):
|
||||||
request_dict = {
|
request_dict = {
|
||||||
|
"request_id": "123",
|
||||||
"messages": [{"role": "user", "content": "Hello!"}],
|
"messages": [{"role": "user", "content": "Hello!"}],
|
||||||
"chat_template_kwargs": {"chat_template": "Hello!"},
|
"chat_template_kwargs": {"chat_template": "Hello!"},
|
||||||
"eos_token_ids": [1],
|
"eos_token_ids": [1],
|
||||||
@@ -118,6 +125,31 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
self.assertEqual(result["outputs"]["text"], "Mock final text")
|
self.assertEqual(result["outputs"]["text"], "Mock final text")
|
||||||
self.assertIn("completion_tokens", result["outputs"])
|
self.assertIn("completion_tokens", result["outputs"])
|
||||||
|
|
||||||
|
def test_think_status(self):
|
||||||
|
"""测试 思考机制"""
|
||||||
|
request = {
|
||||||
|
"prompt": "hello",
|
||||||
|
"request_id": "test_1",
|
||||||
|
"prompt_token_ids": [1, 2, 3],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
}
|
||||||
|
self.processor.reasoning_parser = MagicMock()
|
||||||
|
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
|
||||||
|
self.processor.model_status_dict = {}
|
||||||
|
self.processor.process_request_dict(request, max_model_len=512)
|
||||||
|
self.assertEqual(request["enable_thinking"], True)
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"prompt": "hello",
|
||||||
|
"request_id": "test",
|
||||||
|
"prompt_token_ids": [1, 2, 3],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
}
|
||||||
|
self.processor.process_request_dict(request, max_model_len=512)
|
||||||
|
self.assertEqual(request["enable_thinking"], True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -29,7 +29,12 @@ from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
|
|||||||
from fastdeploy.input.utils import IDS_TYPE_FLAG
|
from fastdeploy.input.utils import IDS_TYPE_FLAG
|
||||||
|
|
||||||
|
|
||||||
class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
class MockReasoningParser:
|
||||||
|
def get_model_status(self, prompt_token_ids):
|
||||||
|
return "think_start"
|
||||||
|
|
||||||
|
|
||||||
|
class TestErnie4_5VLProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Create mock object for Ernie4_5Processor instance
|
# Create mock object for Ernie4_5Processor instance
|
||||||
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None) as mock_init:
|
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None) as mock_init:
|
||||||
@@ -39,38 +44,41 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
# Set necessary attributes
|
# Set necessary attributes
|
||||||
self.processor.tokenizer = MagicMock()
|
self.processor.tokenizer = MagicMock()
|
||||||
self.processor.tokenizer.eos_token_id = 1
|
self.processor.tokenizer.eos_token_id = 1
|
||||||
self.processor.decode_status = {}
|
self.processor.decode_status = {"test": []}
|
||||||
self.processor.reasoning_end_dict = {}
|
self.processor.reasoning_end_dict = {}
|
||||||
self.processor.tool_parser_dict = {}
|
self.processor.tool_parser_dict = {}
|
||||||
self.processor.generation_config = MagicMock()
|
self.processor.generation_config = MagicMock()
|
||||||
self.processor.eos_token_ids = [1]
|
self.processor.eos_token_ids = [1]
|
||||||
self.processor.reasoning_parser = MagicMock()
|
self.processor.reasoning_parser = MockReasoningParser()
|
||||||
self.processor._check_mm_limits = MagicMock()
|
self.processor.model_status_dict = {"test": "think_start"}
|
||||||
self.processor.ernie4_5_processor = MagicMock()
|
self.processor.ernie4_5_processor = MagicMock()
|
||||||
self.processor.pack_outputs = MagicMock()
|
|
||||||
|
|
||||||
# Mock ids2tokens method
|
# Mock ids2tokens method
|
||||||
def mock_ids2tokens(token_ids, task_id):
|
def mock_ids2tokens(token_ids, task_id):
|
||||||
self.processor.decode_status[task_id] = "mock_decode_status"
|
|
||||||
return "delta_text", [2, 3], "previous_texts"
|
return "delta_text", [2, 3], "previous_texts"
|
||||||
|
|
||||||
self.processor.ids2tokens = mock_ids2tokens
|
self.processor.ids2tokens = mock_ids2tokens
|
||||||
|
|
||||||
def mock_messages2ids(request, **kwargs):
|
def mock_request2ids(request, **kwargs):
|
||||||
if "chat_template" in kwargs:
|
return {"input_ids": np.array([1, 2, 3]), "prompt_token_ids": [0]}
|
||||||
return [1]
|
|
||||||
else:
|
def mock_check_mm_limits(item):
|
||||||
return [0]
|
pass
|
||||||
|
|
||||||
def mock_apply_default_parameters(request):
|
def mock_apply_default_parameters(request):
|
||||||
return request
|
return request
|
||||||
|
|
||||||
|
def mock_pack_outputs(outputs):
|
||||||
|
return outputs
|
||||||
|
|
||||||
self.processor._apply_default_parameters = mock_apply_default_parameters
|
self.processor._apply_default_parameters = mock_apply_default_parameters
|
||||||
|
self.processor._check_mm_limits = mock_check_mm_limits
|
||||||
|
self.processor.ernie4_5_processor.request2ids = mock_request2ids
|
||||||
|
self.processor.pack_outputs = mock_pack_outputs
|
||||||
|
|
||||||
# Mock reasoning parser
|
# Mock reasoning parser
|
||||||
self.mock_reasoning_parser = MagicMock()
|
self.mock_reasoning_parser = MagicMock()
|
||||||
self.mock_reasoning_parser.__class__.__name__ = "ErnieX1ReasoningParser"
|
self.mock_reasoning_parser.extract_reasoning_content_streaming.return_value = None
|
||||||
# self.mock_reasoning_parser.extract_reasoning_content_streaming.return_value = ("reasoning", "text")
|
|
||||||
self.processor.reasoning_parser = self.mock_reasoning_parser
|
self.processor.reasoning_parser = self.mock_reasoning_parser
|
||||||
|
|
||||||
# Mock tool parser
|
# Mock tool parser
|
||||||
@@ -80,82 +88,26 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
self.mock_tool_parser_obj.return_value = self.mock_tool_parser
|
self.mock_tool_parser_obj.return_value = self.mock_tool_parser
|
||||||
self.processor.tool_parser_obj = self.mock_tool_parser_obj
|
self.processor.tool_parser_obj = self.mock_tool_parser_obj
|
||||||
|
|
||||||
def test_process_request_dict_with_options(self):
|
def test_think_status(self):
|
||||||
request_dict = {
|
"""测试 思考机制"""
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
request = {
|
||||||
"prompt_token_ids": [1, 1, 1],
|
"prompt": "hello",
|
||||||
|
"request_id": "test_1",
|
||||||
|
"prompt_token_ids": [1, 2, 3],
|
||||||
}
|
}
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
self.processor.reasoning_parser = MagicMock()
|
||||||
self.assertEqual(request_dict["enable_thinking"], True)
|
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
|
||||||
|
self.processor.model_status_dict = {}
|
||||||
|
self.processor.process_request_dict(request, max_model_len=512)
|
||||||
|
self.assertEqual(request["enable_thinking"], True)
|
||||||
|
|
||||||
request_dict = {
|
request = {
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
"prompt": "hello",
|
||||||
"chat_template_kwargs": {"enable_thinking": True},
|
"request_id": "test",
|
||||||
"prompt_token_ids": [1, 1, 1],
|
"prompt_token_ids": [1, 2, 3],
|
||||||
}
|
}
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
self.processor.process_request_dict(request, max_model_len=512)
|
||||||
self.assertEqual(request_dict["enable_thinking"], True)
|
self.assertEqual(request["enable_thinking"], True)
|
||||||
|
|
||||||
request_dict = {
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"chat_template_kwargs": {"enable_thinking": False},
|
|
||||||
"prompt_token_ids": [1, 1, 1],
|
|
||||||
}
|
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
|
||||||
self.assertEqual(request_dict["enable_thinking"], True)
|
|
||||||
|
|
||||||
request_dict = {
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"chat_template_kwargs": {"options": {"thinking_mode": "open"}},
|
|
||||||
"prompt_token_ids": [1, 1, 1],
|
|
||||||
}
|
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
|
||||||
self.assertEqual(request_dict["enable_thinking"], True)
|
|
||||||
|
|
||||||
request_dict = {
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"chat_template_kwargs": {"options": {"thinking_mode": "close"}},
|
|
||||||
"prompt_token_ids": [1, 1, 1],
|
|
||||||
}
|
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
|
||||||
self.assertEqual(request_dict["enable_thinking"], True)
|
|
||||||
|
|
||||||
request_dict = {
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
|
|
||||||
"prompt_token_ids": [1, 1, 1],
|
|
||||||
}
|
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
|
||||||
self.assertEqual(request_dict["enable_thinking"], True)
|
|
||||||
|
|
||||||
request_dict = {
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"chat_template_kwargs": {"options": {"thinking_mode": "123"}},
|
|
||||||
"prompt_token_ids": [1, 1, 1],
|
|
||||||
}
|
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
|
||||||
self.assertEqual(request_dict["enable_thinking"], True)
|
|
||||||
|
|
||||||
request_dict = {
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"chat_template_kwargs": {"options": {"thinking_mode": "close"}},
|
|
||||||
}
|
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
|
||||||
self.assertEqual(request_dict["enable_thinking"], False)
|
|
||||||
|
|
||||||
request_dict = {
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
|
|
||||||
}
|
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
|
||||||
self.assertEqual(request_dict["enable_thinking"], False)
|
|
||||||
|
|
||||||
request_dict = {
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"chat_template_kwargs": {"enable_thinking": False},
|
|
||||||
}
|
|
||||||
self.processor.process_request_dict(request_dict, 100)
|
|
||||||
self.assertEqual(request_dict["enable_thinking"], False)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataProcessorTargetMethods(unittest.TestCase):
|
class TestDataProcessorTargetMethods(unittest.TestCase):
|
||||||
|
|||||||
@@ -793,6 +793,8 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
|
|||||||
self.processor.processor = MagicMock()
|
self.processor.processor = MagicMock()
|
||||||
self.processor.limit_mm_per_prompt = {"image": 1, "video": 1, "audio": 1}
|
self.processor.limit_mm_per_prompt = {"image": 1, "video": 1, "audio": 1}
|
||||||
self.processor.eos_token_ids = [1]
|
self.processor.eos_token_ids = [1]
|
||||||
|
self.processor.reasoning_parser = None
|
||||||
|
self.processor.model_status_dict = {}
|
||||||
|
|
||||||
# 模拟 _apply_default_parameters
|
# 模拟 _apply_default_parameters
|
||||||
def mock_apply_default_parameters(request_or_dict):
|
def mock_apply_default_parameters(request_or_dict):
|
||||||
@@ -971,6 +973,7 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
|
|||||||
"prompt": "test prompt",
|
"prompt": "test prompt",
|
||||||
"multimodal_data": {"image": ["image1"]},
|
"multimodal_data": {"image": ["image1"]},
|
||||||
"metadata": {"generated_token_ids": []},
|
"metadata": {"generated_token_ids": []},
|
||||||
|
"request_id": "test-request",
|
||||||
}
|
}
|
||||||
request_obj.to_dict.return_value = request_dict
|
request_obj.to_dict.return_value = request_dict
|
||||||
|
|
||||||
@@ -1157,6 +1160,27 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
|
|||||||
self.assertTrue(np.array_equal(result["image_type_ids"], np.array([0])))
|
self.assertTrue(np.array_equal(result["image_type_ids"], np.array([0])))
|
||||||
self.assertTrue(np.array_equal(result["position_ids"], np.array([[0], [1], [2]], dtype=np.int64)))
|
self.assertTrue(np.array_equal(result["position_ids"], np.array([[0], [1], [2]], dtype=np.int64)))
|
||||||
|
|
||||||
|
def test_think_status(self):
|
||||||
|
"""测试 思考机制"""
|
||||||
|
request = {
|
||||||
|
"prompt": "hello",
|
||||||
|
"request_id": "test_1",
|
||||||
|
"prompt_token_ids": [1, 2, 3],
|
||||||
|
}
|
||||||
|
self.processor.reasoning_parser = MagicMock()
|
||||||
|
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
|
||||||
|
self.processor.model_status_dict = {}
|
||||||
|
self.processor.process_request_dict(request, max_model_len=512)
|
||||||
|
self.assertEqual(request["enable_thinking"], True)
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"prompt": "hello",
|
||||||
|
"request_id": "test",
|
||||||
|
"prompt_token_ids": [1, 2, 3],
|
||||||
|
}
|
||||||
|
self.processor.process_request_dict(request, max_model_len=512)
|
||||||
|
self.assertEqual(request["enable_thinking"], True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -365,6 +365,31 @@ class TestQwenVLProcessor(unittest.TestCase):
|
|||||||
# Verify both methods produce identical prompt strings
|
# Verify both methods produce identical prompt strings
|
||||||
self.assertEqual(prompt, prompt2)
|
self.assertEqual(prompt, prompt2)
|
||||||
|
|
||||||
|
def test_think_status(self):
|
||||||
|
"""测试 思考机制"""
|
||||||
|
request = {
|
||||||
|
"prompt": "hello",
|
||||||
|
"request_id": "test_1",
|
||||||
|
"prompt_token_ids": [1, 2, 3],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
}
|
||||||
|
self.processor.reasoning_parser = MagicMock()
|
||||||
|
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
|
||||||
|
self.processor.model_status_dict = {}
|
||||||
|
self.processor.process_request_dict(request, max_model_len=512)
|
||||||
|
self.assertEqual(request["enable_thinking"], True)
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"prompt": "hello",
|
||||||
|
"request_id": "test",
|
||||||
|
"prompt_token_ids": [1, 2, 3],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
}
|
||||||
|
self.processor.process_request_dict(request, max_model_len=512)
|
||||||
|
self.assertEqual(request["enable_thinking"], True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -266,7 +266,7 @@ class DataProcessorTestCase(unittest.TestCase):
|
|||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def extract_reasoning_content(self, full_text, response_dict):
|
def extract_reasoning_content(self, full_text, response_dict, model_status):
|
||||||
return reasoning_content, f"{full_text}!"
|
return reasoning_content, f"{full_text}!"
|
||||||
|
|
||||||
return DummyReasoning(tokenizer)
|
return DummyReasoning(tokenizer)
|
||||||
@@ -409,6 +409,7 @@ class DataProcessorTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_process_response_with_reasoning_and_tools(self):
|
def test_process_response_with_reasoning_and_tools(self):
|
||||||
processor = self.processor
|
processor = self.processor
|
||||||
|
processor.model_status_dict = {"resp": "normal"}
|
||||||
|
|
||||||
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer)
|
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer)
|
||||||
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-only")
|
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-only")
|
||||||
@@ -435,7 +436,7 @@ class DataProcessorTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_process_response_dict_normal_with_reasoning(self):
|
def test_process_response_dict_normal_with_reasoning(self):
|
||||||
processor = self.processor
|
processor = self.processor
|
||||||
|
processor.model_status_dict = {"normal": "normal"}
|
||||||
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer, reasoning_content="because")
|
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer, reasoning_content="because")
|
||||||
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-text")
|
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-text")
|
||||||
|
|
||||||
@@ -471,10 +472,10 @@ class DataProcessorTestCase(unittest.TestCase):
|
|||||||
self.addCleanup(lambda: setattr(processor, "process_response_dict_normal", original_normal))
|
self.addCleanup(lambda: setattr(processor, "process_response_dict_normal", original_normal))
|
||||||
|
|
||||||
response = {"outputs": {}, "finished": False, "request_id": "req"}
|
response = {"outputs": {}, "finished": False, "request_id": "req"}
|
||||||
self.assertEqual(processor.process_response_dict(response), "stream")
|
self.assertEqual(processor.process_response_dict(response, stream=True, enable_thinking=True), "stream")
|
||||||
self.assertTrue(calls["stream"]["enable_thinking"])
|
self.assertTrue(calls["stream"]["enable_thinking"])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
processor.process_response_dict(response, stream=False, enable_thinking=None),
|
processor.process_response_dict(response, stream=False, enable_thinking=True),
|
||||||
"normal",
|
"normal",
|
||||||
)
|
)
|
||||||
self.assertTrue(calls["normal"]["enable_thinking"])
|
self.assertTrue(calls["normal"]["enable_thinking"])
|
||||||
|
|||||||
@@ -0,0 +1,197 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
|
from fastdeploy.reasoning.qwen3_reasoning_parsers import Qwen3ReasoningParser
|
||||||
|
|
||||||
|
|
||||||
|
class MockTokenizer:
|
||||||
|
"""Minimal tokenizer with vocab for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.vocab = {
|
||||||
|
"<think>": 100,
|
||||||
|
"</think>": 101,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
"""Return vocab dict for testing."""
|
||||||
|
return self.vocab
|
||||||
|
|
||||||
|
|
||||||
|
class MissingTokenTokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.vocab = {
|
||||||
|
"</think>": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
"""Return vocab dict for testing."""
|
||||||
|
return self.vocab
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwen3ReasoningParser(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.parser = Qwen3ReasoningParser(MockTokenizer())
|
||||||
|
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
|
||||||
|
self.tokenizer = MockTokenizer()
|
||||||
|
|
||||||
|
def test_missing_token(self):
|
||||||
|
with self.assertRaises(RuntimeError) as context:
|
||||||
|
Qwen3ReasoningParser(MissingTokenTokenizer())
|
||||||
|
exception_message = str(context.exception)
|
||||||
|
expected_message_part = "Qwen3 reasoning parser could not find the following token ids"
|
||||||
|
self.assertIn(expected_message_part, exception_message)
|
||||||
|
|
||||||
|
def test_get_model_status(self):
|
||||||
|
status = self.parser.get_model_status([1, 2, 100])
|
||||||
|
self.assertEqual(status, "think_start")
|
||||||
|
status = self.parser.get_model_status([1, 2, 101])
|
||||||
|
self.assertEqual(status, "think_end")
|
||||||
|
status = self.parser.get_model_status([1])
|
||||||
|
self.assertEqual(status, "think_start")
|
||||||
|
|
||||||
|
def test_streaming_thinking_content(self):
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="a",
|
||||||
|
delta_text="a",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[200],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.reasoning_content, "a")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="a</think>b",
|
||||||
|
delta_text="a</think>b",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[99, 101, 102],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.reasoning_content, "a")
|
||||||
|
self.assertEqual(msg.content, "b")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="a</think>",
|
||||||
|
current_text="a</think>b",
|
||||||
|
delta_text="b",
|
||||||
|
previous_token_ids=[1, 101],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[102],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "b")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="a",
|
||||||
|
delta_text="a",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.reasoning_content, "a")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="a",
|
||||||
|
delta_text="a",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[200],
|
||||||
|
model_status="think_end",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "a")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello",
|
||||||
|
current_text="hello</think>hi",
|
||||||
|
delta_text="</think>hi",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[101, 200],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "hi")
|
||||||
|
self.assertEqual(msg.reasoning_content, "")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello",
|
||||||
|
current_text="hello</think>hi",
|
||||||
|
delta_text="hi",
|
||||||
|
previous_token_ids=[100],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, None)
|
||||||
|
self.assertEqual(msg.reasoning_content, "hi")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello",
|
||||||
|
current_text="hello</think>hi",
|
||||||
|
delta_text="<think>hi",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[100, 200],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "")
|
||||||
|
self.assertEqual(msg.reasoning_content, "hi")
|
||||||
|
|
||||||
|
def test_none_streaming_thinking_content(self):
|
||||||
|
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="a",
|
||||||
|
request={},
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, None)
|
||||||
|
self.assertEqual(content, "a")
|
||||||
|
|
||||||
|
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="a</think>b",
|
||||||
|
request={},
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, "a")
|
||||||
|
self.assertEqual(content, "b")
|
||||||
|
|
||||||
|
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="a",
|
||||||
|
request={},
|
||||||
|
model_status="think_end",
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, None)
|
||||||
|
self.assertEqual(content, "a")
|
||||||
|
|
||||||
|
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="<think>a",
|
||||||
|
request={},
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, None)
|
||||||
|
self.assertEqual(content, "<think>a")
|
||||||
|
|
||||||
|
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="<think>a</think>b",
|
||||||
|
request={},
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, "a")
|
||||||
|
self.assertEqual(content, "b")
|
||||||
|
|
||||||
|
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="</think>b",
|
||||||
|
request={},
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, "")
|
||||||
|
self.assertEqual(content, "b")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -31,10 +31,25 @@ class DummyTokenizer:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.vocab = {
|
self.vocab = {
|
||||||
"</think>": 100,
|
"</think>": 100,
|
||||||
"<tool_call>": 101,
|
"<think>": 101,
|
||||||
"</tool_call>": 102,
|
"<tool_call>": 102,
|
||||||
"<response>": 103,
|
"</tool_call>": 103,
|
||||||
"</response>": 104,
|
"<response>": 104,
|
||||||
|
"</response>": 105,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
"""Return vocab dict for testing."""
|
||||||
|
return self.vocab
|
||||||
|
|
||||||
|
|
||||||
|
class MissingTokenTokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.vocab = {
|
||||||
|
"</think>": 100,
|
||||||
|
"<think>": 101,
|
||||||
|
"<tool_call>": 102,
|
||||||
|
"</tool_call>": 103,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
@@ -132,6 +147,17 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
|
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
|
||||||
self.tokenizer = DummyTokenizer()
|
self.tokenizer = DummyTokenizer()
|
||||||
|
|
||||||
|
def test_missing_token(self):
|
||||||
|
with self.assertRaises(RuntimeError) as context:
|
||||||
|
ErnieX1ReasoningParser(MissingTokenTokenizer())
|
||||||
|
exception_message = str(context.exception)
|
||||||
|
expected_message_part = "ernie x1 reasoning parser could not find the following token ids"
|
||||||
|
self.assertIn(expected_message_part, exception_message)
|
||||||
|
|
||||||
|
def test_get_model_status(self):
|
||||||
|
model_status = self.parser.get_model_status([88, 99, 104])
|
||||||
|
self.assertEqual(model_status, "response_start")
|
||||||
|
|
||||||
# ---- Streaming parsing ----
|
# ---- Streaming parsing ----
|
||||||
def test_streaming_thinking_content(self):
|
def test_streaming_thinking_content(self):
|
||||||
msg = self.parser.extract_reasoning_content_streaming(
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
@@ -141,6 +167,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[],
|
current_token_ids=[],
|
||||||
delta_token_ids=[200],
|
delta_token_ids=[200],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertEqual(msg.reasoning_content, "a")
|
self.assertEqual(msg.reasoning_content, "a")
|
||||||
|
|
||||||
@@ -152,6 +179,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[],
|
current_token_ids=[],
|
||||||
delta_token_ids=[201],
|
delta_token_ids=[201],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertEqual(msg.reasoning_content, "\n")
|
self.assertEqual(msg.reasoning_content, "\n")
|
||||||
|
|
||||||
@@ -163,6 +191,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[],
|
current_token_ids=[],
|
||||||
delta_token_ids=[self.parser.think_end_token_id],
|
delta_token_ids=[self.parser.think_end_token_id],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsNone(msg)
|
self.assertIsNone(msg)
|
||||||
|
|
||||||
@@ -174,6 +203,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[],
|
current_token_ids=[],
|
||||||
delta_token_ids=[202],
|
delta_token_ids=[202],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertEqual(msg.content, "h")
|
self.assertEqual(msg.content, "h")
|
||||||
|
|
||||||
@@ -185,6 +215,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[],
|
current_token_ids=[],
|
||||||
delta_token_ids=[203],
|
delta_token_ids=[203],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertEqual(msg.content, "\n")
|
self.assertEqual(msg.content, "\n")
|
||||||
|
|
||||||
@@ -197,6 +228,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[],
|
current_token_ids=[],
|
||||||
delta_token_ids=[self.parser.vocab["<response>"]],
|
delta_token_ids=[self.parser.vocab["<response>"]],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -207,6 +239,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[],
|
current_token_ids=[],
|
||||||
delta_token_ids=[204],
|
delta_token_ids=[204],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(msg, DeltaMessage)
|
self.assertIsInstance(msg, DeltaMessage)
|
||||||
self.assertEqual(msg.content, "\n")
|
self.assertEqual(msg.content, "\n")
|
||||||
@@ -219,9 +252,82 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[],
|
current_token_ids=[],
|
||||||
delta_token_ids=[self.parser.vocab["</response>"]],
|
delta_token_ids=[self.parser.vocab["</response>"]],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_extract_reasoning_content_streaming(self):
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello</think>",
|
||||||
|
current_text="hello</think><response>",
|
||||||
|
delta_text="</think><response>",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[100, 200],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "")
|
||||||
|
self.assertEqual(msg.reasoning_content, "")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello</think>",
|
||||||
|
current_text="hello</think><response>hi",
|
||||||
|
delta_text="</think><response>hi",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[100, 200],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "hi")
|
||||||
|
self.assertEqual(msg.reasoning_content, "")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="hello</think><response>hi",
|
||||||
|
delta_text="hello</think><response>hi",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[100, 200],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "hi")
|
||||||
|
self.assertEqual(msg.reasoning_content, "hello")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello</think><response>",
|
||||||
|
current_text="hello</think><response>hi",
|
||||||
|
delta_text="hi",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[100, 200],
|
||||||
|
model_status="think_end",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "hi")
|
||||||
|
self.assertEqual(msg.reasoning_content, None)
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello</think><response>",
|
||||||
|
current_text="hello</think><response>hi",
|
||||||
|
delta_text="hi",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[100, 200],
|
||||||
|
model_status="response_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "hi")
|
||||||
|
self.assertEqual(msg.reasoning_content, None)
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello</think><response>hi</response>",
|
||||||
|
current_text="hello</think><response>hi</response>end",
|
||||||
|
delta_text="end",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[100, 200],
|
||||||
|
model_status="response_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg, None)
|
||||||
|
|
||||||
def test_streaming_tool_call(self):
|
def test_streaming_tool_call(self):
|
||||||
msg = self.parser.extract_reasoning_content_streaming(
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
previous_text="</think>",
|
previous_text="</think>",
|
||||||
@@ -230,40 +336,48 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[],
|
current_token_ids=[],
|
||||||
delta_token_ids=[self.parser.vocab["<tool_call>"]],
|
delta_token_ids=[self.parser.vocab["<tool_call>"]],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsNone(msg)
|
self.assertIsNone(msg)
|
||||||
|
|
||||||
# ---- Batch parsing ----
|
# ---- Batch parsing ----
|
||||||
def test_batch_reasoning_and_response(self):
|
def test_batch_reasoning_and_response(self):
|
||||||
text = "abc\n</think>\n<response>hello\nworld</response>"
|
text = "abc\n</think>\n<response>hello\nworld</response>"
|
||||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||||
self.assertEqual(reasoning, "abc\n")
|
self.assertEqual(reasoning, "abc\n")
|
||||||
self.assertEqual(response, "hello\nworld")
|
self.assertEqual(response, "hello\nworld")
|
||||||
|
|
||||||
def test_batch_reasoning_and_tool_call(self):
|
def test_batch_reasoning_and_tool_call(self):
|
||||||
text = "abc</think><tool_call>call_here"
|
text = "abc</think><tool_call>call_here"
|
||||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||||
self.assertEqual(reasoning, "abc")
|
self.assertEqual(reasoning, "abc")
|
||||||
self.assertEqual(response, "")
|
self.assertEqual(response, "")
|
||||||
|
|
||||||
def test_batch_no_thinking_tag(self):
|
def test_batch_no_thinking_tag(self):
|
||||||
text = "no_thinking_here"
|
text = "no_thinking_here"
|
||||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||||
self.assertEqual(reasoning, "no_thinking_here")
|
self.assertEqual(reasoning, "no_thinking_here")
|
||||||
self.assertEqual(response, "")
|
self.assertEqual(response, "")
|
||||||
|
|
||||||
def test_batch_response_without_end_tag(self):
|
def test_batch_response_without_end_tag(self):
|
||||||
text = "abc</think><response>partial response"
|
text = "abc</think><response>partial response"
|
||||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||||
self.assertEqual(reasoning, "abc")
|
self.assertEqual(reasoning, "abc")
|
||||||
self.assertEqual(response, "partial response")
|
self.assertEqual(response, "partial response")
|
||||||
|
|
||||||
def test_batch_preserve_all_newlines(self):
|
def test_batch_preserve_all_newlines(self):
|
||||||
text = "abc\n</think>\n<response>line1\nline2\n</response>"
|
text = "abc\n</think>\n<response>line1\nline2\n</response>"
|
||||||
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
|
||||||
self.assertEqual(reasoning, "abc\n")
|
self.assertEqual(reasoning, "abc\n")
|
||||||
self.assertEqual(response, "line1\nline2\n")
|
self.assertEqual(response, "line1\nline2\n")
|
||||||
|
|
||||||
|
def test_extract_reasoning_content(self):
|
||||||
|
reasoning_content, response_content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="hello", request=self.request, model_status="response_start"
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, "")
|
||||||
|
self.assertEqual(response_content, "hello")
|
||||||
|
|
||||||
|
|
||||||
class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -272,6 +386,9 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
self.test_request = ChatCompletionRequest(
|
self.test_request = ChatCompletionRequest(
|
||||||
model="ernie-test", messages=[{"role": "user", "content": "test prompt"}]
|
model="ernie-test", messages=[{"role": "user", "content": "test prompt"}]
|
||||||
)
|
)
|
||||||
|
self.parser.token_status_mapping = {
|
||||||
|
100: "think_start",
|
||||||
|
}
|
||||||
|
|
||||||
def test_streaming_non_reasoning(self):
|
def test_streaming_non_reasoning(self):
|
||||||
result = self.parser.extract_reasoning_content_streaming(
|
result = self.parser.extract_reasoning_content_streaming(
|
||||||
@@ -281,6 +398,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[],
|
previous_token_ids=[],
|
||||||
current_token_ids=[200],
|
current_token_ids=[200],
|
||||||
delta_token_ids=[200],
|
delta_token_ids=[200],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertEqual(result.reasoning_content, "a")
|
self.assertEqual(result.reasoning_content, "a")
|
||||||
@@ -294,6 +412,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[200, 201],
|
previous_token_ids=[200, 201],
|
||||||
current_token_ids=[200, 201, 100],
|
current_token_ids=[200, 201, 100],
|
||||||
delta_token_ids=[100],
|
delta_token_ids=[100],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
||||||
@@ -305,6 +424,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[200, 201],
|
previous_token_ids=[200, 201],
|
||||||
current_token_ids=[200, 201, 100, 300, 400],
|
current_token_ids=[200, 201, 100, 300, 400],
|
||||||
delta_token_ids=[100, 300, 400],
|
delta_token_ids=[100, 300, 400],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertIsNone(result.reasoning_content)
|
self.assertIsNone(result.reasoning_content)
|
||||||
@@ -318,6 +438,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[200, 201, 202],
|
previous_token_ids=[200, 201, 202],
|
||||||
current_token_ids=[200, 201, 202, 100],
|
current_token_ids=[200, 201, 202, 100],
|
||||||
delta_token_ids=[100],
|
delta_token_ids=[100],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
||||||
@@ -329,9 +450,10 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[200, 201, 202],
|
previous_token_ids=[200, 201, 202],
|
||||||
current_token_ids=[200, 201, 202, 100, 200, 101],
|
current_token_ids=[200, 201, 202, 100, 200, 101],
|
||||||
delta_token_ids=[100, 200, 101],
|
delta_token_ids=[100, 200, 101],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertEqual(result.reasoning_content, "")
|
self.assertEqual(result.reasoning_content, None)
|
||||||
|
|
||||||
def test_streaming_with_reasoning_and_illegal_tool(self):
|
def test_streaming_with_reasoning_and_illegal_tool(self):
|
||||||
result = self.parser.extract_reasoning_content_streaming(
|
result = self.parser.extract_reasoning_content_streaming(
|
||||||
@@ -341,6 +463,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[200, 201, 202],
|
previous_token_ids=[200, 201, 202],
|
||||||
current_token_ids=[200, 201, 202, 100, 200, 101],
|
current_token_ids=[200, 201, 202, 100, 200, 101],
|
||||||
delta_token_ids=[109, 200, 101],
|
delta_token_ids=[109, 200, 101],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertEqual(result.content, "\n\nhello<tool_call>")
|
self.assertEqual(result.content, "\n\nhello<tool_call>")
|
||||||
@@ -353,6 +476,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[200, 201, 202],
|
previous_token_ids=[200, 201, 202],
|
||||||
current_token_ids=[200, 201, 202, 100, 200, 110],
|
current_token_ids=[200, 201, 202, 100, 200, 110],
|
||||||
delta_token_ids=[100, 200, 110],
|
delta_token_ids=[100, 200, 110],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertEqual(result.reasoning_content, "hello")
|
self.assertEqual(result.reasoning_content, "hello")
|
||||||
@@ -366,6 +490,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[100],
|
previous_token_ids=[100],
|
||||||
current_token_ids=[100, 110, 111],
|
current_token_ids=[100, 110, 111],
|
||||||
delta_token_ids=[110, 111],
|
delta_token_ids=[110, 111],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertIsNone(result.reasoning_content)
|
self.assertIsNone(result.reasoning_content)
|
||||||
@@ -379,52 +504,140 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[101],
|
previous_token_ids=[101],
|
||||||
current_token_ids=[101, 110],
|
current_token_ids=[101, 110],
|
||||||
delta_token_ids=[110],
|
delta_token_ids=[110],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertEqual(result.reasoning_content, "hello")
|
self.assertEqual(result.reasoning_content, "hello")
|
||||||
|
|
||||||
|
def test_think_end_status_streaming(self):
|
||||||
|
result = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="<tool_call>",
|
||||||
|
current_text="<tool_call>hello",
|
||||||
|
delta_text="hello",
|
||||||
|
previous_token_ids=[101],
|
||||||
|
current_token_ids=[101, 110],
|
||||||
|
delta_token_ids=[110],
|
||||||
|
model_status="think_end",
|
||||||
|
)
|
||||||
|
self.assertIs(result, None)
|
||||||
|
|
||||||
|
result = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello, ",
|
||||||
|
current_text="hello, hi",
|
||||||
|
delta_text="hi",
|
||||||
|
previous_token_ids=[101],
|
||||||
|
current_token_ids=[101, 110],
|
||||||
|
delta_token_ids=[110],
|
||||||
|
model_status="think_end",
|
||||||
|
)
|
||||||
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
|
self.assertEqual(result.content, "hi")
|
||||||
|
|
||||||
|
def test_other_status_streaming(self):
|
||||||
|
result = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="hello, ",
|
||||||
|
current_text="hello, hi",
|
||||||
|
delta_text="hi",
|
||||||
|
previous_token_ids=[101],
|
||||||
|
current_token_ids=[101, 110],
|
||||||
|
delta_token_ids=[110],
|
||||||
|
model_status="tool_call_start",
|
||||||
|
)
|
||||||
|
self.assertIs(result, None)
|
||||||
|
|
||||||
def test_batch_no_think_end(self):
|
def test_batch_no_think_end(self):
|
||||||
reasoning, content = self.parser.extract_reasoning_content(
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
model_output="direct response", request=self.test_request
|
model_output="direct response", request=self.test_request, model_status="think_start"
|
||||||
)
|
)
|
||||||
self.assertEqual(reasoning, "direct response")
|
self.assertEqual(reasoning, "direct response")
|
||||||
self.assertEqual(content, "")
|
self.assertEqual(content, "")
|
||||||
|
|
||||||
def test_batch_no_think_end_with_tool(self):
|
def test_batch_no_think_end_with_tool(self):
|
||||||
reasoning, content = self.parser.extract_reasoning_content(
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
model_output="direct response<tool_call>abc", request=self.test_request
|
model_output="direct response<tool_call>abc", request=self.test_request, model_status="think_start"
|
||||||
)
|
)
|
||||||
self.assertEqual(reasoning, "direct response<tool_call>abc")
|
self.assertEqual(reasoning, "direct response<tool_call>abc")
|
||||||
self.assertEqual(content, "")
|
self.assertEqual(content, "")
|
||||||
|
|
||||||
def test_batch_think_end_normal_content(self):
|
def test_batch_think_end_normal_content(self):
|
||||||
reasoning, content = self.parser.extract_reasoning_content(
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
model_output="reasoning</think>\nresponse", request=self.test_request
|
model_output="reasoning</think>\nresponse", request=self.test_request, model_status="think_start"
|
||||||
)
|
)
|
||||||
self.assertEqual(reasoning, "reasoning")
|
self.assertEqual(reasoning, "reasoning")
|
||||||
self.assertEqual(content, "\nresponse")
|
self.assertEqual(content, "\nresponse")
|
||||||
|
|
||||||
def test_batch_think_end_with_tool(self):
|
def test_batch_think_end_with_tool(self):
|
||||||
reasoning, content = self.parser.extract_reasoning_content(
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
model_output="reasoning</think>\n<tool_call>tool params</tool_call>", request=self.test_request
|
model_output="reasoning</think>\n<tool_call>tool params</tool_call>",
|
||||||
|
request=self.test_request,
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertEqual(reasoning, "reasoning")
|
self.assertEqual(reasoning, "reasoning")
|
||||||
self.assertEqual(content, "")
|
self.assertEqual(content, "")
|
||||||
|
|
||||||
def test_batch_think_end_with_illegal_tool(self):
|
def test_batch_think_end_with_illegal_tool(self):
|
||||||
reasoning, content = self.parser.extract_reasoning_content(
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
model_output="reasoning</think>\nABC\n<tool_call>tool params</tool_call>", request=self.test_request
|
model_output="reasoning</think>\nABC\n<tool_call>tool params</tool_call>",
|
||||||
|
request=self.test_request,
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertEqual(reasoning, "reasoning")
|
self.assertEqual(reasoning, "reasoning")
|
||||||
self.assertEqual(content, "\nABC\n<tool_call>tool params</tool_call>")
|
self.assertEqual(content, "\nABC\n<tool_call>tool params</tool_call>")
|
||||||
|
|
||||||
def test_batch_think_end_content_with_newline(self):
|
def test_batch_think_end_content_with_newline(self):
|
||||||
reasoning, content = self.parser.extract_reasoning_content(
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
model_output="reasoning</think>\n\n actual response", request=self.test_request
|
model_output="reasoning</think>\n\n actual response",
|
||||||
|
request=self.test_request,
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertEqual(reasoning, "reasoning")
|
self.assertEqual(reasoning, "reasoning")
|
||||||
self.assertEqual(content, "\n\n actual response")
|
self.assertEqual(content, "\n\n actual response")
|
||||||
|
|
||||||
|
def test_think_end_status_non_streaming(self):
|
||||||
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="response", request=self.test_request, model_status="think_end"
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning, "")
|
||||||
|
self.assertEqual(content, "response")
|
||||||
|
|
||||||
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="<tool_call>response", request=self.test_request, model_status="think_end"
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning, "")
|
||||||
|
self.assertEqual(content, "")
|
||||||
|
|
||||||
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="\n 1<tool_call>response", request=self.test_request, model_status="think_end"
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning, "")
|
||||||
|
self.assertEqual(content, "\n 1<tool_call>response")
|
||||||
|
|
||||||
|
def test_other_status_non_streaming(self):
|
||||||
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="response", request=self.test_request, model_status="tool_call_start"
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning, "")
|
||||||
|
self.assertEqual(content, "")
|
||||||
|
|
||||||
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="response", request=self.test_request, model_status="tool_call_end"
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning, "")
|
||||||
|
self.assertEqual(content, "")
|
||||||
|
|
||||||
|
def test_find_last_special_token(self):
|
||||||
|
result = self.parser.find_last_special_token([100, 110, 120, 130])
|
||||||
|
self.assertEqual(result, 100)
|
||||||
|
result = self.parser.find_last_special_token([0])
|
||||||
|
self.assertEqual(result, -1)
|
||||||
|
|
||||||
|
def test_get_model_status(self):
|
||||||
|
result = self.parser.get_model_status([100, 110, 120, 130])
|
||||||
|
self.assertEqual(result, "think_start")
|
||||||
|
|
||||||
|
result = self.parser.get_model_status([0])
|
||||||
|
self.assertEqual(result, "think_start")
|
||||||
|
|
||||||
|
|
||||||
class TestErnieVLReasoningParser(unittest.TestCase):
|
class TestErnieVLReasoningParser(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -442,6 +655,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[200, 201, 202],
|
previous_token_ids=[200, 201, 202],
|
||||||
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
|
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
|
||||||
delta_token_ids=[100, 110, 120, 130],
|
delta_token_ids=[100, 110, 120, 130],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertEqual(result.reasoning_content, "")
|
self.assertEqual(result.reasoning_content, "")
|
||||||
@@ -455,6 +669,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[200, 201, 202, 100],
|
previous_token_ids=[200, 201, 202, 100],
|
||||||
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
|
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
|
||||||
delta_token_ids=[110, 120, 130],
|
delta_token_ids=[110, 120, 130],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertIsNone(result.reasoning_content)
|
self.assertIsNone(result.reasoning_content)
|
||||||
@@ -468,6 +683,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
|
|||||||
previous_token_ids=[200, 201, 202],
|
previous_token_ids=[200, 201, 202],
|
||||||
current_token_ids=[200, 201, 202, 110, 120, 130],
|
current_token_ids=[200, 201, 202, 110, 120, 130],
|
||||||
delta_token_ids=[110, 120, 130],
|
delta_token_ids=[110, 120, 130],
|
||||||
|
model_status="think_start",
|
||||||
)
|
)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
self.assertIsNone(result.content)
|
self.assertIsNone(result.content)
|
||||||
@@ -475,7 +691,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
|
|||||||
|
|
||||||
def test_extract_reasoning_content(self):
|
def test_extract_reasoning_content(self):
|
||||||
reasoning, content = self.parser.extract_reasoning_content(
|
reasoning, content = self.parser.extract_reasoning_content(
|
||||||
model_output="reasoning</think>\nactual response", request=self.test_request
|
model_output="reasoning</think>\nactual response", request=self.test_request, model_status="think_start"
|
||||||
)
|
)
|
||||||
self.assertEqual(reasoning, "reasoning")
|
self.assertEqual(reasoning, "reasoning")
|
||||||
self.assertEqual(content, "\nactual response")
|
self.assertEqual(content, "\nactual response")
|
||||||
|
|||||||
@@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
|
from fastdeploy.reasoning.ernie_vl_reasoning_parsers import ErnieVLReasoningParser
|
||||||
|
|
||||||
|
|
||||||
|
class MockTokenizer:
|
||||||
|
"""Minimal tokenizer with vocab for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.vocab = {
|
||||||
|
"<think>": 100,
|
||||||
|
"</think>": 101,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
"""Return vocab dict for testing."""
|
||||||
|
return self.vocab
|
||||||
|
|
||||||
|
|
||||||
|
class TestErnieVLReasoningParser(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.parser = ErnieVLReasoningParser(MockTokenizer())
|
||||||
|
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
|
||||||
|
self.tokenizer = MockTokenizer()
|
||||||
|
|
||||||
|
def test_get_model_status(self):
|
||||||
|
status = self.parser.get_model_status([1, 2, 100])
|
||||||
|
self.assertEqual(status, "think_start")
|
||||||
|
status = self.parser.get_model_status([1, 2, 101])
|
||||||
|
self.assertEqual(status, "think_end")
|
||||||
|
status = self.parser.get_model_status([1])
|
||||||
|
self.assertEqual(status, "think_start")
|
||||||
|
|
||||||
|
def test_streaming_thinking_content(self):
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="a",
|
||||||
|
delta_text="a",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[200],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.reasoning_content, "a")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="a</think>b",
|
||||||
|
delta_text="a</think>b",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[100, 101, 102],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.reasoning_content, "a")
|
||||||
|
self.assertEqual(msg.content, "b")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="a</think>",
|
||||||
|
current_text="a</think>b",
|
||||||
|
delta_text="b",
|
||||||
|
previous_token_ids=[1, 101],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[102],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "b")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="a",
|
||||||
|
delta_text="a",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.reasoning_content, "a")
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="a",
|
||||||
|
delta_text="a",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[200],
|
||||||
|
model_status="think_end",
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "a")
|
||||||
|
|
||||||
|
def test_none_streaming_thinking_content(self):
|
||||||
|
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="a",
|
||||||
|
request={},
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, "")
|
||||||
|
self.assertEqual(content, "a")
|
||||||
|
|
||||||
|
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="a</think>b",
|
||||||
|
request={},
|
||||||
|
model_status="think_start",
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, "a")
|
||||||
|
self.assertEqual(content, "b")
|
||||||
|
|
||||||
|
reasoning_content, content = self.parser.extract_reasoning_content(
|
||||||
|
model_output="a",
|
||||||
|
request={},
|
||||||
|
model_status="think_end",
|
||||||
|
)
|
||||||
|
self.assertEqual(reasoning_content, "")
|
||||||
|
self.assertEqual(content, "a")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user