Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import warnings | |
| from typing import List | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from mmengine.structures import InstanceData, PixelData | |
| from mmengine.utils import is_list_of | |
| from .bbox.transforms import get_warp_matrix | |
| from .pose_data_sample import PoseDataSample | |
| def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample: | |
| """Merge the given data samples into a single data sample. | |
| This function can be used to merge the top-down predictions with | |
| bboxes from the same image. The merged data sample will contain all | |
| instances from the input data samples, and the identical metainfo with | |
| the first input data sample. | |
| Args: | |
| data_samples (List[:obj:`PoseDataSample`]): The data samples to | |
| merge | |
| Returns: | |
| PoseDataSample: The merged data sample. | |
| """ | |
| if not is_list_of(data_samples, PoseDataSample): | |
| raise ValueError('Invalid input type, should be a list of ' | |
| ':obj:`PoseDataSample`') | |
| if len(data_samples) == 0: | |
| warnings.warn('Try to merge an empty list of data samples.') | |
| return PoseDataSample() | |
| merged = PoseDataSample(metainfo=data_samples[0].metainfo) | |
| if 'gt_instances' in data_samples[0]: | |
| merged.gt_instances = InstanceData.cat( | |
| [d.gt_instances for d in data_samples]) | |
| if 'pred_instances' in data_samples[0]: | |
| merged.pred_instances = InstanceData.cat( | |
| [d.pred_instances for d in data_samples]) | |
| if 'pred_fields' in data_samples[0] and 'heatmaps' in data_samples[ | |
| 0].pred_fields: | |
| reverted_heatmaps = [ | |
| revert_heatmap(data_sample.pred_fields.heatmaps, | |
| data_sample.gt_instances.bbox_centers, | |
| data_sample.gt_instances.bbox_scales, | |
| data_sample.ori_shape) | |
| for data_sample in data_samples | |
| ] | |
| merged_heatmaps = np.max(reverted_heatmaps, axis=0) | |
| pred_fields = PixelData() | |
| pred_fields.set_data(dict(heatmaps=merged_heatmaps)) | |
| merged.pred_fields = pred_fields | |
| if 'gt_fields' in data_samples[0] and 'heatmaps' in data_samples[ | |
| 0].gt_fields: | |
| reverted_heatmaps = [ | |
| revert_heatmap(data_sample.gt_fields.heatmaps, | |
| data_sample.gt_instances.bbox_centers, | |
| data_sample.gt_instances.bbox_scales, | |
| data_sample.ori_shape) | |
| for data_sample in data_samples | |
| ] | |
| merged_heatmaps = np.max(reverted_heatmaps, axis=0) | |
| gt_fields = PixelData() | |
| gt_fields.set_data(dict(heatmaps=merged_heatmaps)) | |
| merged.gt_fields = gt_fields | |
| return merged | |
| def revert_heatmap(heatmap, bbox_center, bbox_scale, img_shape): | |
| """Revert predicted heatmap on the original image. | |
| Args: | |
| heatmap (np.ndarray or torch.tensor): predicted heatmap. | |
| bbox_center (np.ndarray): bounding box center coordinate. | |
| bbox_scale (np.ndarray): bounding box scale. | |
| img_shape (tuple or list): size of original image. | |
| """ | |
| if torch.is_tensor(heatmap): | |
| heatmap = heatmap.cpu().detach().numpy() | |
| ndim = heatmap.ndim | |
| # [K, H, W] -> [H, W, K] | |
| if ndim == 3: | |
| heatmap = heatmap.transpose(1, 2, 0) | |
| hm_h, hm_w = heatmap.shape[:2] | |
| img_h, img_w = img_shape | |
| warp_mat = get_warp_matrix( | |
| bbox_center.reshape((2, )), | |
| bbox_scale.reshape((2, )), | |
| rot=0, | |
| output_size=(hm_w, hm_h), | |
| inv=True) | |
| heatmap = cv2.warpAffine( | |
| heatmap, warp_mat, (img_w, img_h), flags=cv2.INTER_LINEAR) | |
| # [H, W, K] -> [K, H, W] | |
| if ndim == 3: | |
| heatmap = heatmap.transpose(2, 0, 1) | |
| return heatmap | |
| def split_instances(instances: InstanceData) -> List[InstanceData]: | |
| """Convert instances into a list where each element is a dict that contains | |
| information about one instance.""" | |
| results = [] | |
| # return an empty list if there is no instance detected by the model | |
| if instances is None: | |
| return results | |
| for i in range(len(instances.keypoints)): | |
| result = dict( | |
| keypoints=instances.keypoints[i].tolist(), | |
| keypoint_scores=instances.keypoint_scores[i].tolist(), | |
| ) | |
| if 'bboxes' in instances: | |
| result['bbox'] = instances.bboxes[i].tolist(), | |
| if 'bbox_scores' in instances: | |
| result['bbox_score'] = instances.bbox_scores[i] | |
| results.append(result) | |
| return results | |