[Feature] The 45VL supports prompt_token_ids + messages input. (#5148)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support prompt_token_ids + messages

* fix bug

* refact code structure

* support cache mm items

* refact code structure

* delete test cases

* modify unit test

* add unit test

* add unit test

* fix append

* add check for messages
This commit is contained in:
kxz2002
2025-11-25 23:11:44 +08:00
committed by GitHub
parent 66e096d509
commit 2d787590c4
4 changed files with 601 additions and 21 deletions
+139 -12
View File
@@ -136,7 +136,9 @@ class DataProcessor:
self.video_end = self.VID_END
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
self.image_end_id = self.tokenizer.convert_tokens_to_ids(self.image_end)
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
self.video_end_id = self.tokenizer.convert_tokens_to_ids(self.video_end)
self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token)
self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)
@@ -243,14 +245,7 @@ class DataProcessor:
return outputs
def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat messages into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
def extract_mm_items(self, request: Dict[str, Any]):
messages = parse_chat_messages(request.get("messages"))
mm_items = []
for msg in messages:
@@ -273,6 +268,7 @@ class DataProcessor:
if len(missing_hashes) > 0 and not self.enable_processor_cache:
raise ValueError("Missing items cannot be retrieved without processor cache.")
dealer = None
if self.enable_processor_cache:
context = zmq.Context()
dealer = context.socket(zmq.DEALER)
@@ -295,6 +291,16 @@ class DataProcessor:
video_uuid.append(item["uuid"])
else:
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
return images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items
def request2ids(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
"""
Convert chat messages into model inputs.
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
"""
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat template.")
@@ -329,6 +335,115 @@ class DataProcessor:
return outputs
def prompt_token_ids2outputs(
self, request: Dict[str, Any], tgts: List[str] = None
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
outputs = {
"input_ids": [],
"token_type_ids": [],
"position_ids": [],
"images": [],
"grid_thw": [],
"image_type_ids": [],
"labels": [],
"cur_position": 0,
"video_cnt": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
"mm_positions": [],
"mm_hashes": [],
}
prompt_token_ids = request.get("prompt_token_ids", [])
prompt_token_ids_len = len(prompt_token_ids)
if not request.get("messages"):
outputs["input_ids"].extend(prompt_token_ids)
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * prompt_token_ids_len)
for i in range(prompt_token_ids_len):
outputs["position_ids"].append([i] * 3)
outputs["cur_position"] += prompt_token_ids_len
return outputs
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
st, image_idx, video_idx = 0, 0, 0
while st < prompt_token_ids_len:
cur_token_id = prompt_token_ids[st]
if cur_token_id == self.image_start_id:
if image_idx >= len(images):
raise ValueError("prompt token ids has more image placeholder than in messages")
# append image_start_id
outputs["input_ids"].extend([cur_token_id])
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
outputs["position_ids"].append([outputs["cur_position"]] * 3)
outputs["cur_position"] += 1
st += 1
# process placeholder token ids
cur_idx = st
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.image_end_id:
cur_idx += 1
if cur_idx >= prompt_token_ids_len:
raise ValueError("image token ids not complete")
image = images[image_idx]
uuid = image_uuid[image_idx] if image_uuid else None
token_len = cur_idx - st
if not isinstance(image, tuple):
self._add_image(image, outputs, uuid, token_len)
else:
self._add_processed_image(image, outputs, uuid, token_len)
image_idx += 1
st = cur_idx
elif cur_token_id == self.video_start_id:
if video_idx >= len(videos):
raise ValueError("prompt token ids has more video placeholder than in messages")
# append video_start_id
outputs["input_ids"].extend([cur_token_id])
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
outputs["position_ids"].append([outputs["cur_position"]] * 3)
outputs["cur_position"] += 1
st += 1
# process placeholder token ids
cur_idx = st
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.video_end_id:
cur_idx += 1
if cur_idx >= prompt_token_ids_len:
raise ValueError("video token ids not complete")
video = videos[video_idx]
uuid = video_uuid[video_idx] if video_uuid else None
token_len = cur_idx - st
if not isinstance(video, tuple):
if isinstance(video, dict):
frames = self._load_and_process_video(video["video"], video)
else:
frames = self._load_and_process_video(video, {})
self._add_video(frames, outputs, uuid, token_len)
else:
self._add_processed_video(video, outputs, uuid, token_len)
video_idx += 1
st = cur_idx
else:
outputs["input_ids"].extend([cur_token_id])
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
outputs["position_ids"].append([outputs["cur_position"]] * 3)
outputs["cur_position"] += 1
st += 1
if image_idx != len(images):
raise ValueError("number of images does not match")
if video_idx != len(videos):
raise ValueError("number of videos does not match")
if self.enable_processor_cache:
missing_idx = set(missing_idx)
hashes_to_cache, items_to_cache = [], []
for idx in range(len(mm_items)):
if idx in missing_idx:
continue
meta = {}
t, h, w = outputs["grid_thw"][idx][0]
meta["thw"] = (t, h, w)
hashes_to_cache.append(outputs["mm_hashes"][idx])
items_to_cache.append((outputs["images"][idx], meta))
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
return outputs
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token)
outputs["input_ids"].append(token_id)
@@ -348,7 +463,7 @@ class DataProcessor:
outputs["position_ids"].append([start + i] * 3)
outputs["cur_position"] += len(tokens)
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
def _add_image(self, img, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
img.height,
img.width,
@@ -356,6 +471,8 @@ class DataProcessor:
max_pixels=self.image_max_pixels,
)[1]
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
if token_len and token_len != num_tokens:
raise ValueError("image tokens num not match the size")
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
@@ -383,9 +500,13 @@ class DataProcessor:
outputs["grid_thw"].append(ret["image_grid_thw"])
outputs["image_type_ids"].append(0)
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
def _add_processed_image(
self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
) -> None:
img, meta = img_cache
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
if token_len and num_tokens != token_len:
raise ValueError("image tokens num not match the size")
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
@@ -401,7 +522,7 @@ class DataProcessor:
outputs["grid_thw"].append(np.array([[1, h, w]]))
outputs["image_type_ids"].append(0)
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
def _add_video(self, frames, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
frames[0].height,
frames[0].width,
@@ -410,6 +531,8 @@ class DataProcessor:
)[1]
num_frames = len(frames)
num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
if token_len and num_tokens != token_len:
raise ValueError("video tokens num not match the size")
pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
ret = self.image_preprocessor.preprocess(
@@ -438,9 +561,13 @@ class DataProcessor:
outputs["position_ids"].extend(pos_ids)
outputs["cur_position"] = np.max(pos_ids) + 1
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
def _add_processed_video(
self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
) -> None:
frames, meta = frames_cache
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
if token_len and num_tokens != token_len:
raise ValueError("video tokens num not match the size")
t, h, w = meta["thw"]
outputs["images"].append(frames)