mirror of
https://huggingface.co/spaces/H-Liu1997/TANGO
synced 2026-04-23 00:27:14 +08:00
153 lines
5.1 KiB
Python
153 lines
5.1 KiB
Python
from typing import Literal, Union
|
|
|
|
def process_mmdet_results(mmdet_results: list,
|
|
cat_id: int = 0,
|
|
multi_person: bool = True) -> list:
|
|
"""Process mmdet results, sort bboxes by area in descending order.
|
|
|
|
Args:
|
|
mmdet_results (list):
|
|
Result of mmdet.apis.inference_detector
|
|
when the input is a batch.
|
|
Shape of the nested lists is
|
|
(n_frame, n_category, n_human, 5).
|
|
cat_id (int, optional):
|
|
Category ID. This function will only select
|
|
the selected category, and drop the others.
|
|
Defaults to 0, ID of human category.
|
|
multi_person (bool, optional):
|
|
Whether to allow multi-person detection, which is
|
|
slower than single-person. If false, the function
|
|
only assure that the first person of each frame
|
|
has the biggest bbox.
|
|
Defaults to True.
|
|
|
|
Returns:
|
|
list:
|
|
A list of detected bounding boxes.
|
|
Shape of the nested lists is
|
|
(n_frame, n_human, 5)
|
|
and each bbox is (x, y, x, y, score).
|
|
"""
|
|
ret_list = []
|
|
only_max_arg = not multi_person
|
|
# for _, frame_results in enumerate(mmdet_results):
|
|
cat_bboxes = mmdet_results[cat_id]
|
|
# import pdb; pdb.set_trace()
|
|
sorted_bbox = qsort_bbox_list(cat_bboxes, only_max_arg)
|
|
|
|
if only_max_arg:
|
|
ret_list.append(sorted_bbox[0:1])
|
|
else:
|
|
ret_list.append(sorted_bbox)
|
|
return ret_list
|
|
|
|
|
|
def qsort_bbox_list(bbox_list: list,
|
|
only_max: bool = False,
|
|
bbox_convention: Literal['xyxy', 'xywh'] = 'xyxy'):
|
|
"""Sort a list of bboxes, by their area in pixel(W*H).
|
|
|
|
Args:
|
|
input_list (list):
|
|
A list of bboxes. Each item is a list of (x1, y1, x2, y2)
|
|
only_max (bool, optional):
|
|
If True, only assure the max element at first place,
|
|
others may not be well sorted.
|
|
If False, return a well sorted descending list.
|
|
Defaults to False.
|
|
bbox_convention (str, optional):
|
|
Bbox type, xyxy or xywh. Defaults to 'xyxy'.
|
|
|
|
Returns:
|
|
list:
|
|
A sorted(maybe not so well) descending list.
|
|
"""
|
|
# import pdb; pdb.set_trace()
|
|
if len(bbox_list) <= 1:
|
|
return bbox_list
|
|
else:
|
|
bigger_list = []
|
|
less_list = []
|
|
anchor_index = int(len(bbox_list) / 2)
|
|
anchor_bbox = bbox_list[anchor_index]
|
|
anchor_area = get_area_of_bbox(anchor_bbox, bbox_convention)
|
|
for i in range(len(bbox_list)):
|
|
if i == anchor_index:
|
|
continue
|
|
tmp_bbox = bbox_list[i]
|
|
tmp_area = get_area_of_bbox(tmp_bbox, bbox_convention)
|
|
if tmp_area >= anchor_area:
|
|
bigger_list.append(tmp_bbox)
|
|
else:
|
|
less_list.append(tmp_bbox)
|
|
if only_max:
|
|
return qsort_bbox_list(bigger_list) + \
|
|
[anchor_bbox, ] + less_list
|
|
else:
|
|
return qsort_bbox_list(bigger_list) + \
|
|
[anchor_bbox, ] + qsort_bbox_list(less_list)
|
|
|
|
def get_area_of_bbox(
|
|
bbox: Union[list, tuple],
|
|
bbox_convention: Literal['xyxy', 'xywh'] = 'xyxy') -> float:
|
|
"""Get the area of a bbox_xyxy.
|
|
|
|
Args:
|
|
(Union[list, tuple]):
|
|
A list of [x1, y1, x2, y2].
|
|
bbox_convention (str, optional):
|
|
Bbox type, xyxy or xywh. Defaults to 'xyxy'.
|
|
|
|
Returns:
|
|
float:
|
|
Area of the bbox(|y2-y1|*|x2-x1|).
|
|
"""
|
|
# import pdb;pdb.set_trace()
|
|
if bbox_convention == 'xyxy':
|
|
return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
|
|
elif bbox_convention == 'xywh':
|
|
return abs(bbox[2] * bbox[3])
|
|
else:
|
|
raise TypeError(f'Wrong bbox convention: {bbox_convention}')
|
|
|
|
def calculate_iou(bbox1, bbox2):
|
|
# Calculate the Intersection over Union (IoU) between two bounding boxes
|
|
x1 = max(bbox1[0], bbox2[0])
|
|
y1 = max(bbox1[1], bbox2[1])
|
|
x2 = min(bbox1[2], bbox2[2])
|
|
y2 = min(bbox1[3], bbox2[3])
|
|
|
|
intersection_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
|
|
|
|
bbox1_area = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1)
|
|
bbox2_area = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1)
|
|
|
|
union_area = bbox1_area + bbox2_area - intersection_area
|
|
|
|
iou = intersection_area / union_area
|
|
return iou
|
|
|
|
|
|
def non_max_suppression(bboxes, iou_threshold):
|
|
# Sort the bounding boxes by their confidence scores (e.g., the probability of containing an object)
|
|
bboxes = sorted(bboxes, key=lambda x: x[4], reverse=True)
|
|
|
|
# Initialize a list to store the selected bounding boxes
|
|
selected_bboxes = []
|
|
|
|
# Perform non-maximum suppression
|
|
while len(bboxes) > 0:
|
|
current_bbox = bboxes[0]
|
|
selected_bboxes.append(current_bbox)
|
|
bboxes = bboxes[1:]
|
|
|
|
remaining_bboxes = []
|
|
for bbox in bboxes:
|
|
iou = calculate_iou(current_bbox, bbox)
|
|
if iou < iou_threshold:
|
|
remaining_bboxes.append(bbox)
|
|
|
|
bboxes = remaining_bboxes
|
|
|
|
return selected_bboxes |