teemosliang's picture
realease: publish the huggingface space demo of SDPose
d5532b9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Sequence, Union
import numpy as np
import torch
from mmcv.transforms import BaseTransform
from mmengine.structures import InstanceData, PixelData
from mmengine.utils import is_seq_of
from mmpose.registry import TRANSFORMS
from mmpose.structures import MultilevelPixelData, PoseDataSample
def image_to_tensor(img: Union[np.ndarray,
Sequence[np.ndarray]]) -> torch.torch.Tensor:
"""Translate image or sequence of images to tensor. Multiple image tensors
will be stacked.
Args:
value (np.ndarray | Sequence[np.ndarray]): The original image or
image sequence
Returns:
torch.Tensor: The output tensor.
"""
if isinstance(img, np.ndarray):
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img)
tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous()
else:
assert is_seq_of(img, np.ndarray)
tensor = torch.stack([image_to_tensor(_img) for _img in img])
return tensor
def keypoints_to_tensor(keypoints: Union[np.ndarray, Sequence[np.ndarray]]
) -> torch.torch.Tensor:
"""Translate keypoints or sequence of keypoints to tensor. Multiple
keypoints tensors will be stacked.
Args:
keypoints (np.ndarray | Sequence[np.ndarray]): The keypoints or
keypoints sequence.
Returns:
torch.Tensor: The output tensor.
"""
if isinstance(keypoints, np.ndarray):
keypoints = np.ascontiguousarray(keypoints)
N = keypoints.shape[0]
keypoints = keypoints.transpose(1, 2, 0).reshape(-1, N)
tensor = torch.from_numpy(keypoints).contiguous()
else:
assert is_seq_of(keypoints, np.ndarray)
tensor = torch.stack(
[keypoints_to_tensor(_keypoints) for _keypoints in keypoints])
return tensor
@TRANSFORMS.register_module()
class PackPoseInputs(BaseTransform):
"""Pack the inputs data for pose estimation.
The ``img_meta`` item is always populated. The contents of the
``img_meta`` dictionary depends on ``meta_keys``. By default it includes:
- ``id``: id of the data sample
- ``img_id``: id of the image
- ``'category_id'``: the id of the instance category
- ``img_path``: path to the image file
- ``crowd_index`` (optional): measure the crowding level of an image,
defined in CrowdPose dataset
- ``ori_shape``: original shape of the image as a tuple (h, w, c)
- ``img_shape``: shape of the image input to the network as a tuple \
(h, w). Note that images may be zero padded on the \
bottom/right if the batch tensor is larger than this shape.
- ``input_size``: the input size to the network
- ``flip``: a boolean indicating if image flip transform was used
- ``flip_direction``: the flipping direction
- ``flip_indices``: the indices of each keypoint's symmetric keypoint
- ``raw_ann_info`` (optional): raw annotation of the instance(s)
Args:
meta_keys (Sequence[str], optional): Meta keys which will be stored in
:obj: `PoseDataSample` as meta info. Defaults to ``('id',
'img_id', 'img_path', 'category_id', 'crowd_index, 'ori_shape',
'img_shape',, 'input_size', 'input_center', 'input_scale', 'flip',
'flip_direction', 'flip_indices', 'raw_ann_info')``
"""
# items in `instance_mapping_table` will be directly packed into
# PoseDataSample.gt_instances without converting to Tensor
instance_mapping_table = {
'bbox': 'bboxes',
'head_size': 'head_size',
'bbox_center': 'bbox_centers',
'bbox_scale': 'bbox_scales',
'bbox_score': 'bbox_scores',
'keypoints': 'keypoints',
'keypoints_visible': 'keypoints_visible',
'lifting_target': 'lifting_target',
'lifting_target_visible': 'lifting_target_visible',
}
# items in `label_mapping_table` will be packed into
# PoseDataSample.gt_instance_labels and converted to Tensor. These items
# will be used for computing losses
label_mapping_table = {
'keypoint_labels': 'keypoint_labels',
'lifting_target_label': 'lifting_target_label',
'lifting_target_weights': 'lifting_target_weights',
'trajectory_weights': 'trajectory_weights',
'keypoint_x_labels': 'keypoint_x_labels',
'keypoint_y_labels': 'keypoint_y_labels',
'keypoint_weights': 'keypoint_weights',
'instance_coords': 'instance_coords',
'transformed_keypoints_visible': 'keypoints_visible',
}
# items in `field_mapping_table` will be packed into
# PoseDataSample.gt_fields and converted to Tensor. These items will be
# used for computing losses
field_mapping_table = {
'heatmaps': 'heatmaps',
'instance_heatmaps': 'instance_heatmaps',
'heatmap_mask': 'heatmap_mask',
'heatmap_weights': 'heatmap_weights',
'displacements': 'displacements',
'displacement_weights': 'displacement_weights',
}
def __init__(self,
meta_keys=('id', 'img_id', 'img_path', 'category_id',
'crowd_index', 'ori_shape', 'img_shape',
'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices',
'raw_ann_info'),
pack_transformed=False):
self.meta_keys = meta_keys
self.pack_transformed = pack_transformed
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
- 'data_samples' (obj:`PoseDataSample`): The annotation info of the
sample.
"""
# Pack image(s) for 2d pose estimation
if 'img' in results:
img = results['img']
inputs_tensor = image_to_tensor(img)
# Pack keypoints for 3d pose-lifting
elif 'lifting_target' in results and 'keypoints' in results:
if 'keypoint_labels' in results:
keypoints = results['keypoint_labels']
else:
keypoints = results['keypoints']
inputs_tensor = keypoints_to_tensor(keypoints)
rgb_gt_tensor = image_to_tensor(results['rgb_gt'])
data_sample = PoseDataSample()
# pack instance data
gt_instances = InstanceData()
for key, packed_key in self.instance_mapping_table.items():
if key in results:
if 'lifting_target' in results and key in {
'keypoints', 'keypoints_visible'
}:
continue
gt_instances.set_field(results[key], packed_key)
# pack `transformed_keypoints` for visualizing data transform
# and augmentation results
if self.pack_transformed and 'transformed_keypoints' in results:
gt_instances.set_field(results['transformed_keypoints'],
'transformed_keypoints')
if self.pack_transformed and \
'transformed_keypoints_visible' in results:
gt_instances.set_field(results['transformed_keypoints_visible'],
'transformed_keypoints_visible')
data_sample.gt_instances = gt_instances
# pack instance labels
gt_instance_labels = InstanceData()
for key, packed_key in self.label_mapping_table.items():
if key in results:
# For pose-lifting, store only target-related fields
if 'lifting_target_label' in results and key in {
'keypoint_labels', 'keypoint_weights',
'transformed_keypoints_visible'
}:
continue
if isinstance(results[key], list):
# A list of labels is usually generated by combined
# multiple encoders (See ``GenerateTarget`` in
# mmpose/datasets/transforms/common_transforms.py)
# In this case, labels in list should have the same
# shape and will be stacked.
_labels = np.stack(results[key])
gt_instance_labels.set_field(_labels, packed_key)
else:
gt_instance_labels.set_field(results[key], packed_key)
data_sample.gt_instance_labels = gt_instance_labels.to_tensor()
# pack fields
gt_fields = None
for key, packed_key in self.field_mapping_table.items():
if key in results:
if isinstance(results[key], list):
if gt_fields is None:
gt_fields = MultilevelPixelData()
else:
assert isinstance(
gt_fields, MultilevelPixelData
), 'Got mixed single-level and multi-level pixel data.'
else:
if gt_fields is None:
gt_fields = PixelData()
else:
assert isinstance(
gt_fields, PixelData
), 'Got mixed single-level and multi-level pixel data.'
gt_fields.set_field(results[key], packed_key)
if gt_fields:
data_sample.gt_fields = gt_fields.to_tensor()
img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)
packed_results = dict()
packed_results['inputs'] = inputs_tensor
packed_results['data_samples'] = data_sample
packed_results['rgb_gt'] = rgb_gt_tensor
return packed_results
def __repr__(self) -> str:
"""print the basic information of the transform.
Returns:
str: Formatted string.
"""
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str