[Optimization] The pre- and post-processing pipeline do not perform dict conversion (#5494)

* to_request_for_infer initial commit

* refact to from_chat_completion_request

* preprocess use request initial commit

* bugfix

* processors refact to using request

* bug fix

* refact Request from_generic_request

* post process initial commit

* bugfix

* postprocess second commit

* bugfix

* serving_embedding initial commit

* serving_reward initial commit

* bugfix

* replace function name

* async_llm initial commit

* offline initial commit and fix bug

* bugfix

* fix async_llm

* remove add speculate_metrics into data

* fix logprobs bug

* fix echo bug

* fix bug

* fix reasoning_max_tokens

* bugfix

* bugfix and modify unittest

* bugfix and modify unit test

* bugfix

* bugfix

* bugfix

* modify unittest

* fix error when reasong_content is none for text_processor

* remove some unnessary logic

* revert removed logic

* implement add and set method for RequestOutput and refact code

* modify unit test

* modify unit test

* union process_request and process_request_obj

* remove a unit test

* union process_response and process_response_obj

* support qwen3_vl_processor

* modify unittest and remove comments

* fix prompt_logprobs

* fix codestyle

* add v1

* v1

* fix unit test

* fix unit test

* fix pre-commit

* fix

* add process request

* add process request

* fix

* fix

* fix unit test

* fix unit test

* fix unit test

* fix unit test

* fix unit test

* remove file

* add unit test

* add unit test

* add unit test

* fix unit test

* fix unit test

* fix

* fix

---------

Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com>
Co-authored-by: luukunn <981429396@qq.com>
Co-authored-by: luukunn <83932082+luukunn@users.noreply.github.com>
Co-authored-by: Zhang Yulong <35552275+ZhangYulongg@users.noreply.github.com>
This commit is contained in:
kxz2002
2026-01-22 00:50:52 +08:00
committed by GitHub
parent fe5ba4b509
commit 6e416c62dd
66 changed files with 16614 additions and 739 deletions
@@ -0,0 +1,20 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .paddleocr_vl_processor import PaddleOCRVLProcessor
from .process import DataProcessor
__all__ = ["DataProcessor", "PaddleOCRVLProcessor"]
@@ -0,0 +1,275 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""Image processor class for Keye."""
# TODO: Support videos
import json
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
from paddleformers.transformers.feature_extraction_utils import BatchFeature
from paddleformers.transformers.image_processing_utils import BaseImageProcessor
from paddleformers.transformers.image_utils import (
ImageInput,
is_valid_image,
make_list_of_images,
to_numpy_array,
)
_OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
_OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.
Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images
elif is_valid_image(images):
return [images]
raise ValueError(f"Could not make batched images from {images}")
def adjust_size(size, patch_size):
num_patches = size // patch_size
if num_patches % 2 != 0:
num_patches -= 1
return num_patches * patch_size
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 28 * 28 * 130,
max_pixels: int = 28 * 28 * 1280,
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
# if height < factor or width < factor:
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
# if int(height < factor//4) + int(width < factor//4):
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor//4}")
if height < factor:
logging.debug(f"smart_resize: height={height} < factor={factor}, reset height=factor")
width = round((width * factor) / height)
height = factor
if width < factor:
logging.debug(f"smart_resize: width={width} < factor={factor}, reset width=factor")
height = round((height * factor) / width)
width = factor
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
class ImageProcessor(BaseImageProcessor):
model_input_names = [
"pixel_values",
"image_grid_thw",
"pixel_values_videos",
"video_grid_thw",
]
def __init__(
self,
do_resize: bool = True,
resample: int = 3,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
min_pixels: int = 28 * 28 * 130,
max_pixels: int = 28 * 28 * 1280,
patch_size: int = 14,
temporal_patch_size: int = 1,
merge_size: int = 2,
**kwargs,
) -> None:
super().__init__()
self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else _OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else _OPENAI_CLIP_STD
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} # not used
self.do_convert_rgb = do_convert_rgb
@classmethod
def from_pretrained(cls, pretrained_model_dir):
pretrained_model_dir = Path(pretrained_model_dir)
image_processor_config_path = pretrained_model_dir / "preprocessor_config.json"
with open(image_processor_config_path, "r", encoding="utf-8") as f:
image_processor_config = json.load(f)
return cls(**image_processor_config)
def _preprocess(
self,
images,
do_resize: Optional[bool] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
):
images = make_list_of_images(images)
if do_convert_rgb:
images = [image.convert("RGB") for image in images]
width, height = images[0].size
resized_height, resized_width = height, width
processed_images = []
for image in images:
if do_resize:
resized_height, resized_width = smart_resize(
height,
width,
factor=self.patch_size * self.merge_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
image = image.resize((resized_width, resized_height), resample=self.resample)
image = to_numpy_array(image)
if do_rescale:
image = (image * rescale_factor).astype(np.float32)
if do_normalize:
image = image.astype(np.float32)
image -= np.array(image_mean, dtype=np.float32)
image /= np.array(image_std, dtype=np.float32)
processed_images.append(image)
patches = np.array(processed_images)
patches = patches.transpose(0, 3, 1, 2)
if patches.shape[0] == 1:
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
channel = patches.shape[1]
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h, grid_w = (
resized_height // self.patch_size,
resized_width // self.patch_size,
)
patches = patches.reshape(
grid_t,
self.temporal_patch_size,
channel,
grid_h,
self.patch_size,
grid_w,
self.patch_size,
)
patches = patches.transpose(0, 3, 5, 2, 1, 4, 6)
assert self.temporal_patch_size == 1
flatten_patches = patches.reshape(grid_t * grid_h * grid_w, channel, self.patch_size, self.patch_size)
return flatten_patches, np.array([grid_t, grid_h, grid_w])
def preprocess(
self,
images,
videos=None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
return_tensors=None,
):
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
if videos is not None:
raise NotImplementedError("Videos are not yet supported")
patches, image_grid_thw = self._preprocess(
images,
do_resize=do_resize,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_convert_rgb=do_convert_rgb,
)
pixel_values = np.array(patches)
data = {"pixel_values": pixel_values, "grid_thw": image_grid_thw}
return BatchFeature(data=data, tensor_type=return_tensors)
@@ -0,0 +1,320 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import numpy as np
from fastdeploy.engine.request import Request
from fastdeploy.input.v1.text_processor import DataProcessor as TextProcessor
from fastdeploy.utils import data_processor_logger
from .process import DataProcessor
_SAMPLING_EPS = 1e-5
from fastdeploy.input.utils import process_stop_token_ids
class PaddleOCRVLProcessor(TextProcessor):
"""
PaddleOCR Vision-Language processor for handling multimodal inputs.
This processor extends TextProcessor to support:
- Image processing
- Multimodal feature extraction
- Tokenization and position encoding
- Request processing and model input generation
Attributes:
processor (DataProcessor): Underlying data processor instance
tokenizer: Text tokenizer instance
limit_mm_per_prompt (dict): Limits for multimodal inputs per prompt
"""
def __init__(
self,
config,
model_name_or_path,
limit_mm_per_prompt=None,
mm_processor_kwargs=None,
reasoning_parser_obj=None,
tool_parser_obj=None,
enable_processor_cache=False,
):
"""
Initialize PaddleOCRVLProcessor instance.
Args:
config: Model configuration object
model_name_or_path (str): Pretrained model name or path
limit_mm_per_prompt (dict, optional): Limits for multimodal inputs
mm_processor_kwargs (dict, optional): Multimodal processor arguments
reasoning_parser_obj: Reasoning parser instance
tool_parser_obj: Tool parser instance
"""
super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj)
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs)
self.processor = DataProcessor(
model_path=model_name_or_path,
enable_processor_cache=enable_processor_cache,
tokens_per_second=config.vision_config.tokens_per_second,
tokenizer=self.tokenizer,
**processor_kwargs,
)
self.image_patch_id = self.processor.image_patch_id
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
def process_request(self, request, max_model_len=None, **kwargs):
"""
Process incoming request and generate model inputs.
Args:
request: Input request object
max_model_len (int, optional): Maximum context length
**kwargs: Additional processing parameters
Returns:
Request: Processed request with model inputs
"""
task = request.to_dict()
task["enable_thinking"] = kwargs.get("enable_thinking", False)
self.process_request_dict(task, max_model_len)
request = Request.from_dict(task)
request = self._apply_default_parameters(request)
return request
def _parse_processor_kwargs(self, kwargs):
"""
Parse and validate multimodal processor arguments.
Args:
kwargs (dict): Processor configuration arguments
Returns:
dict: Validated processor arguments
Raises:
ValueError: If arguments format is invalid
"""
if not kwargs:
return {}
try:
if not isinstance(kwargs, dict):
raise ValueError("mm-processor-kwargs must be a dictionary")
# Validate kwargs types against expected schema
data_processor_logger.info(f"Processing kwargs: {kwargs}")
expected_types = {
"video_max_frames": int, # Maximum video frames parameter
"video_min_frames": int, # Minimum video frames parameter
}
for key, value in kwargs.items():
if key in expected_types and not isinstance(value, expected_types[key]):
raise ValueError(
f"Invalid type for {key}: expected {expected_types[key].__name__}, got {type(value).__name__}"
)
return kwargs
except Exception as e:
data_processor_logger.warning(f"Invalid mm-processor-kwargs format: {e}")
return {}
def _parse_limits(self, limits):
"""
Parse and validate multimodal input limits.
Args:
limits (dict): Input limits configuration
Returns:
dict: Validated limits with defaults
Raises:
ValueError: If limits format is invalid
"""
DEFAULT_LIMITS = {"image": 1, "video": 1, "audio": 1}
if not limits:
return DEFAULT_LIMITS
try:
if not isinstance(limits, dict):
raise ValueError("limit-mm-per-prompt must be a dictionary")
data_processor_logger.info(f"_parse_limits:{limits}")
return {**DEFAULT_LIMITS, **limits}
except Exception as e:
data_processor_logger.warning(f"Invalid limit-mm-per-prompt format: {e}, using default limits")
return DEFAULT_LIMITS
def _check_mm_limits(self, item):
"""
Validate multimodal inputs against configured limits.
Args:
item: Input request item to validate
Raises:
ValueError: If input exceeds configured limits
"""
if isinstance(item, dict):
# 请求包含prompt和multi_modal_data
mm_data = item
else:
# 请求包含messages
mm_data = {"image": [], "video": []}
for message in item:
if isinstance(message.get("content"), list):
for part in message["content"]:
if part.get("type") in ["image_url", "image"]:
mm_data["image"].append(part)
elif part.get("type") in ["video_url", "video"]:
mm_data["video"].append(part)
for modality, data in mm_data.items():
if modality in self.limit_mm_per_prompt:
limit = self.limit_mm_per_prompt[modality]
if len(data) > limit:
raise ValueError(f"Too many {modality} items in prompt, " f"got {len(data)} but limit is {limit}")
def process_request_dict(self, request, max_model_len=None, **kwargs):
"""
Process request dictionary into model inputs.
Args:
request (dict): Input request dictionary
max_model_len (int, optional): Maximum context length
Returns:
dict: Processed request with model inputs
Raises:
ValueError: If request format is invalid
"""
request = self._apply_default_parameters(request)
if not request.eos_token_ids:
request.eos_token_ids = self.eos_token_ids
# processing stop_sequences and stop_token_ids
process_stop_token_ids(request, self.update_stop_seq)
if request.prompt:
multimodal_data = request.multimodal_data
if multimodal_data is None:
multimodal_data = {}
self._check_mm_limits(multimodal_data)
images = multimodal_data.get("image", None)
videos = multimodal_data.get("video", None)
outputs = self.processor.text2ids(request.prompt, images, videos)
elif request.messages:
messages = request.messages
self._check_mm_limits(messages)
outputs = self.processor.request2ids(request)
else:
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
metadata = request.metadata
# Handle continuation of previous generation by appending existing tokens
if metadata and metadata.get("generated_token_ids"):
self.append_generated_tokens(outputs, metadata["generated_token_ids"])
outputs = self.pack_outputs(outputs)
request.prompt_token_ids = outputs["input_ids"].tolist()
request.prompt_token_ids_len = len(request.prompt_token_ids)
request.multimodal_inputs = outputs
# Handle prompt truncation if exceeds model context length
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[
: max_model_len - 1
] # Leave space for at least 1 new token
# Set default max_tokens if not specified
if request.sampling_params.max_tokens is None:
request.sampling_params.max_tokens = max(
1, max_model_len - len(request.prompt_token_ids)
) # Ensure at least 1 token
if request.sampling_params.top_p is not None and request.sampling_params.top_p < _SAMPLING_EPS:
request.sampling_params.top_p = _SAMPLING_EPS
if self.reasoning_parser:
model_status = self.reasoning_parser.get_model_status(request.prompt_token_ids)
parts = request.request_id.split("_")
if len(parts) > 1:
real_req_id = parts[0]
index = int(parts[1])
n = request.get("n", 1)
for idx in range(index * n, (index + 1) * n):
self.model_status_dict[f"{real_req_id}_{idx}"] = model_status
else:
self.model_status_dict[request.request_id] = model_status
request.enable_thinking = model_status == "think_start"
return request
def append_generated_tokens(self, multimodal_inputs, generated_token_ids):
"""
Append generated tokens to existing outputs.
Args:
outputs: Current model outputs
generated_token_ids: Generated tokens to append
"""
num_tokens = len(generated_token_ids)
multimodal_inputs["input_ids"].extend(generated_token_ids)
multimodal_inputs["token_type_ids"].extend([0] * num_tokens)
pos_ids = self.processor._compute_text_positions(multimodal_inputs["cur_position"], num_tokens)
multimodal_inputs["position_ids"].append(pos_ids)
multimodal_inputs["cur_position"] += num_tokens
def pack_outputs(self, outputs):
"""
Prepare final output dictionary for model.
Args:
outputs: Intermediate processing outputs
Returns:
dict: Packed output dictionary with all required fields
"""
if not outputs["images"]:
outputs["images"] = None # No images case
outputs["grid_thw"] = None # No spatial dimensions
outputs["image_type_ids"] = None # No type IDs
else:
outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically
outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions
outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array
# Convert all outputs to numpy arrays with appropriate types
outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64
outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64
outputs["position_ids"] = np.concatenate(
outputs["position_ids"], axis=1, dtype=np.int64
) # Concatenate position ID
outputs["image_patch_id"] = self.processor.image_token_id
outputs["video_patch_id"] = self.processor.video_token_id
outputs["position_ids"] = outputs["position_ids"].transpose(1, 0)
outputs["mm_num_token_func"] = self.processor.mm_num_tokens
return outputs
@@ -0,0 +1,622 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import pickle
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import paddle
import zmq
from paddleformers.transformers import AutoTokenizer
from PIL import Image
from fastdeploy.engine.request import ImagePosition, Request
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.ernie4_5_vl_processor import read_video_decord
from fastdeploy.input.mm_data_processor import MMBaseDataProcessor
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.utils import data_processor_logger
from .image_processor import ImageProcessor
from .process_video import sample_frames
class DataProcessor(MMBaseDataProcessor):
"""
Processes multimodal inputs (text, images, videos) into model-ready formats.
Handles:
- Tokenization of text with special tokens for visual content
- Image and video preprocessing
- Generation of 3D positional embeddings
- Conversion of chat messages to model inputs
Attributes:
tokenizer: Text tokenizer instance
image_processor: Image/video preprocessor
image_token: Special token for image placeholders
video_token: Special token for video placeholders
vision_start: Token marking start of visual content
"""
def __init__(
self,
model_path: str,
enable_processor_cache: bool = False,
video_min_frames: int = 4,
video_max_frames: int = 768,
video_target_frames: int = -1,
video_fps: int = -1,
tokens_per_second: int = 2,
tokenizer=None,
**kwargs,
) -> None:
"""
Initialize the data processor.
Args:
model_path: Path to pretrained model
video_min_frames: Minimum frames to sample from videos
video_max_frames: Maximum frames to sample from videos
tokens_per_second: Temporal resolution for positional embeddings
**kwargs: Additional configuration
"""
super().__init__()
self.min_frames = video_min_frames
self.max_frames = video_max_frames
self.target_frames = video_target_frames
self.fps = video_fps
# Initialize tokenizer with left padding and fast tokenizer
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left", use_fast=True)
self.tokenizer.ignored_index = -100 # Set ignored index for loss calculation
else:
self.tokenizer = tokenizer
self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor
self.enable_processor_cache = enable_processor_cache
# Convolution sizes for patch aggregation
self.spatial_conv_size = self.image_processor.merge_size
self.temporal_conv_size = self.image_processor.temporal_patch_size
# Special tokens and IDs
self.image_token = "<|IMAGE_PLACEHOLDER|>"
self.video_token = "<|video_pad|>"
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
self.image_patch_id = self.image_token_id
self.vision_start = "<|IMAGE_START|>"
self.vision_start_id = self.tokenizer.convert_tokens_to_ids(self.vision_start)
self.tokens_per_second = tokens_per_second
self.role_prefixes = {
"system": "",
"user": "User: ",
"bot": "Assistant: ",
"assistant": "Assistant: ",
}
@staticmethod
def mm_num_tokens(grid_thw: list | list[list[int]] | np.ndarray | paddle.Tensor) -> int | list[int]:
"""
Calculate the number of tokens in the multimodal input.
"""
if isinstance(grid_thw, paddle.Tensor):
grid_thw = grid_thw.numpy()
if len(grid_thw) == 0:
return 0
def calc_one(thw):
t, h, w = map(int, thw)
return t * h * w // 4
if isinstance(grid_thw[0], (list, tuple, np.ndarray)):
return [calc_one(x) for x in grid_thw]
return calc_one(grid_thw)
def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None):
"""
Convert text with image/video placeholders into model inputs.
Args:
text: Input text with <|image@placeholder|> and <|video@placeholder|> markers
images: List of PIL Images corresponding to image placeholders
videos: List of video data corresponding to video placeholders
image_uuid: List of unique identifiers for each image, used for caching or hashing.
video_uuid: List of unique identifiers for each video, used for caching or hashing.
Returns:
Dict containing:
- input_ids: Token IDs
- token_type_ids: Type identifiers (text/image/video)
- position_ids: 3D positional embeddings
- images: Preprocessed visual features
- grid_thw: Spatial/temporal dimensions
- image_type_ids: Visual content type (0=image, 1=video)
"""
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"fps": [],
"mm_positions": [],
"mm_hashes": [],
"vit_seqlen": [],
"vit_position_ids": [],
}
# Define placeholders and their lengths
IMAGE_PLACEHOLDER = self.image_token
VIDEO_PLACEHOLDER = self.video_token
IMAGE_PLACEHOLDER_LEN = len(IMAGE_PLACEHOLDER)
VIDEO_PLACEHOLDER_LEN = len(VIDEO_PLACEHOLDER)
# Initialize tracking variables for text parsing
st, image_idx, video_idx = 0, 0, 0 # Start position, image counter, video counter
while st < len(text):
# Find next image or video placeholder in text
image_pos = text.find(IMAGE_PLACEHOLDER, st)
image_pos = len(text) if image_pos == -1 else image_pos # Set to end if not found
video_pos = text.find(VIDEO_PLACEHOLDER, st)
video_pos = len(text) if video_pos == -1 else video_pos # Set to end if not found
ed = min(image_pos, video_pos) # End position is first placeholder found
self._add_text(text[st:ed], outputs)
if ed == len(text):
break
if ed == image_pos:
image = images[image_idx]
uuid = image_uuid[image_idx] if image_uuid else None
if not isinstance(image, tuple):
self._add_image(image, outputs, uuid)
else:
self._add_processed_image(image, outputs, uuid)
image_idx += 1
st = ed + IMAGE_PLACEHOLDER_LEN
else:
item = videos[video_idx]
uuid = video_uuid[video_idx] if video_uuid else None
if not isinstance(item, tuple):
if isinstance(item, dict):
frames, meta = self._load_and_process_video(item["video"], item)
else:
frames, meta = self._load_and_process_video(item, {})
self._add_video(frames, meta, outputs, uuid)
else:
# cached frames are already processed
self._add_processed_video(item, outputs, uuid)
video_idx += 1
st = ed + VIDEO_PLACEHOLDER_LEN
return outputs
def request2ids(
self, request: Request, tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat request with multimodal messages into model inputs.
Args:
request: Dictionary containing:
- messages: List of chat messages with text/image/video content
- request_id: Unique identifier for logging
tgts: Optional target sequences
Returns:
Dict with same structure as text2ids() output
"""
# Parse and validate chat messages
messages = parse_chat_messages(request.messages)
mm_items = []
for msg in messages:
role = msg.get("role")
assert role in self.role_prefixes, f"Unsupported role: {role}"
# Normalize content to list format
content = msg.get("content")
if not isinstance(content, list):
content = [content]
# Collect all visual content items
for item in content:
if item.get("type") in ["image", "video"]:
mm_items.append(item)
missing_hashes, missing_idx = [], []
for idx, item in enumerate(mm_items):
if not item.get("data"):
# raw data not provided, should be retrieved from processor cache
missing_hashes.append(item.get("uuid"))
missing_idx.append(idx)
if len(missing_hashes) > 0 and not self.enable_processor_cache:
raise ValueError("Missing items cannot be retrieved without processor cache.")
if self.enable_processor_cache:
context = zmq.Context()
dealer = context.socket(zmq.DEALER)
dealer.connect("ipc:///dev/shm/processor_cache.ipc")
missing_items = self.get_processor_cache(dealer, missing_hashes)
for idx in range(len(missing_items)):
if not missing_items[idx]:
raise ValueError(f"Missing item {idx} not found in processor cache")
mm_items[missing_idx[idx]]["data"] = missing_items[idx]
images, videos = [], []
image_uuid, video_uuid = [], []
for item in mm_items:
if item.get("type") == "image":
images.append(item["data"])
image_uuid.append(item["uuid"])
elif item.get("type") == "video":
videos.append(item["data"])
video_uuid.append(item["uuid"])
else:
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat template.")
chat_template_kwargs = request.chat_template_kwargs if request.chat_template_kwargs else {}
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=request.add_generation_prompt if request.add_generation_prompt is not None else True,
**chat_template_kwargs,
)
request.prompt_tokens = prompt
outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid)
if self.enable_processor_cache:
missing_idx = set(missing_idx)
hashes_to_cache, items_to_cache = [], []
for idx in range(len(mm_items)):
if idx in missing_idx:
continue
meta = {}
t, h, w = outputs["grid_thw"][idx]
meta["thw"] = (t, h, w)
meta["fps"] = outputs["fps"][idx]
hashes_to_cache.append(outputs["mm_hashes"][idx])
items_to_cache.append((outputs["images"][idx], meta))
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
return outputs
def _add_text(self, tokens, outputs: Dict) -> None:
"""
Add text tokens to model inputs dictionary.
Args:
tokens: Text string or already tokenized IDs
outputs: Dictionary accumulating model inputs
Note:
- Handles both raw text and pre-tokenized inputs
- Updates position IDs for 3D embeddings
"""
if not tokens:
return None
if isinstance(tokens, str):
tokens_str = self.tokenizer.tokenize(tokens)
tokens = self.tokenizer.convert_tokens_to_ids(tokens_str)
num_tokens = len(tokens)
outputs["input_ids"].extend(tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
pos_ids = self._compute_text_positions(outputs["cur_position"], num_tokens)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray:
"""
Generate 3D positional embeddings for text tokens.
Args:
start_pos: Starting position index
num_tokens: Number of tokens to generate positions for
Returns:
numpy.ndarray: 3D position IDs shaped (3, num_tokens)
"""
text_array = np.arange(num_tokens).reshape(1, -1)
text_index = np.broadcast_to(text_array, (3, num_tokens))
position = text_index + start_pos
return position
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
"""
Add image data to model inputs dictionary.
Args:
img: PIL Image to process
outputs: Dictionary accumulating model inputs
Note:
- Preprocesses image and calculates spatial dimensions
- Adds image token IDs and type markers
- Generates appropriate position embeddings
"""
ret = self.image_processor.preprocess(images=[img.convert("RGB")])
num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2
grid_thw = ret["grid_thw"].tolist()
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
outputs["num_input_image_tokens"] += int(num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(grid_thw)
outputs["image_type_ids"].append(0)
# position_ids
t, h, w = grid_thw
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["fps"].append(0)
numel = h * w
outputs["vit_seqlen"].append(numel)
outputs["vit_position_ids"].append(np.arange(numel) % numel)
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
img, meta = img_cache
num_tokens = img.shape[0] // self.image_processor.merge_size**2
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens)
_, h, w = meta["thw"]
pos_ids = self._compute_vision_positions(outputs["cur_position"], 1, h, w, 0)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["images"].append(img)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[1, h, w]]))
outputs["image_type_ids"].append(0)
outputs["fps"].append(0)
def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) -> None:
"""
Add video data to model inputs dictionary.
Args:
frames: Video frames as numpy array
meta: Video metadata containing fps/duration
outputs: Dictionary accumulating model inputs
Note:
- Handles temporal dimension in position embeddings
- Uses video-specific token IDs and type markers
"""
ret = self.image_processor.preprocess(images=frames)
num_tokens = ret["image_grid_thw"].prod() // self.image_processor.merge_size**2
grid_thw = ret["image_grid_thw"].tolist()
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.video_token_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["num_input_video_tokens"] += int(num_tokens)
outputs["images"].append(ret["pixel_values"])
if not uuid:
outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"]))
else:
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(grid_thw)
outputs["image_type_ids"].extend([1] * grid_thw[0])
fps = meta["fps"]
second_per_grid_t = self.temporal_conv_size / fps
t, h, w = grid_thw
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["fps"].append(fps)
numel = h * w
outputs["vit_seqlen"].append(numel)
outputs["vit_position_ids"].append(np.arange(numel) % numel)
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
frames, meta = frames_cache
num_tokens = frames.shape[0] // self.image_processor.merge_size**2
t, h, w = meta["thw"]
outputs["images"].append(frames)
outputs["mm_hashes"].append(uuid)
outputs["grid_thw"].append(np.array([[t, h, w]]))
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens)
outputs["image_type_ids"].extend([1] * t)
fps = meta["fps"]
second_per_grid_t = self.temporal_conv_size / fps
pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t)
outputs["position_ids"].append(pos_ids)
outputs["cur_position"] = pos_ids.max() + 1
outputs["fps"].append(fps)
def _compute_vision_positions(
self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float
) -> np.ndarray:
"""
Generate 3D position IDs for visual inputs.
Args:
start_pos: Base position in sequence
t: Temporal patches (1 for images)
h: Height in patches
w: Width in patches
second_per_grid_t: Time per temporal patch
Returns:
np.ndarray: Position IDs for [t,h,w] dimensions
"""
h //= self.spatial_conv_size
w //= self.spatial_conv_size
tn = np.arange(t).reshape(-1, 1)
tn = np.broadcast_to(tn, (t, h * w))
tn = tn * int(second_per_grid_t) * self.tokens_per_second
t_index = tn.flatten()
hn = np.arange(h).reshape(1, -1, 1)
h_index = np.broadcast_to(hn, (t, h, w)).flatten()
wn = np.arange(w).reshape(1, 1, -1)
w_index = np.broadcast_to(wn, (t, h, w)).flatten()
position = np.stack([t_index, h_index, w_index]) + start_pos
return position
def _load_and_process_video(self, url: str, item: Dict) -> Tuple[np.ndarray, Dict]:
"""
Load and preprocess video into frames.
Args:
url: Video file path or bytes
item: Dictionary containing processing parameters
Returns:
tuple: (frames, metadata) where:
- frames: Processed video frames as numpy array
- metadata: Updated video metadata dictionary
"""
reader, meta, _ = read_video_decord(url, save_to_disk=False)
# Apply frame sampling if fps or target_frames specified
fps = item.get("fps", self.fps)
num_frames = item.get("target_frames", self.target_frames)
frame_indices = list(range(meta["num_of_frame"]))
if fps > 0 or num_frames > 0:
# Get frame sampling constraints
min_frames = item.get("min_frames", self.min_frames)
max_frames = item.get("max_frames", self.max_frames)
# Sample frames according to specifications
frame_indices = sample_frames(
frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size
min_frames=min_frames,
max_frames=max_frames,
metadata=meta,
fps=fps,
num_frames=num_frames,
)
# Update metadata with new frame count and fps
meta["num_of_frame"] = len(frame_indices)
if fps is not None:
meta["fps"] = fps # Use specified fps
meta["duration"] = len(frame_indices) / fps
else:
meta["fps"] = len(frame_indices) / meta["duration"] # Calculate fps from sampled frames
frames = []
for idx in frame_indices:
frame = reader[idx].asnumpy()
image = Image.fromarray(frame, "RGB")
frames.append(image)
frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
return frames, meta
def get_processor_cache(self, socket, mm_hashes: list[str]) -> list:
"""
get cache correspond to given hash values
"""
req = pickle.dumps(mm_hashes)
socket.send_multipart([b"", req])
_, resp = socket.recv_multipart()
mm_items = pickle.loads(resp)
data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}")
return mm_items
def update_processor_cache(self, socket, mm_hashes: list[str], mm_items):
"""
update cache data
"""
req = pickle.dumps((mm_hashes, mm_items))
socket.send_multipart([b"", req])
data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}")
def apply_chat_template(self, request):
"""
Apply chat template to convert messages into token sequence.
Args:
request: Dictionary containing chat messages
Returns:
List of token IDs
Raises:
ValueError: If model doesn't support chat templates
"""
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
raw_prompt = self.tokenizer.apply_chat_template(
request["messages"],
tokenize=False,
add_generation_prompt=request.get("add_generation_prompt", True),
chat_template=request.get("chat_template", None),
)
prompt_token_str = raw_prompt.replace(self.image_token, "").replace(self.video_token, "")
request["text_after_process"] = raw_prompt
tokens = self.tokenizer.tokenize(prompt_token_str)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info(
f"req_id:{request.get('request_id', ''), } prompt: {raw_prompt} tokens: {tokens}, token_ids: {token_ids}"
)
return token_ids
@@ -0,0 +1,82 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import math
from typing import Optional, Union
import numpy as np
def sample_frames(
frame_factor: int,
min_frames: int,
max_frames: int,
metadata: Optional[dict] = None,
fps: Optional[Union[int, float]] = None,
num_frames: Optional[int] = None,
):
"""
Sample frames from video according to specified criteria.
Args:
frame_factor: Ensure sampled frames are multiples of this factor
min_frames: Minimum number of frames to sample
max_frames: Maximum number of frames to sample
metadata: Video metadata containing fps information
fps: Target frames per second for sampling
num_frames: Exact number of frames to sample
Returns:
np.ndarray: Sampled video frames
Raises:
ValueError: If both fps and num_frames are specified,
or if required metadata is missing,
or if requested frames exceed available frames
"""
if fps > 0 and num_frames > 0:
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
total_num_frames = metadata["num_of_frame"]
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames > 0:
num_frames = round(num_frames / frame_factor) * frame_factor
elif fps > 0:
if metadata is None:
raise ValueError(
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
)
max_frames = math.floor(min(max_frames, total_num_frames) / frame_factor) * frame_factor
num_frames = total_num_frames / metadata["fps"] * fps
num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames)
num_frames = math.floor(num_frames / frame_factor) * frame_factor
if num_frames > total_num_frames:
raise ValueError(
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
"Decrease `num_frames` or `fps` for sampling."
)
# Calculate frame indices based on sampling strategy
if num_frames > 0:
# Evenly spaced sampling for target frame count
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32)
else:
# Keep all frames if no sampling requested
indices = np.arange(0, total_num_frames).astype(np.int32)
return indices