[Feature]Optimization of Thinking Pattern Framework (#4302)

* add model status in vl

* add x1 parser

* add model_status

* fix parser

* fix parser

* fix parser

* fix parser

* Revert "fix parser"

This reverts commit 300f446d8a.

* fix parser

* fix

* fix

* fix

* fix

* fix parser

* fix unit test

* fix unit test

* add unit test

* fix

* fix

* add unit test

* fix unit test

* add unit test

* add unit test

* fix unit test

* fix unit test

* fix bug

* fix unit test

* x1 tool parser

* fix unit test

* fix unit test

* fix unit test

* fix n

* fix unit test

* add unit test

* add unit test

* remove pring
This commit is contained in:
luukunn
2025-12-10 16:17:06 +08:00
committed by GitHub
parent 1bffac866b
commit fbc9bce1e9
28 changed files with 1199 additions and 458 deletions
@@ -72,13 +72,12 @@ class ChatResponseProcessor:
else:
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output):
async def process_response_chat(self, request_outputs, stream, include_stop_str_in_output):
"""
Process a list of responses into a generator that yields each processed response as it's generated.
Args:
request_outputs: The list of outputs to be processed.
stream: Whether or not to stream the output.
enable_thinking: Whether or not to show thinking messages.
include_stop_str_in_output: Whether or not to include stop strings in the output.
"""
for request_output in request_outputs:
@@ -99,7 +98,6 @@ class ChatResponseProcessor:
response = await self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
audio_tokens=all_audio_tokens,
tts=tts,
@@ -108,7 +106,6 @@ class ChatResponseProcessor:
response = self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
audio_tokens=all_audio_tokens,
tts=tts,
@@ -127,7 +124,6 @@ class ChatResponseProcessor:
yield self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
elif stream:
@@ -156,14 +152,12 @@ class ChatResponseProcessor:
await self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
else:
self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
text = {"type": "text", "text": request_output["outputs"]["text"]}
@@ -185,14 +179,12 @@ class ChatResponseProcessor:
await self.data_processor.process_response_dict(
response_dict=part["request_output"],
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
else:
self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
@@ -220,8 +220,6 @@ class OpenAIServingChat:
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
enable_thinking = self._get_thinking_status(request)
include_stop_str_in_output = request.include_stop_str_in_output
stream_options = request.stream_options
@@ -277,7 +275,6 @@ class OpenAIServingChat:
generator = response_processor.process_response_chat(
response,
stream=True,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
@@ -507,7 +504,6 @@ class OpenAIServingChat:
"""
created_time = int(time.time())
num_choices = 1 if request.n is None else request.n
enable_thinking = self._get_thinking_status(request)
include_stop_str_in_output = request.include_stop_str_in_output
try:
@@ -562,7 +558,6 @@ class OpenAIServingChat:
generator = response_processor.process_response_chat(
response,
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
async for data in generator:
@@ -824,23 +819,6 @@ class OpenAIServingChat:
api_server_logger.error(error_msg)
return None
def _get_thinking_status(self, request: ChatCompletionRequest) -> bool:
"""
Get the thinking status from the request.
"""
enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None
if enable_thinking is None:
enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None
options = request.chat_template_kwargs.get("options") if request.chat_template_kwargs else None
if options:
thinking_mode = options.get("thinking_mode")
if thinking_mode:
if thinking_mode == "close" or thinking_mode == "false":
enable_thinking = False
else:
enable_thinking = True
return enable_thinking
def _build_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
+45 -16
View File
@@ -61,6 +61,7 @@ class Ernie4_5Processor(BaseDataProcessor):
self.decode_status = dict()
self.tool_parser_dict = dict()
self.thinking_parser_dict = dict()
self.model_status_dict = dict()
self._load_tokenizer()
data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token} \
@@ -148,8 +149,18 @@ class Ernie4_5Processor(BaseDataProcessor):
request.set("temperature", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
request.enable_thinking = True
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
parts = request.request_id.split("_")
if len(parts) > 1:
real_req_id = parts[0]
index = int(parts[1])
n = request.get("n", 1)
for idx in range(index * n, (index + 1) * n):
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
else:
self.model_status_dict[request.request_id] = model_status
request.enable_thinking = model_status == "think_start"
data_processor_logger.info(f"Processed request: {request}")
return request
@@ -222,9 +233,19 @@ class Ernie4_5Processor(BaseDataProcessor):
request["temperature"] = 1
if request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
request["enable_thinking"] = True
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
parts = request["request_id"].split("_")
if len(parts) > 1:
real_req_id = parts[0]
index = int(parts[1])
n = request.get("n", 1)
for idx in range(index * n, (index + 1) * n):
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
else:
self.model_status_dict[request["request_id"]] = model_status
request["enable_thinking"] = model_status == "think_start"
data_processor_logger.info(f"Processed request dict: {request}")
return request
@@ -246,7 +267,11 @@ class Ernie4_5Processor(BaseDataProcessor):
token_ids = token_ids[:-1]
full_text = self.tokenizer.decode(token_ids)
if self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
full_text,
response_dict,
self.model_status_dict[req_id],
)
response_dict.outputs.text = text
response_dict.outputs.reasoning_content = reasoning_content
else:
@@ -257,6 +282,8 @@ class Ernie4_5Processor(BaseDataProcessor):
if tool_call_info.tools_called:
response_dict.outputs.tool_calls = tool_call_info.tool_calls
response_dict.outputs.text = tool_call_info.content
if req_id in self.model_status_dict:
del self.model_status_dict[req_id]
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
return None
@@ -287,7 +314,6 @@ class Ernie4_5Processor(BaseDataProcessor):
Returns:
Dict: response contain text fields
"""
enable_thinking = kwargs.get("enable_thinking")
token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
@@ -297,16 +323,17 @@ class Ernie4_5Processor(BaseDataProcessor):
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
if is_end:
full_text = previous_texts + delta_text
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
response_dict["outputs"]["text"] = full_text
if self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
full_text,
response_dict,
self.model_status_dict[req_id],
)
response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
else:
response_dict["outputs"]["text"] = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
@@ -316,6 +343,8 @@ class Ernie4_5Processor(BaseDataProcessor):
response_dict["outputs"]["completion_tokens"] = full_text
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.model_status_dict:
del self.model_status_dict[req_id]
return response_dict
def process_response_dict_streaming(self, response_dict, **kwargs):
@@ -328,7 +357,6 @@ class Ernie4_5Processor(BaseDataProcessor):
Returns:
Dict: response contain text fields
"""
enable_thinking = kwargs.get("enable_thinking")
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"]
@@ -338,9 +366,7 @@ class Ernie4_5Processor(BaseDataProcessor):
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["completion_tokens"] = delta_text
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
if self.reasoning_parser:
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
@@ -348,6 +374,7 @@ class Ernie4_5Processor(BaseDataProcessor):
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
self.model_status_dict[req_id],
)
response_dict["outputs"]["delta_message"] = reasoning_delta_message
reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None
@@ -374,6 +401,8 @@ class Ernie4_5Processor(BaseDataProcessor):
del self.decode_status[req_id]
if req_id in self.tool_parser_dict:
del self.tool_parser_dict[req_id]
if req_id in self.model_status_dict:
del self.model_status_dict[req_id]
return response_dict
def messages2ids(self, request_or_messages, **kwargs):
@@ -58,6 +58,7 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
self.tool_parser_dict = dict()
self.decode_status = dict()
self.model_status_dict = dict()
self._load_tokenizer()
# Generation config
@@ -242,15 +243,6 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
request[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
options = chat_template_kwargs.get("options")
if options:
thinking_mode = options.get("thinking_mode")
if thinking_mode:
if thinking_mode == "close" or thinking_mode == "false":
request["enable_thinking"] = False
else:
request["enable_thinking"] = True
request.setdefault("enable_thinking", True)
outputs = self.ernie4_5_processor.request2ids(request)
else:
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
@@ -279,6 +271,18 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
data_processor_logger.info(f"Processed request {request}")
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
parts = request["request_id"].split("_")
if len(parts) > 1:
real_req_id = parts[0]
index = int(parts[1])
n = request.get("n", 1)
for idx in range(index * n, (index + 1) * n):
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
else:
self.model_status_dict[request["request_id"]] = model_status
request["enable_thinking"] = model_status == "think_start"
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
@@ -314,21 +318,3 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
outs["position_ids"] = np.array(outs["position_ids"], dtype=np.int64)
return outs
def process_response_dict(self, response_dict, stream, **kwargs):
"""
Preprocess the response
Args:
response_dict (Dict): response for engine, contain ids fields
Returns:
Dict: response contain text fields
"""
enable_thinking = kwargs.pop("enable_thinking", True)
if enable_thinking is None:
enable_thinking = True
if stream:
return self.process_response_dict_streaming(response_dict, enable_thinking=enable_thinking, **kwargs)
else:
return self.process_response_dict_normal(response_dict, enable_thinking=enable_thinking, **kwargs)
@@ -254,6 +254,19 @@ class PaddleOCRVLProcessor(TextProcessor):
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
parts = request["request_id"].split("_")
if len(parts) > 1:
real_req_id = parts[0]
index = int(parts[1])
n = request.get("n", 1)
for idx in range(index * n, (index + 1) * n):
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
else:
self.model_status_dict[request["request_id"]] = model_status
request["enable_thinking"] = model_status == "think_start"
return request
def append_generated_tokens(self, multimodal_inputs, generated_token_ids):
@@ -267,6 +267,18 @@ class QwenVLProcessor(TextProcessor):
# Set default max_tokens if not specified
if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
parts = request["request_id"].split("_")
if len(parts) > 1:
real_req_id = parts[0]
index = int(parts[1])
n = request.get("n", 1)
for idx in range(index * n, (index + 1) * n):
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
else:
self.model_status_dict[request["request_id"]] = model_status
request["enable_thinking"] = model_status == "think_start"
data_processor_logger.info(f"Processed request {request}")
return request
+45 -20
View File
@@ -176,6 +176,7 @@ class DataProcessor(BaseDataProcessor):
self.generation_config = None
self.decode_status = dict()
self.model_status_dict = dict()
self.tool_parser_dict = dict()
self.tokenizer = self._load_tokenizer()
data_processor_logger.info(
@@ -267,6 +268,18 @@ class DataProcessor(BaseDataProcessor):
request.set("temperature", 1)
if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS)
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
parts = request.request_id.split("_")
if len(parts) > 1:
real_req_id = parts[0]
index = int(parts[1])
n = request.get("n", 1)
for idx in range(index * n, (index + 1) * n):
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
else:
self.model_status_dict[request.request_id] = model_status
request.enable_thinking = model_status == "think_start"
data_processor_logger.info(f"Processed request: {request}")
return request
@@ -340,6 +353,18 @@ class DataProcessor(BaseDataProcessor):
request["temperature"] = 1
if request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request["prompt_token_ids"])
parts = request["request_id"].split("_")
if len(parts) > 1:
real_req_id = parts[0]
index = int(parts[1])
n = request.get("n", 1)
for idx in range(index * n, (index + 1) * n):
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
else:
self.model_status_dict[request["request_id"]] = model_status
request["enable_thinking"] = model_status == "think_start"
data_processor_logger.info(f"Processed request dict: {request}")
return request
@@ -363,21 +388,21 @@ class DataProcessor(BaseDataProcessor):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
full_text = self.tokenizer.decode(token_ids)
# 模型支持思考,并且支持思考
response_dict.outputs.text = full_text
if self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
full_text, response_dict, self.model_status_dict[req_id]
)
response_dict.outputs.text = text
response_dict.outputs.reasoning_content = reasoning_content
else:
# 模型不支持思考,并且没单独设置enable_thinking为false
response_dict.outputs.text = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict.outputs.tool_calls = tool_call_info.tool_calls
response_dict.outputs.text = tool_call_info.content
if req_id in self.model_status_dict:
del self.model_status_dict[req_id]
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
return response_dict
@@ -392,7 +417,6 @@ class DataProcessor(BaseDataProcessor):
Returns:
Dict: response contain text fields
"""
enable_thinking = kwargs.get("enable_thinking")
token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
@@ -403,14 +427,17 @@ class DataProcessor(BaseDataProcessor):
if is_end:
full_text = previous_texts + delta_text
response_dict["outputs"]["completion_tokens"] = full_text
if enable_thinking and self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
response_dict["outputs"]["text"] = full_text
if self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
full_text,
response_dict,
self.model_status_dict[req_id],
)
response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
response_dict["outputs"]["reasoning_token_num"] = len(reasoning_tokens)
else:
response_dict["outputs"]["text"] = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
@@ -419,6 +446,8 @@ class DataProcessor(BaseDataProcessor):
response_dict["outputs"]["text"] = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.model_status_dict:
del self.model_status_dict[req_id]
return response_dict
def process_response_dict_streaming(self, response_dict, **kwargs):
@@ -431,7 +460,6 @@ class DataProcessor(BaseDataProcessor):
Returns:
Dict: response contain text fields
"""
enable_thinking = kwargs.get("enable_thinking")
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"]
@@ -441,9 +469,7 @@ class DataProcessor(BaseDataProcessor):
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["completion_tokens"] = delta_text
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
if self.reasoning_parser:
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
@@ -451,6 +477,7 @@ class DataProcessor(BaseDataProcessor):
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
self.model_status_dict[req_id],
)
response_dict["outputs"]["delta_message"] = reasoning_delta_message
reasoning_content = reasoning_delta_message.reasoning_content if reasoning_delta_message else None
@@ -477,6 +504,8 @@ class DataProcessor(BaseDataProcessor):
del self.decode_status[req_id]
if req_id in self.tool_parser_dict:
del self.tool_parser_dict[req_id]
if req_id in self.model_status_dict:
del self.model_status_dict[req_id]
return response_dict
def process_response_dict(self, response_dict, **kwargs):
@@ -489,16 +518,12 @@ class DataProcessor(BaseDataProcessor):
Returns:
Dict: response contain text fields
"""
enable_thinking = kwargs.pop("enable_thinking", True)
if enable_thinking is None:
enable_thinking = True
stream = kwargs.get("stream", True)
if stream:
return self.process_response_dict_streaming(response_dict, enable_thinking=enable_thinking, **kwargs)
return self.process_response_dict_streaming(response_dict, **kwargs)
else:
return self.process_response_dict_normal(
response_dict=response_dict,
enable_thinking=enable_thinking,
**kwargs,
)
@@ -35,25 +35,53 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.think_end_token = "</think>"
self.tool_begin_token = "<tool_call>"
token_definitions = {
"think_start_token": "<think>",
"think_end_token": "</think>",
"tool_call_start_token": "<tool_call>",
"tool_call_end_token": "</tool_call>",
}
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
)
missing_tokens = []
for name, token_value in token_definitions.items():
setattr(self, name, token_value)
token_id = self.vocab.get(token_value)
setattr(self, f"{name}_id", token_id)
if token_id is None:
missing_tokens.append(f"{name.replace('_', ' ')} token")
self.think_end_token_id = self.vocab.get(self.think_end_token)
self.tool_begin_token_id = self.vocab.get(self.tool_begin_token)
if self.tool_begin_token_id is None:
self.tool_begin_token_id = -1
if self.think_end_token_id is None:
raise RuntimeError("Test reasoning parser could not locate think end tokens in the tokenizer!")
if missing_tokens:
raise RuntimeError(
f"ernie vl reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
)
self.token_status_mapping = {
self.think_start_token_id: "think_start",
self.think_end_token_id: "think_end",
self.tool_call_start_token_id: "tool_call_start",
self.tool_call_end_token_id: "tool_call_end",
}
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
for i in range(len(prompt_token_ids) - 1, -1, -1):
if prompt_token_ids[i] in self.token_status_mapping:
return prompt_token_ids[i]
return -1
def get_model_status(self, prompt_token_ids: list[int]):
special_token_id = self.find_last_special_token(prompt_token_ids)
if special_token_id == -1:
return "think_start"
return self.token_status_mapping[special_token_id]
def extract_reasoning_content_streaming(
self,
previous_text: str,
@@ -62,6 +90,7 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
model_status: str,
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
@@ -71,6 +100,7 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
if model_status == "think_start":
if self.think_end_token not in current_text:
return DeltaMessage(reasoning_content=delta_text)
# Skip single special tokens
@@ -98,9 +128,18 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
if len(striped_suffix) == 0:
return None
return DeltaMessage(content=delta_text)
elif model_status == "think_end":
if current_text.lstrip("\n").startswith(self.tool_call_start_token):
return None
return DeltaMessage(content=delta_text)
else:
return None
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
self,
model_output: str,
request: ChatCompletionRequest,
model_status: str,
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from the model output.
@@ -114,23 +153,30 @@ class Ernie45VLThinkingReasoningParser(ReasoningParser):
"""
# Check if the model output contains the </think> tokens.
if model_status == "think_start":
if self.think_end_token not in model_output:
return model_output, ""
reasoning_content, _, content = model_output.partition(self.think_end_token)
if self.tool_begin_token in content:
prefix, _, _ = content.partition(self.tool_begin_token)
if self.tool_call_start_token in content:
prefix, _, _ = content.partition(self.tool_call_start_token)
prefix_strip = prefix.lstrip("\n")
if len(prefix_strip) > 0:
return reasoning_content, content
return reasoning_content, ""
return reasoning_content, content
elif model_status == "think_end":
if model_output.lstrip("\n").startswith(self.tool_call_start_token):
return "", ""
return "", model_output
else:
return "", ""
def _is_with_tool(self, current_text: str, current_token_ids: Sequence[int]) -> bool:
think_end_index = current_text.find(self.think_end_token)
think_end = think_end_index + len(self.think_end_token)
middle_str = current_text[think_end:]
if self.tool_begin_token_id in current_token_ids:
prefix, _, _ = middle_str.partition(self.tool_begin_token)
if self.tool_call_start_token_id in current_token_ids:
prefix, _, _ = middle_str.partition(self.tool_call_start_token)
striped_prefix = prefix.strip("\n")
if len(striped_prefix) > 0:
return False
@@ -35,20 +35,48 @@ class ErnieVLReasoningParser(ReasoningParser):
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.think_end_token = "</think>"
token_definitions = {
"think_start_token": "<think>",
"think_end_token": "</think>",
}
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
)
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
self.think_end_token_id = self.vocab.get(self.think_end_token)
if self.think_end_token_id is None:
raise RuntimeError("Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!")
missing_tokens = []
for name, token_value in token_definitions.items():
setattr(self, name, token_value)
token_id = self.vocab.get(token_value)
setattr(self, f"{name}_id", token_id)
if token_id is None:
missing_tokens.append(f"{name.replace('_', ' ')} token")
if missing_tokens:
raise RuntimeError(
f"ernie vl reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
)
self.token_status_mapping = {
self.think_start_token_id: "think_start",
self.think_end_token_id: "think_end",
}
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
for i in range(len(prompt_token_ids) - 1, -1, -1):
if prompt_token_ids[i] in self.token_status_mapping:
return prompt_token_ids[i]
return -1
def get_model_status(self, prompt_token_ids: list[int]):
special_token_id = self.find_last_special_token(prompt_token_ids)
if special_token_id == -1:
return "think_start"
return self.token_status_mapping[special_token_id]
def extract_reasoning_content_streaming(
self,
previous_text: str,
@@ -57,6 +85,7 @@ class ErnieVLReasoningParser(ReasoningParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
model_status: str,
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
@@ -69,18 +98,23 @@ class ErnieVLReasoningParser(ReasoningParser):
# Skip single special tokens
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
return None
if model_status == "think_start":
if self.think_end_token_id in delta_token_ids:
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token) :]
return DeltaMessage(reasoning_content=reasoning_content, content=content)
elif self.think_end_token_id in previous_token_ids:
if self.think_end_token_id in previous_token_ids:
return DeltaMessage(content=delta_text)
else:
return DeltaMessage(reasoning_content=delta_text)
else:
return DeltaMessage(content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
self,
model_output: str,
request: ChatCompletionRequest,
model_status: str,
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from the model output.
@@ -92,11 +126,12 @@ class ErnieVLReasoningParser(ReasoningParser):
Returns:
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
# Check if the model output contains the </think> tokens.
if model_status == "think_start":
if self.think_end_token not in model_output:
return "", model_output
reasoning_content, _, content = model_output.partition(self.think_end_token)
final_content = content or ""
return reasoning_content, final_content
else:
return "", model_output
@@ -18,19 +18,55 @@ class ErnieX1ReasoningParser(ReasoningParser):
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.think_end_token = "</think>"
self.response_start_token = "<response>"
self.response_end_token = "</response>"
self.tool_call_start_token = "<tool_call>"
self.tool_call_end_token = "</tool_call>"
# 定义所有需要检查的token
token_definitions = {
"think_start_token": "<think>",
"think_end_token": "</think>",
"response_start_token": "<response>",
"response_end_token": "</response>",
"tool_call_start_token": "<tool_call>",
"tool_call_end_token": "</tool_call>",
}
if not self.model_tokenizer:
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
self.think_end_token_id = self.vocab.get("</think>")
if self.think_end_token_id is None:
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
missing_tokens = []
for name, token_value in token_definitions.items():
setattr(self, name, token_value)
token_id = self.vocab.get(token_value)
setattr(self, f"{name}_id", token_id)
if token_id is None:
missing_tokens.append(token_value)
if missing_tokens:
raise RuntimeError(
f"ernie x1 reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
)
self.token_status_mapping = {
self.think_start_token_id: "think_start",
self.think_end_token_id: "think_end",
self.response_start_token_id: "response_start",
self.response_end_token_id: "response_end",
self.tool_call_start_token_id: "tool_call_start",
self.tool_call_end_token_id: "tool_call_end",
}
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
for i in range(len(prompt_token_ids) - 1, -1, -1):
if prompt_token_ids[i] in self.token_status_mapping:
return prompt_token_ids[i]
return -1
def get_model_status(self, prompt_token_ids: list[int]):
special_token_id = self.find_last_special_token(prompt_token_ids)
if special_token_id == -1:
return "think_start"
return self.token_status_mapping[special_token_id]
def extract_reasoning_content_streaming(
self,
@@ -40,64 +76,82 @@ class ErnieX1ReasoningParser(ReasoningParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
model_status: str,
) -> Union[DeltaMessage, None]:
# Ignore the single </think> token
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
if len(delta_token_ids) == 1 and delta_token_ids[0] in [
self.think_end_token_id,
self.response_start_token_id,
self.response_end_token_id,
self.tool_call_start_token_id,
self.tool_call_end_token_id,
]:
return None
# --- Thinking stage handling ---
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
# If delta is </think>, stop thinking, do not return
if delta_text.startswith(self.think_end_token):
return None
# Otherwise, return thinking content (keep \n as-is)
if model_status == "think_start":
if self.think_end_token in delta_text:
response_content = ""
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
response_start_pos = delta_text.find(self.response_start_token)
if response_start_pos != -1:
response_content = self._extract_response_content(
delta_text[response_start_pos + len(self.response_start_token) :]
)
return DeltaMessage(reasoning_content=reasoning_content, content=response_content)
elif self.think_end_token in previous_text:
if self.response_start_token in previous_text and self.response_end_token not in previous_text:
return DeltaMessage(content=delta_text)
else:
return DeltaMessage(reasoning_content=delta_text)
# --- After thinking ends, check tool_call or response ---
remaining_text = previous_text + delta_text
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
after_think = after_think.lstrip("\n")
# Handle tool_call case: skip it
if after_think.startswith(self.tool_call_start_token):
return None
# Handle response case
if after_think.startswith(self.response_start_token) and self.response_end_token not in after_think:
# Do not return when <response> tag itself appears
if delta_text == self.response_start_token or delta_text == self.response_end_token:
return None
elif model_status == "think_end":
if self.response_start_token in previous_text and self.response_end_token not in previous_text:
return DeltaMessage(content=delta_text)
elif model_status == "response_start":
if self.response_end_token not in previous_text:
return DeltaMessage(content=delta_text)
# Default case: return nothing
return None
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest, model_status: str
) -> Tuple[str, str]:
"""
优化版解析器。保留推理和响应内容中的换行符,
仅删除闭合标签前的单个换行符。
"""
reasoning_content = ""
response_content = ""
if model_status in ["think_start", "think_end"]:
if model_status == "think_start":
think_end_pos = model_output.find(self.think_end_token)
if think_end_pos != -1:
reasoning_content = model_output[:think_end_pos]
remaining = model_output[think_end_pos + len(self.think_end_token) :]
# find <response> or <tool>
response_pos = remaining.find(self.response_start_token)
tool_pos = remaining.find(self.tool_call_start_token)
# <response> first
if response_pos != -1 and (tool_pos == -1 or response_pos < tool_pos):
# The content after the response_start position
remaining_response = remaining[response_pos + len(self.response_start_token) :]
response_end_pos = remaining_response.find(self.response_end_token)
if response_end_pos != -1:
response_content = remaining_response[:response_end_pos]
else:
response_content = remaining_response
# The content after the response_start position is tool_call
remaining = model_output[think_end_pos + len(self.think_end_token) :].lstrip("\n")
else:
reasoning_content = model_output
response_content = ""
remaining = ""
else:
remaining = model_output.lstrip("\n")
response_start_pos = remaining.find(self.response_start_token)
if response_start_pos != -1:
response_content = self._extract_response_content(
remaining[response_start_pos + len(self.response_start_token) :]
)
elif model_status == "response_start":
response_content = self._extract_response_content(model_output)
return reasoning_content, response_content
def _extract_response_content(self, remaining: str) -> str:
"""
Extracts response content, ensuring that the last newline before
the </response> tag is removed.
"""
response_end_pos = remaining.find(self.response_end_token)
if response_end_pos != -1:
return remaining[:response_end_pos]
return remaining
+45 -10
View File
@@ -35,22 +35,50 @@ class Qwen3ReasoningParser(ReasoningParser):
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
# 定义所有需要检查的token
token_definitions = {
"think_start_token": "<think>",
"think_end_token": "</think>",
}
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
)
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
self.think_start_token_id = self.vocab.get(self.think_start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if self.think_end_token_id is None:
raise RuntimeError("Qwen3 reasoning parser could not locate think end " "tokens in the tokenizer!")
missing_tokens = []
for name, token_value in token_definitions.items():
setattr(self, name, token_value)
token_id = self.vocab.get(token_value)
setattr(self, f"{name}_id", token_id)
if token_id is None:
missing_tokens.append(token_value)
if missing_tokens:
raise RuntimeError(
f"Qwen3 reasoning parser could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
)
self.token_status_mapping = {
self.think_start_token_id: "think_start",
self.think_end_token_id: "think_end",
}
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
for i in range(len(prompt_token_ids) - 1, -1, -1):
if prompt_token_ids[i] in self.token_status_mapping:
return prompt_token_ids[i]
return -1
def get_model_status(self, prompt_token_ids: list[int]):
special_token_id = self.find_last_special_token(prompt_token_ids)
if special_token_id == -1:
return "think_start"
return self.token_status_mapping[special_token_id]
def extract_reasoning_content_streaming(
self,
previous_text: str,
@@ -59,6 +87,7 @@ class Qwen3ReasoningParser(ReasoningParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
model_status: str,
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
@@ -71,6 +100,7 @@ class Qwen3ReasoningParser(ReasoningParser):
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]):
return None
if model_status == "think_start":
# </think> in delta
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
@@ -101,9 +131,11 @@ class Qwen3ReasoningParser(ReasoningParser):
return DeltaMessage(reasoning_content=reasoning_content, content=content)
else:
return DeltaMessage(reasoning_content=delta_text)
else:
return DeltaMessage(content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
self, model_output: str, request: ChatCompletionRequest, model_status: str
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from the model output.
@@ -116,6 +148,7 @@ class Qwen3ReasoningParser(ReasoningParser):
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
if model_status == "think_start":
# 检查是否包含结束标签
if self.think_end_token not in model_output:
return None, model_output
@@ -149,3 +182,5 @@ class Qwen3ReasoningParser(ReasoningParser):
return reasoning_content, final_content
return None, model_output
else:
return None, model_output
@@ -482,7 +482,7 @@ def test_chat_with_thinking(openai_client, capsys):
max_tokens=10,
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
)
assert response.choices[0].message.reasoning_content is None
assert response.choices[0].message.reasoning_content == ""
assert "</think>" not in response.choices[0].message.content
# test logic
@@ -957,4 +957,4 @@ def test_thinking_logic_flag(openai_client, capsys):
"chat_template_kwargs": {"enable_thinking": False},
},
)
assert response_case_3.choices[0].message.reasoning_content is None
assert response_case_3.choices[0].message.reasoning_content == ""
+2 -2
View File
@@ -545,7 +545,7 @@ def test_chat_with_thinking(openai_client, capsys):
max_tokens=10,
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
)
assert response.choices[0].message.reasoning_content is None
assert response.choices[0].message.reasoning_content == ""
assert "</think>" not in response.choices[0].message.content
# test logic
@@ -716,4 +716,4 @@ def test_thinking_logic_flag(openai_client, capsys):
"chat_template_kwargs": {"enable_thinking": False},
},
)
assert response_case_3.choices[0].message.reasoning_content is None
assert response_case_3.choices[0].message.reasoning_content == ""
+2 -2
View File
@@ -312,7 +312,7 @@ def test_chat_with_thinking(openai_client, capsys):
max_tokens=10,
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
)
assert response.choices[0].message.reasoning_content is None
assert response.choices[0].message.reasoning_content == ""
assert "</think>" not in response.choices[0].message.content
# test logic
@@ -404,4 +404,4 @@ def test_thinking_logic_flag(openai_client, capsys):
"chat_template_kwargs": {"enable_thinking": False},
},
)
assert response_case_3.choices[0].message.reasoning_content is None
assert response_case_3.choices[0].message.reasoning_content == ""
@@ -59,6 +59,8 @@ class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
self.multi_modal_processor._check_mm_limits = Mock()
self.multi_modal_processor.append_completion_tokens = Mock()
self.multi_modal_processor.pack_outputs = lambda x: x
self.multi_modal_processor.reasoning_parser = None
self.multi_modal_processor.model_status_dict = {}
self.engine_client = Mock()
self.engine_client.connection_initialized = False
@@ -258,7 +260,7 @@ class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
mock_processor_instance = Mock()
mock_processor_instance.enable_multimodal_content.return_value = True
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
async def mock_process_response_chat_async(response, stream, include_stop_str_in_output):
yield response
mock_processor_instance.process_response_chat = mock_process_response_chat_async
@@ -439,7 +441,7 @@ class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
mock_processor_instance = Mock()
mock_processor_instance.enable_multimodal_content.return_value = False
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
async def mock_process_response_chat_async(response, stream, include_stop_str_in_output):
if isinstance(response, list):
for res in response:
yield res
@@ -165,7 +165,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
mock_processor_instance = Mock()
async def mock_process_response_chat_single(response, stream, enable_thinking, include_stop_str_in_output):
async def mock_process_response_chat_single(response, stream, include_stop_str_in_output):
yield response
mock_processor_instance.process_response_chat = mock_process_response_chat_single
@@ -539,7 +539,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
mock_processor_instance = Mock()
async def mock_process_response_chat(response, stream, enable_thinking, include_stop_str_in_output):
async def mock_process_response_chat(response, stream, include_stop_str_in_output):
delta_msg_mock = Mock()
delta_msg_mock.content = response["outputs"]["text"]
if response["outputs"]["text"] == "a":
@@ -48,7 +48,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
results = [
r
async for r in processor.process_response_chat(
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
request_outputs, stream=False, include_stop_str_in_output=False
)
]
@@ -69,7 +69,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
results = [
r
async for r in processor.process_response_chat(
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
request_outputs, stream=True, include_stop_str_in_output=False
)
]
@@ -89,7 +89,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
results = [
r
async for r in self.processor_mm.process_response_chat(
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
request_outputs, stream=True, include_stop_str_in_output=False
)
]
@@ -116,7 +116,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
results = [
r
async for r in self.processor_mm.process_response_chat(
request_outputs, stream=True, enable_thinking=False, include_stop_str_in_output=False
request_outputs, stream=True, include_stop_str_in_output=False
)
]
@@ -134,7 +134,7 @@ class TestChatResponseProcessor(unittest.IsolatedAsyncioTestCase):
results = [
r
async for r in self.processor_mm.process_response_chat(
request_outputs, stream=False, enable_thinking=False, include_stop_str_in_output=False
request_outputs, stream=False, include_stop_str_in_output=False
)
]
@@ -41,35 +41,6 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
chat_template=None,
)
def test_enable_thinking(self):
request = ChatCompletionRequest(messages=[], chat_template_kwargs={})
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
self.assertEqual(enable_thinking, None)
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"enable_thinking": True})
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
self.assertEqual(enable_thinking, True)
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"enable_thinking": False})
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
self.assertEqual(enable_thinking, False)
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "close"}})
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
self.assertEqual(enable_thinking, False)
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "false"}})
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
self.assertEqual(enable_thinking, False)
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "open"}})
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
self.assertEqual(enable_thinking, True)
request = ChatCompletionRequest(messages=[], chat_template_kwargs={"options": {"thinking_mode": "123"}})
enable_thinking = self.chat_completion_handler._get_thinking_status(request)
self.assertEqual(enable_thinking, True)
def test_build_prompt_logprobs_basic(self):
"""Test basic functionality of _build_prompt_logprobs"""
# Create mock data
@@ -52,33 +52,12 @@ class TestErnieX1ToolParser(unittest.TestCase):
self.assertTrue(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
def test_extract_tool_calls_partial_arguments(self):
"""Test partial extraction when arguments incomplete"""
output = '<tool_call>{"name": "get_weather", "arguments": {"location": ""</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
def test_extract_tool_calls_invalid_response_before_toolcall(self):
"""Test case where <response> before <tool_call> is invalid"""
output = '<response>hello</response><tool_call>{"name": "get_weather", "arguments": {}}</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertIn("<response>", result.content)
def test_extract_tool_calls_no_toolcall(self):
"""Test when no tool_call tags are present"""
output = "no tool call here"
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
def test_extract_tool_calls_invalid_json(self):
"""Test tool_call with badly formatted JSON triggers fallback parser"""
output = '<tool_call>"name": "get_weather", "arguments": {</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
def test_extract_tool_calls_exception(self):
"""Force exception to cover error branch"""
with patch(
+2
View File
@@ -89,6 +89,7 @@ class ErnieX1ReasoningParser:
previous_token_ids,
all_token_ids,
delta_token_ids,
model_status,
):
"""Return a simple object with reasoning_content to cover reasoning branch."""
@@ -161,6 +162,7 @@ class TestErnie4_5Processor(unittest.TestCase):
tool_cls = MockToolParser if tool else None
proc = Ernie4_5Processor("dummy-model", reasoning_parser_obj=reasoning_cls, tool_parser_obj=tool_cls)
proc._apply_default_parameters = lambda req: req
proc.model_status_dict = {"req-1": "think_start"}
return proc
def test_update_bad_words(self):
+35 -3
View File
@@ -20,6 +20,11 @@ from unittest.mock import MagicMock, patch
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
class MockReasoningParser:
def get_model_status(self, prompt_token_ids):
return "think_start"
class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
def setUp(self):
# 创建 Ernie4_5Processor 实例的模拟对象
@@ -30,12 +35,13 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
# 设置必要的属性
self.processor.tokenizer = MagicMock()
self.processor.tokenizer.eos_token_id = 1
self.processor.decode_status = {}
self.processor.decode_status = {"test": []}
self.processor.reasoning_end_dict = {}
self.processor.tool_parser_dict = {}
self.processor.generation_config = MagicMock()
self.processor.eos_token_ids = [1]
self.processor.reasoning_parser = MagicMock()
self.processor.reasoning_parser = MockReasoningParser()
self.processor.model_status_dict = {"request-id_0": "think_start", "test": "think_start"}
# 模拟 ids2tokens 方法
def mock_ids2tokens(token_ids, task_id):
@@ -72,7 +78,7 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
def test_process_response_dict_streaming_normal_case(self):
"""测试正常情况下的流式响应处理"""
# 准备输入
response_dict = {"finished": False, "request_id": "req1", "outputs": {"token_ids": [4, 5]}}
response_dict = {"finished": False, "request_id": "test", "outputs": {"token_ids": [4, 5]}}
kwargs = {"enable_thinking": True}
# 调用方法
@@ -83,6 +89,7 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
def test_process_request_dict(self):
request_dict = {
"request_id": "123",
"messages": [{"role": "user", "content": "Hello!"}],
"chat_template_kwargs": {"chat_template": "Hello!"},
"eos_token_ids": [1],
@@ -118,6 +125,31 @@ class TestErnie4_5ProcessorProcessResponseDictStreaming(unittest.TestCase):
self.assertEqual(result["outputs"]["text"], "Mock final text")
self.assertIn("completion_tokens", result["outputs"])
def test_think_status(self):
"""测试 思考机制"""
request = {
"prompt": "hello",
"request_id": "test_1",
"prompt_token_ids": [1, 2, 3],
"temperature": 0.7,
"top_p": 0.9,
}
self.processor.reasoning_parser = MagicMock()
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
self.processor.model_status_dict = {}
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
request = {
"prompt": "hello",
"request_id": "test",
"prompt_token_ids": [1, 2, 3],
"temperature": 0.7,
"top_p": 0.9,
}
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
if __name__ == "__main__":
unittest.main()
+38 -86
View File
@@ -29,7 +29,12 @@ from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
from fastdeploy.input.utils import IDS_TYPE_FLAG
class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
class MockReasoningParser:
def get_model_status(self, prompt_token_ids):
return "think_start"
class TestErnie4_5VLProcessorProcessResponseDictStreaming(unittest.TestCase):
def setUp(self):
# Create mock object for Ernie4_5Processor instance
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None) as mock_init:
@@ -39,38 +44,41 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
# Set necessary attributes
self.processor.tokenizer = MagicMock()
self.processor.tokenizer.eos_token_id = 1
self.processor.decode_status = {}
self.processor.decode_status = {"test": []}
self.processor.reasoning_end_dict = {}
self.processor.tool_parser_dict = {}
self.processor.generation_config = MagicMock()
self.processor.eos_token_ids = [1]
self.processor.reasoning_parser = MagicMock()
self.processor._check_mm_limits = MagicMock()
self.processor.reasoning_parser = MockReasoningParser()
self.processor.model_status_dict = {"test": "think_start"}
self.processor.ernie4_5_processor = MagicMock()
self.processor.pack_outputs = MagicMock()
# Mock ids2tokens method
def mock_ids2tokens(token_ids, task_id):
self.processor.decode_status[task_id] = "mock_decode_status"
return "delta_text", [2, 3], "previous_texts"
self.processor.ids2tokens = mock_ids2tokens
def mock_messages2ids(request, **kwargs):
if "chat_template" in kwargs:
return [1]
else:
return [0]
def mock_request2ids(request, **kwargs):
return {"input_ids": np.array([1, 2, 3]), "prompt_token_ids": [0]}
def mock_check_mm_limits(item):
pass
def mock_apply_default_parameters(request):
return request
def mock_pack_outputs(outputs):
return outputs
self.processor._apply_default_parameters = mock_apply_default_parameters
self.processor._check_mm_limits = mock_check_mm_limits
self.processor.ernie4_5_processor.request2ids = mock_request2ids
self.processor.pack_outputs = mock_pack_outputs
# Mock reasoning parser
self.mock_reasoning_parser = MagicMock()
self.mock_reasoning_parser.__class__.__name__ = "ErnieX1ReasoningParser"
# self.mock_reasoning_parser.extract_reasoning_content_streaming.return_value = ("reasoning", "text")
self.mock_reasoning_parser.extract_reasoning_content_streaming.return_value = None
self.processor.reasoning_parser = self.mock_reasoning_parser
# Mock tool parser
@@ -80,82 +88,26 @@ class TestErnie4_5_vl_ProcessorProcessResponseDictStreaming(unittest.TestCase):
self.mock_tool_parser_obj.return_value = self.mock_tool_parser
self.processor.tool_parser_obj = self.mock_tool_parser_obj
def test_process_request_dict_with_options(self):
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"prompt_token_ids": [1, 1, 1],
def test_think_status(self):
"""测试 思考机制"""
request = {
"prompt": "hello",
"request_id": "test_1",
"prompt_token_ids": [1, 2, 3],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], True)
self.processor.reasoning_parser = MagicMock()
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
self.processor.model_status_dict = {}
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"enable_thinking": True},
"prompt_token_ids": [1, 1, 1],
request = {
"prompt": "hello",
"request_id": "test",
"prompt_token_ids": [1, 2, 3],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"enable_thinking": False},
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "open"}},
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "close"}},
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "123"}},
"prompt_token_ids": [1, 1, 1],
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], True)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "close"}},
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"options": {"thinking_mode": "false"}},
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
request_dict = {
"messages": [{"role": "user", "content": "Hello"}],
"chat_template_kwargs": {"enable_thinking": False},
}
self.processor.process_request_dict(request_dict, 100)
self.assertEqual(request_dict["enable_thinking"], False)
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
class TestDataProcessorTargetMethods(unittest.TestCase):
@@ -793,6 +793,8 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
self.processor.processor = MagicMock()
self.processor.limit_mm_per_prompt = {"image": 1, "video": 1, "audio": 1}
self.processor.eos_token_ids = [1]
self.processor.reasoning_parser = None
self.processor.model_status_dict = {}
# 模拟 _apply_default_parameters
def mock_apply_default_parameters(request_or_dict):
@@ -971,6 +973,7 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
"prompt": "test prompt",
"multimodal_data": {"image": ["image1"]},
"metadata": {"generated_token_ids": []},
"request_id": "test-request",
}
request_obj.to_dict.return_value = request_dict
@@ -1157,6 +1160,27 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
self.assertTrue(np.array_equal(result["image_type_ids"], np.array([0])))
self.assertTrue(np.array_equal(result["position_ids"], np.array([[0], [1], [2]], dtype=np.int64)))
def test_think_status(self):
"""测试 思考机制"""
request = {
"prompt": "hello",
"request_id": "test_1",
"prompt_token_ids": [1, 2, 3],
}
self.processor.reasoning_parser = MagicMock()
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
self.processor.model_status_dict = {}
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
request = {
"prompt": "hello",
"request_id": "test",
"prompt_token_ids": [1, 2, 3],
}
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
if __name__ == "__main__":
unittest.main()
+25
View File
@@ -365,6 +365,31 @@ class TestQwenVLProcessor(unittest.TestCase):
# Verify both methods produce identical prompt strings
self.assertEqual(prompt, prompt2)
def test_think_status(self):
"""测试 思考机制"""
request = {
"prompt": "hello",
"request_id": "test_1",
"prompt_token_ids": [1, 2, 3],
"temperature": 0.7,
"top_p": 0.9,
}
self.processor.reasoning_parser = MagicMock()
self.processor.reasoning_parser.get_model_status.return_value = "think_start"
self.processor.model_status_dict = {}
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
request = {
"prompt": "hello",
"request_id": "test",
"prompt_token_ids": [1, 2, 3],
"temperature": 0.7,
"top_p": 0.9,
}
self.processor.process_request_dict(request, max_model_len=512)
self.assertEqual(request["enable_thinking"], True)
if __name__ == "__main__":
unittest.main()
+5 -4
View File
@@ -266,7 +266,7 @@ class DataProcessorTestCase(unittest.TestCase):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def extract_reasoning_content(self, full_text, response_dict):
def extract_reasoning_content(self, full_text, response_dict, model_status):
return reasoning_content, f"{full_text}!"
return DummyReasoning(tokenizer)
@@ -409,6 +409,7 @@ class DataProcessorTestCase(unittest.TestCase):
def test_process_response_with_reasoning_and_tools(self):
processor = self.processor
processor.model_status_dict = {"resp": "normal"}
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer)
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-only")
@@ -435,7 +436,7 @@ class DataProcessorTestCase(unittest.TestCase):
def test_process_response_dict_normal_with_reasoning(self):
processor = self.processor
processor.model_status_dict = {"normal": "normal"}
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer, reasoning_content="because")
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-text")
@@ -471,10 +472,10 @@ class DataProcessorTestCase(unittest.TestCase):
self.addCleanup(lambda: setattr(processor, "process_response_dict_normal", original_normal))
response = {"outputs": {}, "finished": False, "request_id": "req"}
self.assertEqual(processor.process_response_dict(response), "stream")
self.assertEqual(processor.process_response_dict(response, stream=True, enable_thinking=True), "stream")
self.assertTrue(calls["stream"]["enable_thinking"])
self.assertEqual(
processor.process_response_dict(response, stream=False, enable_thinking=None),
processor.process_response_dict(response, stream=False, enable_thinking=True),
"normal",
)
self.assertTrue(calls["normal"]["enable_thinking"])
@@ -0,0 +1,197 @@
import unittest
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
from fastdeploy.reasoning.qwen3_reasoning_parsers import Qwen3ReasoningParser
class MockTokenizer:
"""Minimal tokenizer with vocab for testing."""
def __init__(self):
self.vocab = {
"<think>": 100,
"</think>": 101,
}
def get_vocab(self):
"""Return vocab dict for testing."""
return self.vocab
class MissingTokenTokenizer:
def __init__(self):
self.vocab = {
"</think>": 100,
}
def get_vocab(self):
"""Return vocab dict for testing."""
return self.vocab
class TestQwen3ReasoningParser(unittest.TestCase):
def setUp(self):
self.parser = Qwen3ReasoningParser(MockTokenizer())
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
self.tokenizer = MockTokenizer()
def test_missing_token(self):
with self.assertRaises(RuntimeError) as context:
Qwen3ReasoningParser(MissingTokenTokenizer())
exception_message = str(context.exception)
expected_message_part = "Qwen3 reasoning parser could not find the following token ids"
self.assertIn(expected_message_part, exception_message)
def test_get_model_status(self):
status = self.parser.get_model_status([1, 2, 100])
self.assertEqual(status, "think_start")
status = self.parser.get_model_status([1, 2, 101])
self.assertEqual(status, "think_end")
status = self.parser.get_model_status([1])
self.assertEqual(status, "think_start")
def test_streaming_thinking_content(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[200],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "a")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a</think>b",
delta_text="a</think>b",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[99, 101, 102],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "a")
self.assertEqual(msg.content, "b")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="a</think>",
current_text="a</think>b",
delta_text="b",
previous_token_ids=[1, 101],
current_token_ids=[],
delta_token_ids=[102],
model_status="think_start",
)
self.assertEqual(msg.content, "b")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "a")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[200],
model_status="think_end",
)
self.assertEqual(msg.content, "a")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello",
current_text="hello</think>hi",
delta_text="</think>hi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[101, 200],
model_status="think_start",
)
self.assertEqual(msg.content, "hi")
self.assertEqual(msg.reasoning_content, "")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello",
current_text="hello</think>hi",
delta_text="hi",
previous_token_ids=[100],
current_token_ids=[],
delta_token_ids=[],
model_status="think_start",
)
self.assertEqual(msg.content, None)
self.assertEqual(msg.reasoning_content, "hi")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello",
current_text="hello</think>hi",
delta_text="<think>hi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="think_start",
)
self.assertEqual(msg.content, "")
self.assertEqual(msg.reasoning_content, "hi")
def test_none_streaming_thinking_content(self):
reasoning_content, content = self.parser.extract_reasoning_content(
model_output="a",
request={},
model_status="think_start",
)
self.assertEqual(reasoning_content, None)
self.assertEqual(content, "a")
reasoning_content, content = self.parser.extract_reasoning_content(
model_output="a</think>b",
request={},
model_status="think_start",
)
self.assertEqual(reasoning_content, "a")
self.assertEqual(content, "b")
reasoning_content, content = self.parser.extract_reasoning_content(
model_output="a",
request={},
model_status="think_end",
)
self.assertEqual(reasoning_content, None)
self.assertEqual(content, "a")
reasoning_content, content = self.parser.extract_reasoning_content(
model_output="<think>a",
request={},
model_status="think_start",
)
self.assertEqual(reasoning_content, None)
self.assertEqual(content, "<think>a")
reasoning_content, content = self.parser.extract_reasoning_content(
model_output="<think>a</think>b",
request={},
model_status="think_start",
)
self.assertEqual(reasoning_content, "a")
self.assertEqual(content, "b")
reasoning_content, content = self.parser.extract_reasoning_content(
model_output="</think>b",
request={},
model_status="think_start",
)
self.assertEqual(reasoning_content, "")
self.assertEqual(content, "b")
if __name__ == "__main__":
unittest.main()
+233 -17
View File
@@ -31,10 +31,25 @@ class DummyTokenizer:
def __init__(self):
self.vocab = {
"</think>": 100,
"<tool_call>": 101,
"</tool_call>": 102,
"<response>": 103,
"</response>": 104,
"<think>": 101,
"<tool_call>": 102,
"</tool_call>": 103,
"<response>": 104,
"</response>": 105,
}
def get_vocab(self):
"""Return vocab dict for testing."""
return self.vocab
class MissingTokenTokenizer:
def __init__(self):
self.vocab = {
"</think>": 100,
"<think>": 101,
"<tool_call>": 102,
"</tool_call>": 103,
}
def get_vocab(self):
@@ -132,6 +147,17 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
self.tokenizer = DummyTokenizer()
def test_missing_token(self):
with self.assertRaises(RuntimeError) as context:
ErnieX1ReasoningParser(MissingTokenTokenizer())
exception_message = str(context.exception)
expected_message_part = "ernie x1 reasoning parser could not find the following token ids"
self.assertIn(expected_message_part, exception_message)
def test_get_model_status(self):
model_status = self.parser.get_model_status([88, 99, 104])
self.assertEqual(model_status, "response_start")
# ---- Streaming parsing ----
def test_streaming_thinking_content(self):
msg = self.parser.extract_reasoning_content_streaming(
@@ -141,6 +167,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[200],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "a")
@@ -152,6 +179,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[201],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "\n")
@@ -163,6 +191,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.think_end_token_id],
model_status="think_start",
)
self.assertIsNone(msg)
@@ -174,6 +203,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[202],
model_status="think_start",
)
self.assertEqual(msg.content, "h")
@@ -185,6 +215,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[203],
model_status="think_start",
)
self.assertEqual(msg.content, "\n")
@@ -197,6 +228,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.vocab["<response>"]],
model_status="think_start",
)
)
@@ -207,6 +239,7 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[204],
model_status="think_start",
)
self.assertIsInstance(msg, DeltaMessage)
self.assertEqual(msg.content, "\n")
@@ -219,9 +252,82 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.vocab["</response>"]],
model_status="think_start",
)
)
def test_extract_reasoning_content_streaming(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello</think>",
current_text="hello</think><response>",
delta_text="</think><response>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="think_start",
)
self.assertEqual(msg.content, "")
self.assertEqual(msg.reasoning_content, "")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello</think>",
current_text="hello</think><response>hi",
delta_text="</think><response>hi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="think_start",
)
self.assertEqual(msg.content, "hi")
self.assertEqual(msg.reasoning_content, "")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="hello</think><response>hi",
delta_text="hello</think><response>hi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="think_start",
)
self.assertEqual(msg.content, "hi")
self.assertEqual(msg.reasoning_content, "hello")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello</think><response>",
current_text="hello</think><response>hi",
delta_text="hi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="think_end",
)
self.assertEqual(msg.content, "hi")
self.assertEqual(msg.reasoning_content, None)
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello</think><response>",
current_text="hello</think><response>hi",
delta_text="hi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="response_start",
)
self.assertEqual(msg.content, "hi")
self.assertEqual(msg.reasoning_content, None)
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello</think><response>hi</response>",
current_text="hello</think><response>hi</response>end",
delta_text="end",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="response_start",
)
self.assertEqual(msg, None)
def test_streaming_tool_call(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="</think>",
@@ -230,40 +336,48 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.vocab["<tool_call>"]],
model_status="think_start",
)
self.assertIsNone(msg)
# ---- Batch parsing ----
def test_batch_reasoning_and_response(self):
text = "abc\n</think>\n<response>hello\nworld</response>"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "abc\n")
self.assertEqual(response, "hello\nworld")
def test_batch_reasoning_and_tool_call(self):
text = "abc</think><tool_call>call_here"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "abc")
self.assertEqual(response, "")
def test_batch_no_thinking_tag(self):
text = "no_thinking_here"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "no_thinking_here")
self.assertEqual(response, "")
def test_batch_response_without_end_tag(self):
text = "abc</think><response>partial response"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "abc")
self.assertEqual(response, "partial response")
def test_batch_preserve_all_newlines(self):
text = "abc\n</think>\n<response>line1\nline2\n</response>"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "abc\n")
self.assertEqual(response, "line1\nline2\n")
def test_extract_reasoning_content(self):
reasoning_content, response_content = self.parser.extract_reasoning_content(
model_output="hello", request=self.request, model_status="response_start"
)
self.assertEqual(reasoning_content, "")
self.assertEqual(response_content, "hello")
class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
def setUp(self):
@@ -272,6 +386,9 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
self.test_request = ChatCompletionRequest(
model="ernie-test", messages=[{"role": "user", "content": "test prompt"}]
)
self.parser.token_status_mapping = {
100: "think_start",
}
def test_streaming_non_reasoning(self):
result = self.parser.extract_reasoning_content_streaming(
@@ -281,6 +398,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
previous_token_ids=[],
current_token_ids=[200],
delta_token_ids=[200],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "a")
@@ -294,6 +412,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
previous_token_ids=[200, 201],
current_token_ids=[200, 201, 100],
delta_token_ids=[100],
model_status="think_start",
)
self.assertIsNone(result)
@@ -305,6 +424,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
previous_token_ids=[200, 201],
current_token_ids=[200, 201, 100, 300, 400],
delta_token_ids=[100, 300, 400],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.reasoning_content)
@@ -318,6 +438,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100],
delta_token_ids=[100],
model_status="think_start",
)
self.assertIsNone(result)
@@ -329,9 +450,10 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 200, 101],
delta_token_ids=[100, 200, 101],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "")
self.assertEqual(result.reasoning_content, None)
def test_streaming_with_reasoning_and_illegal_tool(self):
result = self.parser.extract_reasoning_content_streaming(
@@ -341,6 +463,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 200, 101],
delta_token_ids=[109, 200, 101],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.content, "\n\nhello<tool_call>")
@@ -353,6 +476,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 200, 110],
delta_token_ids=[100, 200, 110],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "hello")
@@ -366,6 +490,7 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
previous_token_ids=[100],
current_token_ids=[100, 110, 111],
delta_token_ids=[110, 111],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.reasoning_content)
@@ -379,52 +504,140 @@ class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
previous_token_ids=[101],
current_token_ids=[101, 110],
delta_token_ids=[110],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "hello")
def test_think_end_status_streaming(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="<tool_call>",
current_text="<tool_call>hello",
delta_text="hello",
previous_token_ids=[101],
current_token_ids=[101, 110],
delta_token_ids=[110],
model_status="think_end",
)
self.assertIs(result, None)
result = self.parser.extract_reasoning_content_streaming(
previous_text="hello, ",
current_text="hello, hi",
delta_text="hi",
previous_token_ids=[101],
current_token_ids=[101, 110],
delta_token_ids=[110],
model_status="think_end",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.content, "hi")
def test_other_status_streaming(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="hello, ",
current_text="hello, hi",
delta_text="hi",
previous_token_ids=[101],
current_token_ids=[101, 110],
delta_token_ids=[110],
model_status="tool_call_start",
)
self.assertIs(result, None)
def test_batch_no_think_end(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="direct response", request=self.test_request
model_output="direct response", request=self.test_request, model_status="think_start"
)
self.assertEqual(reasoning, "direct response")
self.assertEqual(content, "")
def test_batch_no_think_end_with_tool(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="direct response<tool_call>abc", request=self.test_request
model_output="direct response<tool_call>abc", request=self.test_request, model_status="think_start"
)
self.assertEqual(reasoning, "direct response<tool_call>abc")
self.assertEqual(content, "")
def test_batch_think_end_normal_content(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning</think>\nresponse", request=self.test_request
model_output="reasoning</think>\nresponse", request=self.test_request, model_status="think_start"
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\nresponse")
def test_batch_think_end_with_tool(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning</think>\n<tool_call>tool params</tool_call>", request=self.test_request
model_output="reasoning</think>\n<tool_call>tool params</tool_call>",
request=self.test_request,
model_status="think_start",
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "")
def test_batch_think_end_with_illegal_tool(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning</think>\nABC\n<tool_call>tool params</tool_call>", request=self.test_request
model_output="reasoning</think>\nABC\n<tool_call>tool params</tool_call>",
request=self.test_request,
model_status="think_start",
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\nABC\n<tool_call>tool params</tool_call>")
def test_batch_think_end_content_with_newline(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning</think>\n\n actual response", request=self.test_request
model_output="reasoning</think>\n\n actual response",
request=self.test_request,
model_status="think_start",
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\n\n actual response")
def test_think_end_status_non_streaming(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="response", request=self.test_request, model_status="think_end"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "response")
reasoning, content = self.parser.extract_reasoning_content(
model_output="<tool_call>response", request=self.test_request, model_status="think_end"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "")
reasoning, content = self.parser.extract_reasoning_content(
model_output="\n 1<tool_call>response", request=self.test_request, model_status="think_end"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "\n 1<tool_call>response")
def test_other_status_non_streaming(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="response", request=self.test_request, model_status="tool_call_start"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "")
reasoning, content = self.parser.extract_reasoning_content(
model_output="response", request=self.test_request, model_status="tool_call_end"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "")
def test_find_last_special_token(self):
result = self.parser.find_last_special_token([100, 110, 120, 130])
self.assertEqual(result, 100)
result = self.parser.find_last_special_token([0])
self.assertEqual(result, -1)
def test_get_model_status(self):
result = self.parser.get_model_status([100, 110, 120, 130])
self.assertEqual(result, "think_start")
result = self.parser.get_model_status([0])
self.assertEqual(result, "think_start")
class TestErnieVLReasoningParser(unittest.TestCase):
def setUp(self):
@@ -442,6 +655,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
delta_token_ids=[100, 110, 120, 130],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "")
@@ -455,6 +669,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
previous_token_ids=[200, 201, 202, 100],
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
delta_token_ids=[110, 120, 130],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.reasoning_content)
@@ -468,6 +683,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 110, 120, 130],
delta_token_ids=[110, 120, 130],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.content)
@@ -475,7 +691,7 @@ class TestErnieVLReasoningParser(unittest.TestCase):
def test_extract_reasoning_content(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning</think>\nactual response", request=self.test_request
model_output="reasoning</think>\nactual response", request=self.test_request, model_status="think_start"
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\nactual response")
+135
View File
@@ -0,0 +1,135 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
from fastdeploy.reasoning.ernie_vl_reasoning_parsers import ErnieVLReasoningParser
class MockTokenizer:
"""Minimal tokenizer with vocab for testing."""
def __init__(self):
self.vocab = {
"<think>": 100,
"</think>": 101,
}
def get_vocab(self):
"""Return vocab dict for testing."""
return self.vocab
class TestErnieVLReasoningParser(unittest.TestCase):
def setUp(self):
self.parser = ErnieVLReasoningParser(MockTokenizer())
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
self.tokenizer = MockTokenizer()
def test_get_model_status(self):
status = self.parser.get_model_status([1, 2, 100])
self.assertEqual(status, "think_start")
status = self.parser.get_model_status([1, 2, 101])
self.assertEqual(status, "think_end")
status = self.parser.get_model_status([1])
self.assertEqual(status, "think_start")
def test_streaming_thinking_content(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[200],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "a")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a</think>b",
delta_text="a</think>b",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 101, 102],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "a")
self.assertEqual(msg.content, "b")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="a</think>",
current_text="a</think>b",
delta_text="b",
previous_token_ids=[1, 101],
current_token_ids=[],
delta_token_ids=[102],
model_status="think_start",
)
self.assertEqual(msg.content, "b")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "a")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[200],
model_status="think_end",
)
self.assertEqual(msg.content, "a")
def test_none_streaming_thinking_content(self):
reasoning_content, content = self.parser.extract_reasoning_content(
model_output="a",
request={},
model_status="think_start",
)
self.assertEqual(reasoning_content, "")
self.assertEqual(content, "a")
reasoning_content, content = self.parser.extract_reasoning_content(
model_output="a</think>b",
request={},
model_status="think_start",
)
self.assertEqual(reasoning_content, "a")
self.assertEqual(content, "b")
reasoning_content, content = self.parser.extract_reasoning_content(
model_output="a",
request={},
model_status="think_end",
)
self.assertEqual(reasoning_content, "")
self.assertEqual(content, "a")
if __name__ == "__main__":
unittest.main()