Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Sequence, Union | |
| import numpy as np | |
| import torch | |
| from mmcv.transforms import BaseTransform | |
| from mmengine.structures import InstanceData, PixelData | |
| from mmengine.utils import is_seq_of | |
| from mmpose.registry import TRANSFORMS | |
| from mmpose.structures import MultilevelPixelData, PoseDataSample | |
| def image_to_tensor(img: Union[np.ndarray, | |
| Sequence[np.ndarray]]) -> torch.torch.Tensor: | |
| """Translate image or sequence of images to tensor. Multiple image tensors | |
| will be stacked. | |
| Args: | |
| value (np.ndarray | Sequence[np.ndarray]): The original image or | |
| image sequence | |
| Returns: | |
| torch.Tensor: The output tensor. | |
| """ | |
| if isinstance(img, np.ndarray): | |
| if len(img.shape) < 3: | |
| img = np.expand_dims(img, -1) | |
| img = np.ascontiguousarray(img) | |
| tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous() | |
| else: | |
| assert is_seq_of(img, np.ndarray) | |
| tensor = torch.stack([image_to_tensor(_img) for _img in img]) | |
| return tensor | |
| def keypoints_to_tensor(keypoints: Union[np.ndarray, Sequence[np.ndarray]] | |
| ) -> torch.torch.Tensor: | |
| """Translate keypoints or sequence of keypoints to tensor. Multiple | |
| keypoints tensors will be stacked. | |
| Args: | |
| keypoints (np.ndarray | Sequence[np.ndarray]): The keypoints or | |
| keypoints sequence. | |
| Returns: | |
| torch.Tensor: The output tensor. | |
| """ | |
| if isinstance(keypoints, np.ndarray): | |
| keypoints = np.ascontiguousarray(keypoints) | |
| N = keypoints.shape[0] | |
| keypoints = keypoints.transpose(1, 2, 0).reshape(-1, N) | |
| tensor = torch.from_numpy(keypoints).contiguous() | |
| else: | |
| assert is_seq_of(keypoints, np.ndarray) | |
| tensor = torch.stack( | |
| [keypoints_to_tensor(_keypoints) for _keypoints in keypoints]) | |
| return tensor | |
| class PackPoseInputs(BaseTransform): | |
| """Pack the inputs data for pose estimation. | |
| The ``img_meta`` item is always populated. The contents of the | |
| ``img_meta`` dictionary depends on ``meta_keys``. By default it includes: | |
| - ``id``: id of the data sample | |
| - ``img_id``: id of the image | |
| - ``'category_id'``: the id of the instance category | |
| - ``img_path``: path to the image file | |
| - ``crowd_index`` (optional): measure the crowding level of an image, | |
| defined in CrowdPose dataset | |
| - ``ori_shape``: original shape of the image as a tuple (h, w, c) | |
| - ``img_shape``: shape of the image input to the network as a tuple \ | |
| (h, w). Note that images may be zero padded on the \ | |
| bottom/right if the batch tensor is larger than this shape. | |
| - ``input_size``: the input size to the network | |
| - ``flip``: a boolean indicating if image flip transform was used | |
| - ``flip_direction``: the flipping direction | |
| - ``flip_indices``: the indices of each keypoint's symmetric keypoint | |
| - ``raw_ann_info`` (optional): raw annotation of the instance(s) | |
| Args: | |
| meta_keys (Sequence[str], optional): Meta keys which will be stored in | |
| :obj: `PoseDataSample` as meta info. Defaults to ``('id', | |
| 'img_id', 'img_path', 'category_id', 'crowd_index, 'ori_shape', | |
| 'img_shape',, 'input_size', 'input_center', 'input_scale', 'flip', | |
| 'flip_direction', 'flip_indices', 'raw_ann_info')`` | |
| """ | |
| # items in `instance_mapping_table` will be directly packed into | |
| # PoseDataSample.gt_instances without converting to Tensor | |
| instance_mapping_table = { | |
| 'bbox': 'bboxes', | |
| 'head_size': 'head_size', | |
| 'bbox_center': 'bbox_centers', | |
| 'bbox_scale': 'bbox_scales', | |
| 'bbox_score': 'bbox_scores', | |
| 'keypoints': 'keypoints', | |
| 'keypoints_visible': 'keypoints_visible', | |
| 'lifting_target': 'lifting_target', | |
| 'lifting_target_visible': 'lifting_target_visible', | |
| } | |
| # items in `label_mapping_table` will be packed into | |
| # PoseDataSample.gt_instance_labels and converted to Tensor. These items | |
| # will be used for computing losses | |
| label_mapping_table = { | |
| 'keypoint_labels': 'keypoint_labels', | |
| 'lifting_target_label': 'lifting_target_label', | |
| 'lifting_target_weights': 'lifting_target_weights', | |
| 'trajectory_weights': 'trajectory_weights', | |
| 'keypoint_x_labels': 'keypoint_x_labels', | |
| 'keypoint_y_labels': 'keypoint_y_labels', | |
| 'keypoint_weights': 'keypoint_weights', | |
| 'instance_coords': 'instance_coords', | |
| 'transformed_keypoints_visible': 'keypoints_visible', | |
| } | |
| # items in `field_mapping_table` will be packed into | |
| # PoseDataSample.gt_fields and converted to Tensor. These items will be | |
| # used for computing losses | |
| field_mapping_table = { | |
| 'heatmaps': 'heatmaps', | |
| 'instance_heatmaps': 'instance_heatmaps', | |
| 'heatmap_mask': 'heatmap_mask', | |
| 'heatmap_weights': 'heatmap_weights', | |
| 'displacements': 'displacements', | |
| 'displacement_weights': 'displacement_weights', | |
| } | |
| def __init__(self, | |
| meta_keys=('id', 'img_id', 'img_path', 'category_id', | |
| 'crowd_index', 'ori_shape', 'img_shape', | |
| 'input_size', 'input_center', 'input_scale', | |
| 'flip', 'flip_direction', 'flip_indices', | |
| 'raw_ann_info'), | |
| pack_transformed=False): | |
| self.meta_keys = meta_keys | |
| self.pack_transformed = pack_transformed | |
| def transform(self, results: dict) -> dict: | |
| """Method to pack the input data. | |
| Args: | |
| results (dict): Result dict from the data pipeline. | |
| Returns: | |
| dict: | |
| - 'inputs' (obj:`torch.Tensor`): The forward data of models. | |
| - 'data_samples' (obj:`PoseDataSample`): The annotation info of the | |
| sample. | |
| """ | |
| # Pack image(s) for 2d pose estimation | |
| if 'img' in results: | |
| img = results['img'] | |
| inputs_tensor = image_to_tensor(img) | |
| # Pack keypoints for 3d pose-lifting | |
| elif 'lifting_target' in results and 'keypoints' in results: | |
| if 'keypoint_labels' in results: | |
| keypoints = results['keypoint_labels'] | |
| else: | |
| keypoints = results['keypoints'] | |
| inputs_tensor = keypoints_to_tensor(keypoints) | |
| rgb_gt_tensor = image_to_tensor(results['rgb_gt']) | |
| data_sample = PoseDataSample() | |
| # pack instance data | |
| gt_instances = InstanceData() | |
| for key, packed_key in self.instance_mapping_table.items(): | |
| if key in results: | |
| if 'lifting_target' in results and key in { | |
| 'keypoints', 'keypoints_visible' | |
| }: | |
| continue | |
| gt_instances.set_field(results[key], packed_key) | |
| # pack `transformed_keypoints` for visualizing data transform | |
| # and augmentation results | |
| if self.pack_transformed and 'transformed_keypoints' in results: | |
| gt_instances.set_field(results['transformed_keypoints'], | |
| 'transformed_keypoints') | |
| if self.pack_transformed and \ | |
| 'transformed_keypoints_visible' in results: | |
| gt_instances.set_field(results['transformed_keypoints_visible'], | |
| 'transformed_keypoints_visible') | |
| data_sample.gt_instances = gt_instances | |
| # pack instance labels | |
| gt_instance_labels = InstanceData() | |
| for key, packed_key in self.label_mapping_table.items(): | |
| if key in results: | |
| # For pose-lifting, store only target-related fields | |
| if 'lifting_target_label' in results and key in { | |
| 'keypoint_labels', 'keypoint_weights', | |
| 'transformed_keypoints_visible' | |
| }: | |
| continue | |
| if isinstance(results[key], list): | |
| # A list of labels is usually generated by combined | |
| # multiple encoders (See ``GenerateTarget`` in | |
| # mmpose/datasets/transforms/common_transforms.py) | |
| # In this case, labels in list should have the same | |
| # shape and will be stacked. | |
| _labels = np.stack(results[key]) | |
| gt_instance_labels.set_field(_labels, packed_key) | |
| else: | |
| gt_instance_labels.set_field(results[key], packed_key) | |
| data_sample.gt_instance_labels = gt_instance_labels.to_tensor() | |
| # pack fields | |
| gt_fields = None | |
| for key, packed_key in self.field_mapping_table.items(): | |
| if key in results: | |
| if isinstance(results[key], list): | |
| if gt_fields is None: | |
| gt_fields = MultilevelPixelData() | |
| else: | |
| assert isinstance( | |
| gt_fields, MultilevelPixelData | |
| ), 'Got mixed single-level and multi-level pixel data.' | |
| else: | |
| if gt_fields is None: | |
| gt_fields = PixelData() | |
| else: | |
| assert isinstance( | |
| gt_fields, PixelData | |
| ), 'Got mixed single-level and multi-level pixel data.' | |
| gt_fields.set_field(results[key], packed_key) | |
| if gt_fields: | |
| data_sample.gt_fields = gt_fields.to_tensor() | |
| img_meta = {k: results[k] for k in self.meta_keys if k in results} | |
| data_sample.set_metainfo(img_meta) | |
| packed_results = dict() | |
| packed_results['inputs'] = inputs_tensor | |
| packed_results['data_samples'] = data_sample | |
| packed_results['rgb_gt'] = rgb_gt_tensor | |
| return packed_results | |
| def __repr__(self) -> str: | |
| """print the basic information of the transform. | |
| Returns: | |
| str: Formatted string. | |
| """ | |
| repr_str = self.__class__.__name__ | |
| repr_str += f'(meta_keys={self.meta_keys})' | |
| return repr_str | |