mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[Feature] The 45VL supports prompt_token_ids + messages input. (#5148)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* support prompt_token_ids + messages * fix bug * refact code structure * support cache mm items * refact code structure * delete test cases * modify unit test * add unit test * add unit test * fix append * add check for messages
This commit is contained in:
@@ -136,7 +136,9 @@ class DataProcessor:
|
||||
self.video_end = self.VID_END
|
||||
self.image_patch_id = self.tokenizer.convert_tokens_to_ids("<|IMAGE_PLACEHOLDER|>")
|
||||
self.image_start_id = self.tokenizer.convert_tokens_to_ids(self.image_start)
|
||||
self.image_end_id = self.tokenizer.convert_tokens_to_ids(self.image_end)
|
||||
self.video_start_id = self.tokenizer.convert_tokens_to_ids(self.video_start)
|
||||
self.video_end_id = self.tokenizer.convert_tokens_to_ids(self.video_end)
|
||||
self.sep_token_id = self.tokenizer.convert_tokens_to_ids(self.sep_token)
|
||||
self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)
|
||||
|
||||
@@ -243,14 +245,7 @@ class DataProcessor:
|
||||
|
||||
return outputs
|
||||
|
||||
def request2ids(
|
||||
self, request: Dict[str, Any], tgts: List[str] = None
|
||||
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
||||
"""
|
||||
Convert chat messages into model inputs.
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
"""
|
||||
|
||||
def extract_mm_items(self, request: Dict[str, Any]):
|
||||
messages = parse_chat_messages(request.get("messages"))
|
||||
mm_items = []
|
||||
for msg in messages:
|
||||
@@ -273,6 +268,7 @@ class DataProcessor:
|
||||
if len(missing_hashes) > 0 and not self.enable_processor_cache:
|
||||
raise ValueError("Missing items cannot be retrieved without processor cache.")
|
||||
|
||||
dealer = None
|
||||
if self.enable_processor_cache:
|
||||
context = zmq.Context()
|
||||
dealer = context.socket(zmq.DEALER)
|
||||
@@ -295,6 +291,16 @@ class DataProcessor:
|
||||
video_uuid.append(item["uuid"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported multimodal type: {item.get('type')}")
|
||||
return images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items
|
||||
|
||||
def request2ids(
|
||||
self, request: Dict[str, Any], tgts: List[str] = None
|
||||
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
||||
"""
|
||||
Convert chat messages into model inputs.
|
||||
Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels.
|
||||
"""
|
||||
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
|
||||
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat template.")
|
||||
@@ -329,6 +335,115 @@ class DataProcessor:
|
||||
|
||||
return outputs
|
||||
|
||||
def prompt_token_ids2outputs(
|
||||
self, request: Dict[str, Any], tgts: List[str] = None
|
||||
) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]:
|
||||
outputs = {
|
||||
"input_ids": [],
|
||||
"token_type_ids": [],
|
||||
"position_ids": [],
|
||||
"images": [],
|
||||
"grid_thw": [],
|
||||
"image_type_ids": [],
|
||||
"labels": [],
|
||||
"cur_position": 0,
|
||||
"video_cnt": 0,
|
||||
"num_input_image_tokens": 0,
|
||||
"num_input_video_tokens": 0,
|
||||
"mm_positions": [],
|
||||
"mm_hashes": [],
|
||||
}
|
||||
prompt_token_ids = request.get("prompt_token_ids", [])
|
||||
prompt_token_ids_len = len(prompt_token_ids)
|
||||
if not request.get("messages"):
|
||||
outputs["input_ids"].extend(prompt_token_ids)
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * prompt_token_ids_len)
|
||||
for i in range(prompt_token_ids_len):
|
||||
outputs["position_ids"].append([i] * 3)
|
||||
outputs["cur_position"] += prompt_token_ids_len
|
||||
return outputs
|
||||
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = self.extract_mm_items(request)
|
||||
st, image_idx, video_idx = 0, 0, 0
|
||||
while st < prompt_token_ids_len:
|
||||
cur_token_id = prompt_token_ids[st]
|
||||
if cur_token_id == self.image_start_id:
|
||||
if image_idx >= len(images):
|
||||
raise ValueError("prompt token ids has more image placeholder than in messages")
|
||||
# append image_start_id
|
||||
outputs["input_ids"].extend([cur_token_id])
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
|
||||
outputs["position_ids"].append([outputs["cur_position"]] * 3)
|
||||
outputs["cur_position"] += 1
|
||||
st += 1
|
||||
# process placeholder token ids
|
||||
cur_idx = st
|
||||
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.image_end_id:
|
||||
cur_idx += 1
|
||||
if cur_idx >= prompt_token_ids_len:
|
||||
raise ValueError("image token ids not complete")
|
||||
image = images[image_idx]
|
||||
uuid = image_uuid[image_idx] if image_uuid else None
|
||||
token_len = cur_idx - st
|
||||
if not isinstance(image, tuple):
|
||||
self._add_image(image, outputs, uuid, token_len)
|
||||
else:
|
||||
self._add_processed_image(image, outputs, uuid, token_len)
|
||||
image_idx += 1
|
||||
st = cur_idx
|
||||
elif cur_token_id == self.video_start_id:
|
||||
if video_idx >= len(videos):
|
||||
raise ValueError("prompt token ids has more video placeholder than in messages")
|
||||
# append video_start_id
|
||||
outputs["input_ids"].extend([cur_token_id])
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
|
||||
outputs["position_ids"].append([outputs["cur_position"]] * 3)
|
||||
outputs["cur_position"] += 1
|
||||
st += 1
|
||||
# process placeholder token ids
|
||||
cur_idx = st
|
||||
while cur_idx < prompt_token_ids_len and prompt_token_ids[cur_idx] != self.video_end_id:
|
||||
cur_idx += 1
|
||||
if cur_idx >= prompt_token_ids_len:
|
||||
raise ValueError("video token ids not complete")
|
||||
video = videos[video_idx]
|
||||
uuid = video_uuid[video_idx] if video_uuid else None
|
||||
token_len = cur_idx - st
|
||||
if not isinstance(video, tuple):
|
||||
if isinstance(video, dict):
|
||||
frames = self._load_and_process_video(video["video"], video)
|
||||
else:
|
||||
frames = self._load_and_process_video(video, {})
|
||||
self._add_video(frames, outputs, uuid, token_len)
|
||||
else:
|
||||
self._add_processed_video(video, outputs, uuid, token_len)
|
||||
video_idx += 1
|
||||
st = cur_idx
|
||||
else:
|
||||
outputs["input_ids"].extend([cur_token_id])
|
||||
outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]])
|
||||
outputs["position_ids"].append([outputs["cur_position"]] * 3)
|
||||
outputs["cur_position"] += 1
|
||||
st += 1
|
||||
if image_idx != len(images):
|
||||
raise ValueError("number of images does not match")
|
||||
if video_idx != len(videos):
|
||||
raise ValueError("number of videos does not match")
|
||||
|
||||
if self.enable_processor_cache:
|
||||
missing_idx = set(missing_idx)
|
||||
hashes_to_cache, items_to_cache = [], []
|
||||
for idx in range(len(mm_items)):
|
||||
if idx in missing_idx:
|
||||
continue
|
||||
meta = {}
|
||||
t, h, w = outputs["grid_thw"][idx][0]
|
||||
meta["thw"] = (t, h, w)
|
||||
hashes_to_cache.append(outputs["mm_hashes"][idx])
|
||||
items_to_cache.append((outputs["images"][idx], meta))
|
||||
self.update_processor_cache(dealer, hashes_to_cache, items_to_cache)
|
||||
|
||||
return outputs
|
||||
|
||||
def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None:
|
||||
token_id = token if isinstance(token, int) else self.tokenizer.convert_tokens_to_ids(token)
|
||||
outputs["input_ids"].append(token_id)
|
||||
@@ -348,7 +463,7 @@ class DataProcessor:
|
||||
outputs["position_ids"].append([start + i] * 3)
|
||||
outputs["cur_position"] += len(tokens)
|
||||
|
||||
def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
def _add_image(self, img, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
img.height,
|
||||
img.width,
|
||||
@@ -356,6 +471,8 @@ class DataProcessor:
|
||||
max_pixels=self.image_max_pixels,
|
||||
)[1]
|
||||
num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2)
|
||||
if token_len and token_len != num_tokens:
|
||||
raise ValueError("image tokens num not match the size")
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
@@ -383,9 +500,13 @@ class DataProcessor:
|
||||
outputs["grid_thw"].append(ret["image_grid_thw"])
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
def _add_processed_image(
|
||||
self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
|
||||
) -> None:
|
||||
img, meta = img_cache
|
||||
num_tokens = img.shape[0] // (self.spatial_conv_size**2)
|
||||
if token_len and num_tokens != token_len:
|
||||
raise ValueError("image tokens num not match the size")
|
||||
|
||||
outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens))
|
||||
outputs["input_ids"].extend([self.image_patch_id] * num_tokens)
|
||||
@@ -401,7 +522,7 @@ class DataProcessor:
|
||||
outputs["grid_thw"].append(np.array([[1, h, w]]))
|
||||
outputs["image_type_ids"].append(0)
|
||||
|
||||
def _add_video(self, frames, outputs: Dict, uuid: Optional[str]) -> None:
|
||||
def _add_video(self, frames, outputs: Dict, uuid: Optional[str], token_len=None) -> None:
|
||||
patches_h, patches_w = self.image_preprocessor.get_smarted_resize(
|
||||
frames[0].height,
|
||||
frames[0].width,
|
||||
@@ -410,6 +531,8 @@ class DataProcessor:
|
||||
)[1]
|
||||
num_frames = len(frames)
|
||||
num_tokens = (num_frames * patches_h * patches_w) // (self.spatial_conv_size**2 * self.temporal_conv_size)
|
||||
if token_len and num_tokens != token_len:
|
||||
raise ValueError("video tokens num not match the size")
|
||||
|
||||
pixel_stack = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0)
|
||||
ret = self.image_preprocessor.preprocess(
|
||||
@@ -438,9 +561,13 @@ class DataProcessor:
|
||||
outputs["position_ids"].extend(pos_ids)
|
||||
outputs["cur_position"] = np.max(pos_ids) + 1
|
||||
|
||||
def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None:
|
||||
def _add_processed_video(
|
||||
self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str, token_len=None
|
||||
) -> None:
|
||||
frames, meta = frames_cache
|
||||
num_tokens = frames.shape[0] // (self.spatial_conv_size**2 * self.temporal_conv_size)
|
||||
if token_len and num_tokens != token_len:
|
||||
raise ValueError("video tokens num not match the size")
|
||||
|
||||
t, h, w = meta["thw"]
|
||||
outputs["images"].append(frames)
|
||||
|
||||
Reference in New Issue
Block a user