mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization]Unified data processing for online and offline (#6891)
* remove process_request * fix chat * fix unit test * remove process response * fix unit test * fix offline decode * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix sampling_params --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
+21
-14
@@ -266,19 +266,24 @@ class LLMEngine:
|
|||||||
# TODO 输入输出长度确认
|
# TODO 输入输出长度确认
|
||||||
|
|
||||||
if sampling_params is not None:
|
if sampling_params is not None:
|
||||||
task.update(asdict(sampling_params))
|
if sampling_params.temperature is not None and abs(sampling_params.temperature) < 1e-06:
|
||||||
|
sampling_params.temperature = 1e-06
|
||||||
|
task.update({k: v for k, v in asdict(sampling_params).items() if v is not None})
|
||||||
|
|
||||||
|
# Prepare chat_template_kwargs before calling process_request_dict
|
||||||
|
chat_template_kwargs = kwargs.get("chat_template_kwargs") or {}
|
||||||
|
chat_template_kwargs["chat_template"] = kwargs.get("chat_template")
|
||||||
|
task["chat_template_kwargs"] = chat_template_kwargs
|
||||||
|
|
||||||
|
# Use dict to call process_request_dict
|
||||||
|
task = self.engine.data_processor.process_request_dict(task, self.cfg.model_config.max_model_len)
|
||||||
|
|
||||||
|
# Create Request struct after processing
|
||||||
request = Request.from_dict(task)
|
request = Request.from_dict(task)
|
||||||
request.metrics.scheduler_recv_req_time = time.time()
|
request.metrics.scheduler_recv_req_time = time.time()
|
||||||
llm_logger.info(f"Receive request {request}")
|
llm_logger.info(f"Receive request {request}")
|
||||||
if sampling_params is not None:
|
|
||||||
if sampling_params.temperature is not None and abs(sampling_params.temperature) < 1e-06:
|
|
||||||
sampling_params.temperature = 1e-06
|
|
||||||
request.sampling_params = sampling_params
|
|
||||||
request.metrics.preprocess_start_time = time.time()
|
request.metrics.preprocess_start_time = time.time()
|
||||||
chat_template_kwargs = kwargs.get("chat_template_kwargs") or {}
|
|
||||||
chat_template_kwargs["chat_template"] = kwargs.get("chat_template")
|
|
||||||
kwargs["chat_template_kwargs"] = chat_template_kwargs
|
|
||||||
request = self.engine.data_processor.process_request(request, self.cfg.model_config.max_model_len, **kwargs)
|
|
||||||
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
||||||
request.need_prefill_tokens = request.prompt_token_ids_len
|
request.need_prefill_tokens = request.prompt_token_ids_len
|
||||||
input_ids_len = request.prompt_token_ids_len
|
input_ids_len = request.prompt_token_ids_len
|
||||||
@@ -716,16 +721,18 @@ class LLMEngine:
|
|||||||
for result in self._get_generated_tokens(req_id):
|
for result in self._get_generated_tokens(req_id):
|
||||||
is_end = result.finished
|
is_end = result.finished
|
||||||
if stream and not is_end:
|
if stream and not is_end:
|
||||||
processed = self.engine.data_processor.process_response(result)
|
output = self.engine.data_processor.process_response_dict(
|
||||||
if processed is None:
|
result.to_dict(), stream=False, include_stop_str_in_output=False
|
||||||
|
)
|
||||||
|
if output is None:
|
||||||
continue
|
continue
|
||||||
output = processed.to_dict()
|
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
# Exit loop if termination condition is met
|
# Exit loop if termination condition is met
|
||||||
if is_end:
|
if is_end:
|
||||||
processed = self.engine.data_processor.process_response(result)
|
output = self.engine.data_processor.process_response_dict(
|
||||||
output = processed.to_dict()
|
result.to_dict(), stream=False, include_stop_str_in_output=False, direct_decode=not stream
|
||||||
|
)
|
||||||
llm_logger.debug(f"Generate result: {output}")
|
llm_logger.debug(f"Generate result: {output}")
|
||||||
if not stream:
|
if not stream:
|
||||||
yield output
|
yield output
|
||||||
|
|||||||
@@ -728,7 +728,6 @@ class CompletionOutput:
|
|||||||
delta_message: Optional[DeltaMessage] = None
|
delta_message: Optional[DeltaMessage] = None
|
||||||
multipart: Optional[list[Any]] = None
|
multipart: Optional[list[Any]] = None
|
||||||
num_image_tokens: Optional[int] = None
|
num_image_tokens: Optional[int] = None
|
||||||
enable_parser: bool = False
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from fastdeploy.engine.args_utils import EngineArgs
|
from fastdeploy.engine.args_utils import EngineArgs
|
||||||
from fastdeploy.engine.engine import LLMEngine
|
from fastdeploy.engine.engine import LLMEngine
|
||||||
|
from fastdeploy.engine.request import RequestOutput
|
||||||
from fastdeploy.engine.sampling_params import SamplingParams
|
from fastdeploy.engine.sampling_params import SamplingParams
|
||||||
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionToolsParam
|
from fastdeploy.entrypoints.openai.protocol import ChatCompletionToolsParam
|
||||||
@@ -565,7 +566,11 @@ class LLM:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
result = self.req_output.pop(req_id)
|
result = self.req_output.pop(req_id)
|
||||||
result = self.llm_engine.data_processor.process_response(result)
|
result_dict = result.to_dict()
|
||||||
|
result_dict = self.llm_engine.data_processor.process_response_dict(
|
||||||
|
result_dict, stream=False, include_stop_str_in_output=False, direct_decode=True
|
||||||
|
)
|
||||||
|
result = RequestOutput.from_dict(result_dict)
|
||||||
|
|
||||||
# filter logprobs
|
# filter logprobs
|
||||||
if result.outputs.top_logprobs is not None and topk_logprobs is not None:
|
if result.outputs.top_logprobs is not None and topk_logprobs is not None:
|
||||||
|
|||||||
@@ -81,90 +81,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
if reasoning_parser_obj:
|
if reasoning_parser_obj:
|
||||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||||
|
|
||||||
def process_request(self, request, max_model_len=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Preprocess the request
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request (Dict): may contain text and messages fields
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether preprocessing is successful
|
|
||||||
str: error message
|
|
||||||
"""
|
|
||||||
data_processor_logger.info(f"Start processing request: {request}")
|
|
||||||
request = self._apply_default_parameters(request)
|
|
||||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
|
||||||
request.eos_token_ids = self.eos_token_ids
|
|
||||||
|
|
||||||
# processing stop_sequences and stop_token_ids
|
|
||||||
process_stop_token_ids(request, self.update_stop_seq)
|
|
||||||
|
|
||||||
# processing bad_words
|
|
||||||
bad_words = request.get("bad_words")
|
|
||||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
|
||||||
if bad_words:
|
|
||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
|
||||||
request["bad_words_token_ids"] = bad_words_token_ids
|
|
||||||
|
|
||||||
# processing prompt_token_ids
|
|
||||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
|
||||||
if request.prompt is not None:
|
|
||||||
prompt = request.prompt
|
|
||||||
tokens = self.tokenizer.tokenize(prompt)
|
|
||||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
|
||||||
request.prompt_token_ids = token_ids
|
|
||||||
data_processor_logger.debug(
|
|
||||||
f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
|
|
||||||
)
|
|
||||||
elif request.messages is not None:
|
|
||||||
task = request.to_dict()
|
|
||||||
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
|
|
||||||
if chat_template_kwargs:
|
|
||||||
if isinstance(chat_template_kwargs, dict):
|
|
||||||
for k, v in chat_template_kwargs.items():
|
|
||||||
if k not in task or task[k] is None:
|
|
||||||
task[k] = v
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
|
||||||
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
|
|
||||||
|
|
||||||
if len(request.prompt_token_ids) == 0:
|
|
||||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
|
||||||
|
|
||||||
# truncate prompts that exceed the length limit
|
|
||||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
|
||||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
|
||||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
|
||||||
if request.get("max_tokens") is None:
|
|
||||||
request.set("max_tokens", max(1, max_tokens))
|
|
||||||
else:
|
|
||||||
request.set("max_tokens", min(max_tokens, request.get("max_tokens")))
|
|
||||||
if request.get("temperature") < _SAMPLING_EPS:
|
|
||||||
# zero temperature means greedy decoding: set top_k=1 to force argmax
|
|
||||||
request.set("temperature", 1)
|
|
||||||
request.set("top_k", 1)
|
|
||||||
if request.get("top_p") < _SAMPLING_EPS:
|
|
||||||
request.set("top_p", _SAMPLING_EPS)
|
|
||||||
request.set("top_k", 1)
|
|
||||||
if self.reasoning_parser:
|
|
||||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
|
||||||
parts = request.request_id.split("_")
|
|
||||||
if len(parts) > 1:
|
|
||||||
real_req_id = parts[0]
|
|
||||||
index = int(parts[1])
|
|
||||||
n = request.get("n", 1)
|
|
||||||
for idx in range(index * n, (index + 1) * n):
|
|
||||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
|
||||||
else:
|
|
||||||
self.model_status_dict[request.request_id] = model_status
|
|
||||||
request.enable_thinking = model_status == "think_start"
|
|
||||||
|
|
||||||
data_processor_logger.info(f"Processed request: {request}")
|
|
||||||
return request
|
|
||||||
|
|
||||||
def process_request_dict(self, request, max_model_len=None):
|
def process_request_dict(self, request, max_model_len=None):
|
||||||
"""
|
"""
|
||||||
Preprocess the request
|
Preprocess the request
|
||||||
@@ -257,46 +173,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
data_processor_logger.info(f"Processed request dict: {request}")
|
data_processor_logger.info(f"Processed request dict: {request}")
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def process_response(self, response_dict, **kwargs):
|
|
||||||
"""
|
|
||||||
Preprocess the response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_dict (Dict): response for engine, contain ids fields
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: response contain text fields
|
|
||||||
"""
|
|
||||||
req_id = response_dict.request_id
|
|
||||||
token_ids = response_dict.outputs.token_ids
|
|
||||||
|
|
||||||
response_dict.usage = {"completion_tokens": response_dict.outputs.index + 1}
|
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
|
||||||
token_ids = token_ids[:-1]
|
|
||||||
full_text = self.tokenizer.decode(token_ids)
|
|
||||||
if self.reasoning_parser:
|
|
||||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
|
||||||
full_text,
|
|
||||||
response_dict,
|
|
||||||
self.model_status_dict[req_id],
|
|
||||||
)
|
|
||||||
response_dict.outputs.text = text
|
|
||||||
response_dict.outputs.reasoning_content = reasoning_content
|
|
||||||
else:
|
|
||||||
response_dict.outputs.text = full_text
|
|
||||||
if self.tool_parser_obj:
|
|
||||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
|
||||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
|
||||||
if tool_call_info.tools_called:
|
|
||||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
|
||||||
response_dict.outputs.text = tool_call_info.content
|
|
||||||
if req_id in self.model_status_dict:
|
|
||||||
del self.model_status_dict[req_id]
|
|
||||||
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
|
||||||
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
|
|
||||||
return None
|
|
||||||
return response_dict
|
|
||||||
|
|
||||||
def process_response_dict(self, response_dict, stream, **kwargs):
|
def process_response_dict(self, response_dict, stream, **kwargs):
|
||||||
"""
|
"""
|
||||||
Preprocess the response
|
Preprocess the response
|
||||||
@@ -326,11 +202,15 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
is_end = response_dict["finished"]
|
is_end = response_dict["finished"]
|
||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
request = kwargs.get("request", None)
|
request = kwargs.get("request", None)
|
||||||
|
direct_decode = kwargs.get("direct_decode", False)
|
||||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
if direct_decode:
|
||||||
response_dict["outputs"]["enable_parser"] = False
|
delta_text = self.tokenizer.decode(token_ids)
|
||||||
|
previous_texts = ""
|
||||||
|
else:
|
||||||
|
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
if is_end:
|
if is_end:
|
||||||
full_text = previous_texts + delta_text
|
full_text = previous_texts + delta_text
|
||||||
response_dict["outputs"]["text"] = full_text
|
response_dict["outputs"]["text"] = full_text
|
||||||
@@ -351,8 +231,9 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
|
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
|
||||||
response_dict["outputs"]["text"] = tool_call_info.content
|
response_dict["outputs"]["text"] = tool_call_info.content
|
||||||
response_dict["outputs"]["completion_tokens"] = full_text
|
response_dict["outputs"]["completion_tokens"] = full_text
|
||||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
if req_id in self.decode_status:
|
||||||
del self.decode_status[req_id]
|
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||||
|
del self.decode_status[req_id]
|
||||||
if req_id in self.model_status_dict:
|
if req_id in self.model_status_dict:
|
||||||
del self.model_status_dict[req_id]
|
del self.model_status_dict[req_id]
|
||||||
return response_dict
|
return response_dict
|
||||||
@@ -371,7 +252,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
request = kwargs.get("request", None)
|
request = kwargs.get("request", None)
|
||||||
response_dict["outputs"]["enable_parser"] = False
|
|
||||||
|
|
||||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
@@ -402,7 +282,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
if not is_end:
|
if not is_end:
|
||||||
response_dict["outputs"]["skipped"] = True
|
response_dict["outputs"]["skipped"] = True
|
||||||
if self.tool_parser_obj:
|
if self.tool_parser_obj:
|
||||||
response_dict["outputs"]["enable_parser"] = True
|
|
||||||
if req_id not in self.tool_parser_dict:
|
if req_id not in self.tool_parser_dict:
|
||||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||||
tool_parser = self.tool_parser_dict[req_id]
|
tool_parser = self.tool_parser_dict[req_id]
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from collections.abc import Mapping
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from paddleformers.generation import GenerationConfig
|
from paddleformers.generation import GenerationConfig
|
||||||
|
|
||||||
from fastdeploy.engine.request import Request
|
|
||||||
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
|
from fastdeploy.input.ernie4_5_processor import Ernie4_5Processor
|
||||||
from fastdeploy.input.utils import IDS_TYPE_FLAG, process_stop_token_ids
|
from fastdeploy.input.utils import IDS_TYPE_FLAG, process_stop_token_ids
|
||||||
from fastdeploy.utils import data_processor_logger
|
from fastdeploy.utils import data_processor_logger
|
||||||
@@ -120,16 +119,6 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
set_value(request, "presence_penalty", 0.0)
|
set_value(request, "presence_penalty", 0.0)
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def process_request(self, request, max_model_len=None, **kwargs):
|
|
||||||
"""process the input data"""
|
|
||||||
task = request.to_dict()
|
|
||||||
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
|
|
||||||
self.process_request_dict(task, max_model_len)
|
|
||||||
request = Request.from_dict(task)
|
|
||||||
request = self._apply_default_parameters(request)
|
|
||||||
|
|
||||||
return request
|
|
||||||
|
|
||||||
def _parse_processor_kwargs(self, kwargs):
|
def _parse_processor_kwargs(self, kwargs):
|
||||||
"""解析多模态处理器参数配置"""
|
"""解析多模态处理器参数配置"""
|
||||||
if not kwargs:
|
if not kwargs:
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from fastdeploy.engine.request import Request
|
|
||||||
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
||||||
from fastdeploy.utils import data_processor_logger
|
from fastdeploy.utils import data_processor_logger
|
||||||
|
|
||||||
@@ -76,25 +75,6 @@ class PaddleOCRVLProcessor(TextProcessor):
|
|||||||
self.image_patch_id = self.processor.image_patch_id
|
self.image_patch_id = self.processor.image_patch_id
|
||||||
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
||||||
|
|
||||||
def process_request(self, request, max_model_len=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Process incoming request and generate model inputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Input request object
|
|
||||||
max_model_len (int, optional): Maximum context length
|
|
||||||
**kwargs: Additional processing parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Request: Processed request with model inputs
|
|
||||||
"""
|
|
||||||
task = request.to_dict()
|
|
||||||
task["enable_thinking"] = kwargs.get("enable_thinking", False)
|
|
||||||
self.process_request_dict(task, max_model_len)
|
|
||||||
request = Request.from_dict(task)
|
|
||||||
request = self._apply_default_parameters(request)
|
|
||||||
return request
|
|
||||||
|
|
||||||
def _parse_processor_kwargs(self, kwargs):
|
def _parse_processor_kwargs(self, kwargs):
|
||||||
"""
|
"""
|
||||||
Parse and validate multimodal processor arguments.
|
Parse and validate multimodal processor arguments.
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from fastdeploy.engine.request import Request
|
|
||||||
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
||||||
from fastdeploy.utils import data_processor_logger
|
from fastdeploy.utils import data_processor_logger
|
||||||
|
|
||||||
@@ -74,25 +73,6 @@ class Qwen3VLProcessor(TextProcessor):
|
|||||||
self.image_patch_id = self.processor.image_token_id
|
self.image_patch_id = self.processor.image_token_id
|
||||||
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
||||||
|
|
||||||
def process_request(self, request, max_model_len=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Process incoming request and generate model inputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Input request object
|
|
||||||
max_model_len (int, optional): Maximum context length
|
|
||||||
**kwargs: Additional processing parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Request: Processed request with model inputs
|
|
||||||
"""
|
|
||||||
task = request.to_dict()
|
|
||||||
task["enable_thinking"] = kwargs.get("enable_thinking", False)
|
|
||||||
self.process_request_dict(task, max_model_len)
|
|
||||||
request = Request.from_dict(task)
|
|
||||||
request = self._apply_default_parameters(request)
|
|
||||||
return request
|
|
||||||
|
|
||||||
def _parse_processor_kwargs(self, kwargs):
|
def _parse_processor_kwargs(self, kwargs):
|
||||||
"""
|
"""
|
||||||
Parse and validate multimodal processor arguments.
|
Parse and validate multimodal processor arguments.
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from fastdeploy.engine.request import Request
|
|
||||||
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
from fastdeploy.input.text_processor import DataProcessor as TextProcessor
|
||||||
from fastdeploy.input.utils import process_stop_token_ids
|
from fastdeploy.input.utils import process_stop_token_ids
|
||||||
from fastdeploy.utils import data_processor_logger
|
from fastdeploy.utils import data_processor_logger
|
||||||
@@ -75,25 +74,6 @@ class QwenVLProcessor(TextProcessor):
|
|||||||
self.image_patch_id = self.processor.image_token_id
|
self.image_patch_id = self.processor.image_token_id
|
||||||
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
||||||
|
|
||||||
def process_request(self, request, max_model_len=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Process incoming request and generate model inputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Input request object
|
|
||||||
max_model_len (int, optional): Maximum context length
|
|
||||||
**kwargs: Additional processing parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Request: Processed request with model inputs
|
|
||||||
"""
|
|
||||||
task = request.to_dict()
|
|
||||||
task["enable_thinking"] = kwargs.get("enable_thinking", False)
|
|
||||||
self.process_request_dict(task, max_model_len)
|
|
||||||
request = Request.from_dict(task)
|
|
||||||
request = self._apply_default_parameters(request)
|
|
||||||
return request
|
|
||||||
|
|
||||||
def _parse_processor_kwargs(self, kwargs):
|
def _parse_processor_kwargs(self, kwargs):
|
||||||
"""
|
"""
|
||||||
Parse and validate multimodal processor arguments.
|
Parse and validate multimodal processor arguments.
|
||||||
|
|||||||
@@ -76,34 +76,6 @@ class BaseDataProcessor(ABC):
|
|||||||
set_value(request, "presence_penalty", 0.0)
|
set_value(request, "presence_penalty", 0.0)
|
||||||
return request
|
return request
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process_request(self, request, **kwargs):
|
|
||||||
"""
|
|
||||||
Preprocess the request
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request (Dict): may contain text and messages fields
|
|
||||||
**kwargs: others
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether preprocessing is successful
|
|
||||||
str: error message
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process_response(self, response_dict):
|
|
||||||
"""
|
|
||||||
Preprocess the response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_dict (Dict): response for engine, contain ids fields
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: response contain text fields
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def text2ids(self, text, max_model_len=None):
|
def text2ids(self, text, max_model_len=None):
|
||||||
"""
|
"""
|
||||||
text to token ids
|
text to token ids
|
||||||
@@ -294,113 +266,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
logits_processors_args.pop("think_prompt_last_token_id", None)
|
logits_processors_args.pop("think_prompt_last_token_id", None)
|
||||||
return logits_processors_args
|
return logits_processors_args
|
||||||
|
|
||||||
def process_request(self, request, max_model_len=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Preprocess the request
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request (Dict): may contain text and messages fields
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether preprocessing is successful
|
|
||||||
str: error message
|
|
||||||
"""
|
|
||||||
data_processor_logger.info(f"Start processing request: {request}")
|
|
||||||
request = self._apply_default_parameters(request)
|
|
||||||
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
|
|
||||||
request.eos_token_ids = self.eos_token_ids
|
|
||||||
|
|
||||||
# processing stop_sequences and stop_token_ids
|
|
||||||
process_stop_token_ids(request, self.update_stop_seq)
|
|
||||||
|
|
||||||
# processing bad_words
|
|
||||||
bad_words = request.get("bad_words")
|
|
||||||
bad_words_token_ids = request.get("bad_words_token_ids")
|
|
||||||
if bad_words:
|
|
||||||
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
|
|
||||||
request["bad_words_token_ids"] = bad_words_token_ids
|
|
||||||
|
|
||||||
logits_processors_args = request.get("logits_processors_args") or {}
|
|
||||||
think_stop_sentence = logits_processors_args.get("think_stop_sentence")
|
|
||||||
if isinstance(think_stop_sentence, str) and think_stop_sentence:
|
|
||||||
newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
|
|
||||||
sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
|
|
||||||
logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
|
|
||||||
logits_processors_args.pop("think_stop_sentence", None)
|
|
||||||
request["logits_processors_args"] = logits_processors_args
|
|
||||||
|
|
||||||
# processing prompt_token_ids
|
|
||||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
|
||||||
if request.prompt is not None:
|
|
||||||
prompt = request.prompt
|
|
||||||
add_special_tokens = request.get("add_special_tokens", False)
|
|
||||||
assert isinstance(prompt, str) or (
|
|
||||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
|
||||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
|
||||||
if isinstance(prompt, list): # if prompt is a token id list
|
|
||||||
request.prompt_token_ids = prompt
|
|
||||||
else:
|
|
||||||
request.prompt_token_ids = self.text2ids(
|
|
||||||
request.prompt, max_model_len, add_special_tokens=add_special_tokens
|
|
||||||
)
|
|
||||||
elif request.messages is not None:
|
|
||||||
if self.tokenizer.chat_template is None:
|
|
||||||
raise ValueError("This model does not support chat_template.")
|
|
||||||
task = request.to_dict()
|
|
||||||
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
|
|
||||||
if chat_template_kwargs:
|
|
||||||
if isinstance(chat_template_kwargs, dict):
|
|
||||||
for k, v in chat_template_kwargs.items():
|
|
||||||
if k not in task or task[k] is None:
|
|
||||||
task[k] = v
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
|
||||||
task.setdefault("enable_thinking", True)
|
|
||||||
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
|
||||||
|
|
||||||
if len(request.prompt_token_ids) == 0:
|
|
||||||
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
|
|
||||||
|
|
||||||
# truncate prompts that exceed the length limit
|
|
||||||
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
|
|
||||||
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
|
|
||||||
|
|
||||||
logits_processors_args = request.get("logits_processors_args") or {}
|
|
||||||
logits_processors_args = self._update_thinking_prompt_state(request.prompt_token_ids, logits_processors_args)
|
|
||||||
request["logits_processors_args"] = logits_processors_args
|
|
||||||
|
|
||||||
max_tokens = max_model_len - len(request.prompt_token_ids)
|
|
||||||
if request.get("max_tokens") is None:
|
|
||||||
request.set("max_tokens", max(1, max_tokens))
|
|
||||||
else:
|
|
||||||
request.set("max_tokens", min(max_tokens, request.get("max_tokens")))
|
|
||||||
if request.get("temperature") < _SAMPLING_EPS:
|
|
||||||
# zero temperature means greedy decoding: set top_k=1 to force argmax
|
|
||||||
request.set("temperature", 1)
|
|
||||||
request.set("top_k", 1)
|
|
||||||
if request.get("top_p") < _SAMPLING_EPS:
|
|
||||||
request.set("top_p", _SAMPLING_EPS)
|
|
||||||
request.set("top_k", 1)
|
|
||||||
if self.reasoning_parser:
|
|
||||||
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
|
|
||||||
parts = request.request_id.split("_")
|
|
||||||
if len(parts) > 1:
|
|
||||||
real_req_id = parts[0]
|
|
||||||
index = int(parts[1])
|
|
||||||
n = request.get("n", 1)
|
|
||||||
for idx in range(index * n, (index + 1) * n):
|
|
||||||
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
|
|
||||||
else:
|
|
||||||
self.model_status_dict[request.request_id] = model_status
|
|
||||||
request.enable_thinking = model_status == "think_start"
|
|
||||||
|
|
||||||
if request.get("response_max_tokens") is not None and request.enable_thinking is False:
|
|
||||||
request["max_tokens"] = min(request["response_max_tokens"], request["max_tokens"])
|
|
||||||
data_processor_logger.info(f"Processed request: {request}")
|
|
||||||
return request
|
|
||||||
|
|
||||||
def process_request_dict(self, request, max_model_len=None, **kwargs):
|
def process_request_dict(self, request, max_model_len=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Preprocess the request
|
Preprocess the request
|
||||||
@@ -439,10 +304,17 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
# processing prompt_token_ids
|
# processing prompt_token_ids
|
||||||
if not request.get("prompt_token_ids"):
|
if not request.get("prompt_token_ids"):
|
||||||
if request.get("prompt"):
|
if request.get("prompt"):
|
||||||
add_special_tokens = request.get("add_special_tokens", False)
|
prompt = request["prompt"]
|
||||||
request["prompt_token_ids"] = self.text2ids(
|
assert isinstance(prompt, str) or (
|
||||||
request["prompt"], max_model_len, add_special_tokens=add_special_tokens
|
isinstance(prompt, list) and all(isinstance(t, int) for t in prompt)
|
||||||
).tolist()
|
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
request["prompt_token_ids"] = prompt
|
||||||
|
else:
|
||||||
|
add_special_tokens = request.get("add_special_tokens", False)
|
||||||
|
request["prompt_token_ids"] = self.text2ids(
|
||||||
|
prompt, max_model_len, add_special_tokens=add_special_tokens
|
||||||
|
).tolist()
|
||||||
elif request.get("messages"):
|
elif request.get("messages"):
|
||||||
if self.tokenizer.chat_template is None:
|
if self.tokenizer.chat_template is None:
|
||||||
raise ValueError("This model does not support chat_template.")
|
raise ValueError("This model does not support chat_template.")
|
||||||
@@ -504,39 +376,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
full_text = self.tokenizer.decode(token_ids, **kwargs)
|
full_text = self.tokenizer.decode(token_ids, **kwargs)
|
||||||
return full_text
|
return full_text
|
||||||
|
|
||||||
def process_response(self, response_dict, **kwargs):
|
|
||||||
"""
|
|
||||||
Preprocess the response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_dict (Dict): response for engine, contain ids fields
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: response contain text fields
|
|
||||||
"""
|
|
||||||
req_id = response_dict.request_id
|
|
||||||
token_ids = response_dict.outputs.token_ids
|
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
|
||||||
token_ids = token_ids[:-1]
|
|
||||||
full_text = self.tokenizer.decode(token_ids)
|
|
||||||
response_dict.outputs.text = full_text
|
|
||||||
if self.reasoning_parser:
|
|
||||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
|
||||||
full_text, response_dict, self.model_status_dict[req_id]
|
|
||||||
)
|
|
||||||
response_dict.outputs.text = text
|
|
||||||
response_dict.outputs.reasoning_content = reasoning_content
|
|
||||||
if self.tool_parser_obj:
|
|
||||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
|
||||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
|
||||||
if tool_call_info.tools_called:
|
|
||||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
|
||||||
if req_id in self.model_status_dict:
|
|
||||||
del self.model_status_dict[req_id]
|
|
||||||
data_processor_logger.info(f"req_id:{req_id}, token_ids: {token_ids}")
|
|
||||||
|
|
||||||
return response_dict
|
|
||||||
|
|
||||||
def process_response_dict_normal(self, response_dict, **kwargs):
|
def process_response_dict_normal(self, response_dict, **kwargs):
|
||||||
"""
|
"""
|
||||||
Preprocess the response
|
Preprocess the response
|
||||||
@@ -551,10 +390,15 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
is_end = response_dict["finished"]
|
is_end = response_dict["finished"]
|
||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
request = kwargs.get("request", None)
|
request = kwargs.get("request", None)
|
||||||
|
direct_decode = kwargs.get("direct_decode", False)
|
||||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] in self.eos_token_ids:
|
if token_ids[-1] in self.eos_token_ids:
|
||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
if direct_decode:
|
||||||
|
delta_text = self.tokenizer.decode(token_ids)
|
||||||
|
previous_texts = ""
|
||||||
|
else:
|
||||||
|
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
if is_end:
|
if is_end:
|
||||||
full_text = previous_texts + delta_text
|
full_text = previous_texts + delta_text
|
||||||
response_dict["outputs"]["completion_tokens"] = full_text
|
response_dict["outputs"]["completion_tokens"] = full_text
|
||||||
@@ -574,8 +418,9 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
||||||
if tool_call_info.tools_called:
|
if tool_call_info.tools_called:
|
||||||
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
|
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
|
||||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
if req_id in self.decode_status:
|
||||||
del self.decode_status[req_id]
|
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||||
|
del self.decode_status[req_id]
|
||||||
if req_id in self.model_status_dict:
|
if req_id in self.model_status_dict:
|
||||||
del self.model_status_dict[req_id]
|
del self.model_status_dict[req_id]
|
||||||
return response_dict
|
return response_dict
|
||||||
@@ -594,7 +439,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
request = kwargs.get("request", None)
|
request = kwargs.get("request", None)
|
||||||
response_dict["outputs"]["enable_parser"] = False
|
|
||||||
|
|
||||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] in self.eos_token_ids:
|
if token_ids[-1] in self.eos_token_ids:
|
||||||
@@ -606,7 +450,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
response_dict["outputs"]["tool_calls"] = None
|
response_dict["outputs"]["tool_calls"] = None
|
||||||
response_dict["outputs"]["reasoning_content"] = ""
|
response_dict["outputs"]["reasoning_content"] = ""
|
||||||
if self.reasoning_parser:
|
if self.reasoning_parser:
|
||||||
response_dict["outputs"]["enable_parser"] = True
|
|
||||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||||
previous_texts,
|
previous_texts,
|
||||||
previous_texts + delta_text,
|
previous_texts + delta_text,
|
||||||
@@ -626,7 +469,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
if not is_end:
|
if not is_end:
|
||||||
response_dict["outputs"]["skipped"] = True
|
response_dict["outputs"]["skipped"] = True
|
||||||
if self.tool_parser_obj:
|
if self.tool_parser_obj:
|
||||||
response_dict["outputs"]["enable_parser"] = True
|
|
||||||
if req_id not in self.tool_parser_dict:
|
if req_id not in self.tool_parser_dict:
|
||||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||||
tool_parser = self.tool_parser_dict[req_id]
|
tool_parser = self.tool_parser_dict[req_id]
|
||||||
|
|||||||
@@ -351,7 +351,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
full_text = previous_texts + delta_text
|
full_text = previous_texts + delta_text
|
||||||
response_obj.outputs.text = full_text
|
response_obj.outputs.text = full_text
|
||||||
if self.reasoning_parser:
|
if self.reasoning_parser:
|
||||||
response_obj.outputs.enable_parser = True
|
|
||||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||||
full_text,
|
full_text,
|
||||||
request,
|
request,
|
||||||
@@ -362,7 +361,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
reasoning_tokens = self.tokenizer.tokenize(reasoning_content)
|
||||||
response_obj.outputs.reasoning_token_num = len(reasoning_tokens)
|
response_obj.outputs.reasoning_token_num = len(reasoning_tokens)
|
||||||
if self.tool_parser_obj:
|
if self.tool_parser_obj:
|
||||||
response_obj.outputs.enable_parser = True
|
|
||||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||||
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
||||||
if tool_call_info.tools_called:
|
if tool_call_info.tools_called:
|
||||||
@@ -396,7 +394,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
response_obj.outputs.completion_tokens = delta_text
|
response_obj.outputs.completion_tokens = delta_text
|
||||||
if self.reasoning_parser:
|
if self.reasoning_parser:
|
||||||
response_obj.outputs.enable_parser = True
|
|
||||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||||
previous_texts,
|
previous_texts,
|
||||||
previous_texts + delta_text,
|
previous_texts + delta_text,
|
||||||
@@ -420,7 +417,6 @@ class Ernie4_5Processor(BaseDataProcessor):
|
|||||||
else:
|
else:
|
||||||
response_obj.outputs.text = delta_text
|
response_obj.outputs.text = delta_text
|
||||||
if self.tool_parser_obj:
|
if self.tool_parser_obj:
|
||||||
response_obj.outputs.enable_parser = True
|
|
||||||
if req_id not in self.tool_parser_dict:
|
if req_id not in self.tool_parser_dict:
|
||||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||||
tool_parser = self.tool_parser_dict[req_id]
|
tool_parser = self.tool_parser_dict[req_id]
|
||||||
|
|||||||
@@ -555,7 +555,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
response_obj.outputs.completion_tokens = full_text
|
response_obj.outputs.completion_tokens = full_text
|
||||||
response_obj.outputs.text = full_text
|
response_obj.outputs.text = full_text
|
||||||
if self.reasoning_parser:
|
if self.reasoning_parser:
|
||||||
response_obj.outputs.enable_parser = True
|
|
||||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
|
||||||
full_text,
|
full_text,
|
||||||
request,
|
request,
|
||||||
@@ -566,7 +565,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else []
|
reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else []
|
||||||
response_obj.outputs.reasoning_token_num = len(reasoning_tokens)
|
response_obj.outputs.reasoning_token_num = len(reasoning_tokens)
|
||||||
if self.tool_parser_obj:
|
if self.tool_parser_obj:
|
||||||
response_obj.outputs.enable_parser = True
|
|
||||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||||
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
|
||||||
if tool_call_info.tools_called:
|
if tool_call_info.tools_called:
|
||||||
@@ -600,7 +598,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
response_obj.outputs.completion_tokens = delta_text
|
response_obj.outputs.completion_tokens = delta_text
|
||||||
if self.reasoning_parser:
|
if self.reasoning_parser:
|
||||||
response_obj.outputs.enable_parser = True
|
|
||||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||||
previous_texts,
|
previous_texts,
|
||||||
previous_texts + delta_text,
|
previous_texts + delta_text,
|
||||||
@@ -615,7 +612,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else []
|
reasoning_tokens = self.tokenizer.tokenize(reasoning_content) if reasoning_content else []
|
||||||
response_obj.outputs.reasoning_token_num = len(reasoning_tokens)
|
response_obj.outputs.reasoning_token_num = len(reasoning_tokens)
|
||||||
if self.tool_parser_obj:
|
if self.tool_parser_obj:
|
||||||
response_obj.outputs.enable_parser = True
|
|
||||||
if req_id not in self.tool_parser_dict:
|
if req_id not in self.tool_parser_dict:
|
||||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||||
tool_parser = self.tool_parser_dict[req_id]
|
tool_parser = self.tool_parser_dict[req_id]
|
||||||
|
|||||||
@@ -42,8 +42,8 @@ class DummyDataProcessor:
|
|||||||
def process_logprob_response(self, token_ids, clean_up_tokenization_spaces: bool = False):
|
def process_logprob_response(self, token_ids, clean_up_tokenization_spaces: bool = False):
|
||||||
return f"tok_{token_ids[0]}"
|
return f"tok_{token_ids[0]}"
|
||||||
|
|
||||||
def process_response(self, result):
|
def process_response_dict(self, response_dict, **kwargs):
|
||||||
return result
|
return response_dict
|
||||||
|
|
||||||
def process_response_dict_streaming(self, response_dict, stream, enable_thinking, include_stop_str_in_output):
|
def process_response_dict_streaming(self, response_dict, stream, enable_thinking, include_stop_str_in_output):
|
||||||
tokens = "".join(f"tok_{tid}" for tid in response_dict["outputs"]["token_ids"])
|
tokens = "".join(f"tok_{tid}" for tid in response_dict["outputs"]["token_ids"])
|
||||||
@@ -61,6 +61,18 @@ class DummyResult:
|
|||||||
def add(self, other):
|
def add(self, other):
|
||||||
self.added = True
|
self.added = True
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"request_id": self.request_id,
|
||||||
|
"finished": self.finished,
|
||||||
|
"prompt_logprobs": self.prompt_logprobs,
|
||||||
|
"outputs": {
|
||||||
|
"token_ids": self.outputs.token_ids,
|
||||||
|
"top_logprobs": self.outputs.top_logprobs,
|
||||||
|
"logprobs": self.outputs.logprobs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _make_engine(vocab_size=5, max_logprobs=5, enable_logprob=True, enable_prefix_caching=False, is_master=True):
|
def _make_engine(vocab_size=5, max_logprobs=5, enable_logprob=True, enable_prefix_caching=False, is_master=True):
|
||||||
cfg = SimpleNamespace(
|
cfg = SimpleNamespace(
|
||||||
@@ -255,7 +267,8 @@ def test_run_engine_and_streaming(monkeypatch):
|
|||||||
|
|
||||||
llm_module.tqdm = DummyTqdm
|
llm_module.tqdm = DummyTqdm
|
||||||
out = llm._run_engine(["r1"], use_tqdm=True, topk_logprobs=-1, num_prompt_logprobs=-1)
|
out = llm._run_engine(["r1"], use_tqdm=True, topk_logprobs=-1, num_prompt_logprobs=-1)
|
||||||
assert out[0] is result
|
assert out[0].request_id == result.request_id
|
||||||
|
assert out[0].finished == result.finished
|
||||||
|
|
||||||
current = DummyResult(
|
current = DummyResult(
|
||||||
"r2",
|
"r2",
|
||||||
|
|||||||
@@ -261,38 +261,25 @@ class TestErnie4_5Processor(unittest.TestCase):
|
|||||||
self.assertEqual(len(stop_seqs2), 2)
|
self.assertEqual(len(stop_seqs2), 2)
|
||||||
self.assertEqual(len(stop_lens2), 2)
|
self.assertEqual(len(stop_lens2), 2)
|
||||||
|
|
||||||
def test_process_request_chat_template_kwargs(self):
|
def test_process_request_dict_with_chat_template_kwargs(self):
|
||||||
"""Test chat_template_kwargs application inside process_request."""
|
"""Test chat_template_kwargs application inside process_request_dict."""
|
||||||
|
|
||||||
proc = self._make_processor()
|
proc = self._make_processor()
|
||||||
|
|
||||||
class ReqObj(dict):
|
request = {
|
||||||
"""Mock request object supporting attributes, set(), and to_dict()."""
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_p": 0.5,
|
||||||
|
"chat_template_kwargs": {"extra": "VALUE"},
|
||||||
|
}
|
||||||
|
|
||||||
def set(self, k, v):
|
processed = proc.process_request_dict(request, max_model_len=20)
|
||||||
self[k] = v
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
self.assertEqual(processed["eos_token_ids"], [proc.tokenizer.eos_token_id])
|
||||||
return self.get(item, None)
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return dict(self)
|
|
||||||
|
|
||||||
request = ReqObj(
|
|
||||||
{
|
|
||||||
"messages": [{"role": "user", "content": "hello"}],
|
|
||||||
"temperature": 0.5,
|
|
||||||
"top_p": 0.5,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
processed = proc.process_request(request, max_model_len=20, chat_template_kwargs={"extra": "VALUE"})
|
|
||||||
|
|
||||||
self.assertEqual(processed.eos_token_ids, [proc.tokenizer.eos_token_id])
|
|
||||||
|
|
||||||
expected_ids = proc.tokenizer.convert_tokens_to_ids(proc.tokenizer.tokenize("hello"))
|
expected_ids = proc.tokenizer.convert_tokens_to_ids(proc.tokenizer.tokenize("hello"))
|
||||||
self.assertIsNotNone(processed.prompt_token_ids)
|
self.assertIsNotNone(processed["prompt_token_ids"])
|
||||||
self.assertEqual(processed.prompt_token_ids, expected_ids)
|
self.assertEqual(processed["prompt_token_ids"], expected_ids)
|
||||||
|
|
||||||
self.assertIn("max_tokens", processed)
|
self.assertIn("max_tokens", processed)
|
||||||
self.assertEqual(processed["max_tokens"], max(1, 20 - len(expected_ids)))
|
self.assertEqual(processed["max_tokens"], max(1, 20 - len(expected_ids)))
|
||||||
@@ -322,21 +309,14 @@ class TestErnie4_5Processor(unittest.TestCase):
|
|||||||
def test_process_response_with_tool_parser(self):
|
def test_process_response_with_tool_parser(self):
|
||||||
"""Verify tool_call extraction in process_response."""
|
"""Verify tool_call extraction in process_response."""
|
||||||
proc = self._make_processor(tool=True)
|
proc = self._make_processor(tool=True)
|
||||||
|
resp = {
|
||||||
class RespObj:
|
"request_id": "reqx",
|
||||||
"""Mock response carrying token_ids and index for testing."""
|
"outputs": {"token_ids": [9, proc.tokenizer.eos_token_id], "index": 0},
|
||||||
|
"finished": True,
|
||||||
def __init__(self):
|
}
|
||||||
self.request_id = "reqx"
|
result = proc.process_response_dict(resp, False)
|
||||||
self.outputs = MagicMock()
|
assert "tool_calls" in result["outputs"]
|
||||||
self.outputs.token_ids = [9, proc.tokenizer.eos_token_id]
|
self.assertEqual(result["outputs"]["tool_calls"][0]["name"], "fake_tool")
|
||||||
self.outputs.index = 0
|
|
||||||
|
|
||||||
resp = RespObj()
|
|
||||||
result = proc.process_response(resp)
|
|
||||||
|
|
||||||
self.assertTrue(hasattr(result.outputs, "tool_calls"))
|
|
||||||
self.assertEqual(result.outputs.tool_calls[0]["name"], "fake_tool")
|
|
||||||
|
|
||||||
def test_process_response_dict_normal_with_tool(self):
|
def test_process_response_dict_normal_with_tool(self):
|
||||||
"""Verify tool_call extraction in normal (non-streaming) response mode."""
|
"""Verify tool_call extraction in normal (non-streaming) response mode."""
|
||||||
|
|||||||
@@ -264,33 +264,6 @@ class TestErnie4_5VLProcessorProcessResponseDictStreaming(unittest.TestCase):
|
|||||||
processor._check_mm_limits(mm_data)
|
processor._check_mm_limits(mm_data)
|
||||||
self.assertIn("Too many image items", str(context.exception))
|
self.assertIn("Too many image items", str(context.exception))
|
||||||
|
|
||||||
def test_process_request(self):
|
|
||||||
"""Test process_request method"""
|
|
||||||
from fastdeploy.engine.request import Request
|
|
||||||
|
|
||||||
# Mock the process_request_dict method
|
|
||||||
self.processor.process_request_dict = MagicMock()
|
|
||||||
|
|
||||||
# Create a mock Request object
|
|
||||||
mock_request = MagicMock(spec=Request)
|
|
||||||
mock_request.to_dict.return_value = {"messages": [{"role": "user", "content": "Hello"}]}
|
|
||||||
|
|
||||||
# Mock Request.from_dict to return a mock request
|
|
||||||
with patch.object(Request, "from_dict") as mock_from_dict:
|
|
||||||
mock_result_request = MagicMock(spec=Request)
|
|
||||||
mock_from_dict.return_value = mock_result_request
|
|
||||||
|
|
||||||
self.processor.process_request(mock_request, max_model_len=100, chat_template_kwargs={"key": "value"})
|
|
||||||
|
|
||||||
# Verify to_dict was called
|
|
||||||
mock_request.to_dict.assert_called_once()
|
|
||||||
|
|
||||||
# Verify process_request_dict was called
|
|
||||||
self.processor.process_request_dict.assert_called_once()
|
|
||||||
|
|
||||||
# Verify from_dict was called
|
|
||||||
mock_from_dict.assert_called_once()
|
|
||||||
|
|
||||||
def test_get_pad_id(self):
|
def test_get_pad_id(self):
|
||||||
"""Test get_pad_id method"""
|
"""Test get_pad_id method"""
|
||||||
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
|
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
|
||||||
|
|||||||
@@ -965,40 +965,6 @@ class TestPaddleOCRVLProcessor(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.processor._check_mm_limits(item_exceeded)
|
self.processor._check_mm_limits(item_exceeded)
|
||||||
|
|
||||||
def test_process_request_wrapper(self):
|
|
||||||
"""测试 process_request 封装方法"""
|
|
||||||
# 1. 模拟输入 Request 对象
|
|
||||||
request_obj = MagicMock()
|
|
||||||
request_dict = {
|
|
||||||
"prompt": "test prompt",
|
|
||||||
"multimodal_data": {"image": ["image1"]},
|
|
||||||
"metadata": {"generated_token_ids": []},
|
|
||||||
"request_id": "test-request",
|
|
||||||
}
|
|
||||||
request_obj.to_dict.return_value = request_dict
|
|
||||||
|
|
||||||
# 2. patch 'Request'
|
|
||||||
patch_target = "fastdeploy.input.paddleocr_vl_processor.paddleocr_vl_processor.Request"
|
|
||||||
with patch(patch_target) as MockRequestCls:
|
|
||||||
|
|
||||||
# 3. 模拟 Request.from_dict 返回一个 mock 对象
|
|
||||||
final_mock_request = MagicMock()
|
|
||||||
MockRequestCls.from_dict.return_value = final_mock_request
|
|
||||||
|
|
||||||
# 4. Call function
|
|
||||||
result_request = self.processor.process_request(request_obj, max_model_len=512)
|
|
||||||
|
|
||||||
# 5. 检查 *传递给* Request.from_dict 的字典
|
|
||||||
self.assertTrue(MockRequestCls.from_dict.called)
|
|
||||||
# 获取传递给 from_dict 的第一个位置参数
|
|
||||||
processed_task_dict = MockRequestCls.from_dict.call_args[0][0]
|
|
||||||
|
|
||||||
# 这个断言现在应该能通过了
|
|
||||||
self.assertEqual(processed_task_dict["prompt_token_ids"], [1, 2, 3])
|
|
||||||
|
|
||||||
# 6. 检查返回的是否是最终的 Request 对象
|
|
||||||
self.assertIs(result_request, final_mock_request)
|
|
||||||
|
|
||||||
def test_parse_processor_kwargs_invalid_type(self):
|
def test_parse_processor_kwargs_invalid_type(self):
|
||||||
"""测试 _parse_processor_kwargs 传入非字典类型"""
|
"""测试 _parse_processor_kwargs 传入非字典类型"""
|
||||||
invalid_input = ["video_max_frames", 10]
|
invalid_input = ["video_max_frames", 10]
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from unittest.mock import MagicMock, patch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from fastdeploy.engine.request import Request
|
|
||||||
from fastdeploy.input.qwen3_vl_processor import Qwen3VLProcessor
|
from fastdeploy.input.qwen3_vl_processor import Qwen3VLProcessor
|
||||||
from fastdeploy.input.qwen3_vl_processor.process import sample_frames
|
from fastdeploy.input.qwen3_vl_processor.process import sample_frames
|
||||||
|
|
||||||
@@ -127,9 +126,9 @@ class TestQwen3VLProcessor(unittest.TestCase):
|
|||||||
self.patcher_parse_image.stop()
|
self.patcher_parse_image.stop()
|
||||||
self.patcher_parse_video.stop()
|
self.patcher_parse_video.stop()
|
||||||
|
|
||||||
def test_process_request(self):
|
def test_process_request_dict_with_multimodal(self):
|
||||||
"""
|
"""
|
||||||
Test processing of Request object with multimodal input
|
Test processing of request dict with multimodal input
|
||||||
|
|
||||||
Validates:
|
Validates:
|
||||||
1. Token ID lengths match position_ids and token_type_ids shapes
|
1. Token ID lengths match position_ids and token_type_ids shapes
|
||||||
@@ -151,17 +150,16 @@ class TestQwen3VLProcessor(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = Request.from_dict(message)
|
result = self.processor.process_request_dict(message, 1024 * 100)
|
||||||
result = self.processor.process_request(request, 1024 * 100)
|
|
||||||
|
|
||||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
|
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["position_ids"].shape[0])
|
||||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
|
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["token_type_ids"].shape[0])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.multimodal_inputs["images"].shape[0],
|
result["multimodal_inputs"]["images"].shape[0],
|
||||||
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
|
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
|
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_process_request_dict(self):
|
def test_process_request_dict(self):
|
||||||
@@ -224,17 +222,16 @@ class TestQwen3VLProcessor(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
request = Request.from_dict(prompt)
|
result = self.processor.process_request_dict(prompt, 1024 * 100)
|
||||||
result = self.processor.process_request(request, 1024 * 100)
|
|
||||||
|
|
||||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
|
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["position_ids"].shape[0])
|
||||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
|
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["token_type_ids"].shape[0])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.multimodal_inputs["images"].shape[0],
|
result["multimodal_inputs"]["images"].shape[0],
|
||||||
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
|
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
|
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_message_and_prompt(self):
|
def test_message_and_prompt(self):
|
||||||
@@ -276,14 +273,15 @@ class TestQwen3VLProcessor(unittest.TestCase):
|
|||||||
"video": [{"video": b"123"}],
|
"video": [{"video": b"123"}],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
request2 = Request.from_dict(prompt)
|
result2 = self.processor.process_request_dict(prompt, 1024 * 100)
|
||||||
result2 = self.processor.process_request(request2, 1024 * 100)
|
|
||||||
|
|
||||||
# Verify both processing methods produce identical results
|
# Verify both processing methods produce identical results
|
||||||
self.assertEqual(result["prompt_token_ids"], result2.prompt_token_ids)
|
self.assertEqual(result["prompt_token_ids"], result2["prompt_token_ids"])
|
||||||
self.assertTrue(np.equal(result["multimodal_inputs"]["grid_thw"], result2.multimodal_inputs["grid_thw"]).all())
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.equal(result["multimodal_inputs"]["position_ids"], result2.multimodal_inputs["position_ids"]).all()
|
np.equal(result["multimodal_inputs"]["grid_thw"], result2["multimodal_inputs"]["grid_thw"]).all()
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
np.equal(result["multimodal_inputs"]["position_ids"], result2["multimodal_inputs"]["position_ids"]).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_apply_chat_template(self):
|
def test_apply_chat_template(self):
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from unittest.mock import MagicMock, patch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from fastdeploy.engine.request import Request
|
|
||||||
from fastdeploy.input.qwen_vl_processor import QwenVLProcessor
|
from fastdeploy.input.qwen_vl_processor import QwenVLProcessor
|
||||||
from fastdeploy.input.qwen_vl_processor.process_video import sample_frames
|
from fastdeploy.input.qwen_vl_processor.process_video import sample_frames
|
||||||
|
|
||||||
@@ -129,9 +128,9 @@ class TestQwenVLProcessor(unittest.TestCase):
|
|||||||
self.patcher_parse_image.stop()
|
self.patcher_parse_image.stop()
|
||||||
self.patcher_parse_video.stop()
|
self.patcher_parse_video.stop()
|
||||||
|
|
||||||
def test_process_request(self):
|
def test_process_request_dict_with_multimodal(self):
|
||||||
"""
|
"""
|
||||||
Test processing of Request object with multimodal input
|
Test processing of request dict with multimodal input
|
||||||
|
|
||||||
Validates:
|
Validates:
|
||||||
1. Token ID lengths match position_ids and token_type_ids shapes
|
1. Token ID lengths match position_ids and token_type_ids shapes
|
||||||
@@ -153,17 +152,16 @@ class TestQwenVLProcessor(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = Request.from_dict(message)
|
result = self.processor.process_request_dict(message, 1024 * 100)
|
||||||
result = self.processor.process_request(request, 1024 * 100)
|
|
||||||
|
|
||||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
|
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["position_ids"].shape[0])
|
||||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
|
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["token_type_ids"].shape[0])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.multimodal_inputs["images"].shape[0],
|
result["multimodal_inputs"]["images"].shape[0],
|
||||||
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
|
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
|
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_process_request_dict(self):
|
def test_process_request_dict(self):
|
||||||
@@ -246,17 +244,16 @@ class TestQwenVLProcessor(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
request = Request.from_dict(prompt)
|
result = self.processor.process_request_dict(prompt, 1024 * 100)
|
||||||
result = self.processor.process_request(request, 1024 * 100)
|
|
||||||
|
|
||||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["position_ids"].shape[0])
|
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["position_ids"].shape[0])
|
||||||
self.assertEqual(result.prompt_token_ids_len, result.multimodal_inputs["token_type_ids"].shape[0])
|
self.assertEqual(len(result["prompt_token_ids"]), result["multimodal_inputs"]["token_type_ids"].shape[0])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.multimodal_inputs["images"].shape[0],
|
result["multimodal_inputs"]["images"].shape[0],
|
||||||
sum(map(lambda x: x.prod(), result.multimodal_inputs["grid_thw"])),
|
sum(map(lambda x: x.prod(), result["multimodal_inputs"]["grid_thw"])),
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.multimodal_inputs["image_type_ids"].shape[0], result.multimodal_inputs["grid_thw"][:, 0].sum()
|
result["multimodal_inputs"]["image_type_ids"].shape[0], result["multimodal_inputs"]["grid_thw"][:, 0].sum()
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_message_and_prompt(self):
|
def test_message_and_prompt(self):
|
||||||
@@ -298,14 +295,15 @@ class TestQwenVLProcessor(unittest.TestCase):
|
|||||||
"video": [{"video": b"123"}],
|
"video": [{"video": b"123"}],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
request2 = Request.from_dict(prompt)
|
result2 = self.processor.process_request_dict(prompt, 1024 * 100)
|
||||||
result2 = self.processor.process_request(request2, 1024 * 100)
|
|
||||||
|
|
||||||
# Verify both processing methods produce identical results
|
# Verify both processing methods produce identical results
|
||||||
self.assertEqual(result["prompt_token_ids"], result2.prompt_token_ids)
|
self.assertEqual(result["prompt_token_ids"], result2["prompt_token_ids"])
|
||||||
self.assertTrue(np.equal(result["multimodal_inputs"]["grid_thw"], result2.multimodal_inputs["grid_thw"]).all())
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.equal(result["multimodal_inputs"]["position_ids"], result2.multimodal_inputs["position_ids"]).all()
|
np.equal(result["multimodal_inputs"]["grid_thw"], result2["multimodal_inputs"]["grid_thw"]).all()
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
np.equal(result["multimodal_inputs"]["position_ids"], result2["multimodal_inputs"]["position_ids"]).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_apply_chat_template(self):
|
def test_apply_chat_template(self):
|
||||||
|
|||||||
@@ -347,19 +347,9 @@ class DataProcessorTestCase(unittest.TestCase):
|
|||||||
def _load_tokenizer(self):
|
def _load_tokenizer(self):
|
||||||
return DummyTokenizer()
|
return DummyTokenizer()
|
||||||
|
|
||||||
def process_request(self, request, **kwargs):
|
|
||||||
return super().process_request(request, **kwargs)
|
|
||||||
|
|
||||||
def process_response(self, response_dict):
|
|
||||||
return super().process_response(response_dict)
|
|
||||||
|
|
||||||
processor = MinimalProcessor()
|
processor = MinimalProcessor()
|
||||||
defaults = processor._apply_default_parameters({})
|
defaults = processor._apply_default_parameters({})
|
||||||
self.assertAlmostEqual(defaults["top_p"], 0.5)
|
self.assertAlmostEqual(defaults["top_p"], 0.5)
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
processor.process_request({}, max_model_len=None)
|
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
processor.process_response({})
|
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
processor.text2ids("text")
|
processor.text2ids("text")
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
@@ -392,28 +382,28 @@ class DataProcessorTestCase(unittest.TestCase):
|
|||||||
self.assertTrue(processed["enable_thinking"])
|
self.assertTrue(processed["enable_thinking"])
|
||||||
self.assertEqual(processed["prompt_tokens"], "system prompt hello")
|
self.assertEqual(processed["prompt_tokens"], "system prompt hello")
|
||||||
|
|
||||||
def test_process_request_object_handles_sequences(self):
|
def test_process_request_dict_handles_sequences(self):
|
||||||
request = DummyRequest(
|
request = {
|
||||||
prompt=[1, 2, 3, 4, 5, 6],
|
"prompt": [1, 2, 3, 4, 5, 6],
|
||||||
stop=["stop"],
|
"stop": ["stop"],
|
||||||
bad_words=["zz"],
|
"bad_words": ["zz"],
|
||||||
temperature=0,
|
"temperature": 0,
|
||||||
top_p=0,
|
"top_p": 0,
|
||||||
)
|
}
|
||||||
processed = self.processor.process_request(request, max_model_len=5)
|
processed = self.processor.process_request_dict(request, max_model_len=5)
|
||||||
|
|
||||||
self.assertEqual(processed.prompt_token_ids, [1, 2, 3, 4])
|
self.assertEqual(processed["prompt_token_ids"], [1, 2, 3, 4])
|
||||||
self.assertEqual(processed.sampling_params.max_tokens, 1)
|
self.assertEqual(processed["max_tokens"], 1)
|
||||||
self.assertEqual(processed.sampling_params.stop_token_ids, [[4]])
|
self.assertEqual(processed["stop_token_ids"], [[4]])
|
||||||
self.assertEqual(set(processed.sampling_params.bad_words_token_ids), {2, 3})
|
self.assertEqual(set(processed["bad_words_token_ids"]), {2, 3})
|
||||||
self.assertEqual(processed.sampling_params.temperature, 1)
|
self.assertEqual(processed["temperature"], 1)
|
||||||
self.assertEqual(processed.sampling_params.top_k, 1)
|
self.assertEqual(processed["top_k"], 1)
|
||||||
self.assertAlmostEqual(processed.sampling_params.top_p, 1e-5)
|
self.assertAlmostEqual(processed["top_p"], 1e-5)
|
||||||
|
|
||||||
def test_process_request_requires_prompt_or_messages(self):
|
def test_process_request_dict_requires_prompt_or_messages(self):
|
||||||
request = DummyRequest(prompt=None, messages=None, prompt_token_ids=None)
|
request = {"prompt": None, "messages": None, "prompt_token_ids": None}
|
||||||
with self.assertRaisesRegex(ValueError, "should have `input_ids`, `text` or `messages`"):
|
with self.assertRaisesRegex(ValueError, "Request must contain"):
|
||||||
self.processor.process_request(request, max_model_len=5)
|
self.processor.process_request_dict(request, max_model_len=5)
|
||||||
|
|
||||||
def test_process_request_dict_rejects_bad_kwargs(self):
|
def test_process_request_dict_rejects_bad_kwargs(self):
|
||||||
request = {
|
request = {
|
||||||
@@ -458,14 +448,15 @@ class DataProcessorTestCase(unittest.TestCase):
|
|||||||
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer)
|
processor.reasoning_parser = self.create_dummy_reasoning(processor.tokenizer)
|
||||||
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-only")
|
processor.tool_parser_obj = self.create_dummy_tool_parser(processor.tokenizer, content="tool-only")
|
||||||
|
|
||||||
response = SimpleNamespace(
|
response = {
|
||||||
request_id="resp",
|
"request_id": "resp",
|
||||||
outputs=SimpleNamespace(token_ids=[1, processor.tokenizer.eos_token_id]),
|
"finished": True,
|
||||||
)
|
"outputs": {"token_ids": [1, processor.tokenizer.eos_token_id]},
|
||||||
|
}
|
||||||
|
|
||||||
processed = processor.process_response(response)
|
processed = processor.process_response_dict(response, stream=False)
|
||||||
self.assertEqual(processed.outputs.reasoning_content, "think")
|
self.assertEqual(processed["outputs"]["reasoning_content"], "think")
|
||||||
self.assertEqual(processed.outputs.tool_calls, ["tool"])
|
self.assertEqual(processed["outputs"]["tool_calls"], ["tool"])
|
||||||
|
|
||||||
def test_process_response_streaming_clears_state(self):
|
def test_process_response_streaming_clears_state(self):
|
||||||
processor = self.processor
|
processor = self.processor
|
||||||
|
|||||||
@@ -897,37 +897,6 @@ class TestThinkingBudgetSupplemental(unittest.TestCase):
|
|||||||
# 命中 _get_think_token_ids 的缓存分支
|
# 命中 _get_think_token_ids 的缓存分支
|
||||||
self.assertEqual(processor._get_think_token_ids(), (THINKING_START_TOKEN_ID, THINKING_END_TOKEN_ID))
|
self.assertEqual(processor._get_think_token_ids(), (THINKING_START_TOKEN_ID, THINKING_END_TOKEN_ID))
|
||||||
|
|
||||||
def test_text_process_request_think_stop_sentence(self):
|
|
||||||
processor = TextDataProcessor.__new__(TextDataProcessor)
|
|
||||||
processor._apply_default_parameters = lambda request: request
|
|
||||||
processor.eos_token_ids = [1]
|
|
||||||
processor.update_stop_seq = lambda *args, **kwargs: None
|
|
||||||
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
|
|
||||||
processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [101, 102]
|
|
||||||
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
|
|
||||||
processor.reasoning_parser = None
|
|
||||||
|
|
||||||
request = DummyRequestV1(
|
|
||||||
request_id="req_text",
|
|
||||||
eos_token_ids=[1],
|
|
||||||
prompt_token_ids=[8],
|
|
||||||
prompt=None,
|
|
||||||
messages=None,
|
|
||||||
logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
|
|
||||||
bad_words=None,
|
|
||||||
bad_words_token_ids=None,
|
|
||||||
max_tokens=1,
|
|
||||||
temperature=1.0,
|
|
||||||
top_p=0.9,
|
|
||||||
)
|
|
||||||
with patch("fastdeploy.input.text_processor.process_stop_token_ids", lambda *args, **kwargs: None):
|
|
||||||
processed = processor.process_request(request, max_model_len=16)
|
|
||||||
self.assertEqual(
|
|
||||||
processed.logits_processors_args.get("think_stop_sentence_token_ids"),
|
|
||||||
[23, 101, 102],
|
|
||||||
)
|
|
||||||
self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
|
|
||||||
|
|
||||||
def test_text_process_request_dict_think_stop_sentence(self):
|
def test_text_process_request_dict_think_stop_sentence(self):
|
||||||
processor = TextDataProcessor.__new__(TextDataProcessor)
|
processor = TextDataProcessor.__new__(TextDataProcessor)
|
||||||
processor._apply_default_parameters = lambda request: request
|
processor._apply_default_parameters = lambda request: request
|
||||||
|
|||||||
Reference in New Issue
Block a user