# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from typing import Dict, List, Optional, Tuple, Union import cv2 import mmcv import numpy as np from matplotlib import pyplot as plt from mmengine.dist import master_only from mmengine.structures import InstanceData from mmpose.registry import VISUALIZERS from mmpose.structures import PoseDataSample from . import PoseLocalVisualizer @VISUALIZERS.register_module() class Pose3dLocalVisualizer(PoseLocalVisualizer): def __init__( self, name: str = 'visualizer', image: Optional[np.ndarray] = None, vis_backends: Optional[Dict] = None, save_dir: Optional[str] = None, bbox_color: Optional[Union[str, Tuple[int]]] = 'green', kpt_color: Optional[Union[str, Tuple[Tuple[int]]]] = 'red', link_color: Optional[Union[str, Tuple[Tuple[int]]]] = None, text_color: Optional[Union[str, Tuple[int]]] = (255, 255, 255), skeleton: Optional[Union[List, Tuple]] = None, line_width: Union[int, float] = 1, radius: Union[int, float] = 3, show_keypoint_weight: bool = False, backend: str = 'opencv', alpha: float = 0.8, det_kpt_color: Optional[Union[str, Tuple[Tuple[int]]]] = None, det_dataset_skeleton: Optional[Union[str, Tuple[Tuple[int]]]] = None, det_dataset_link_color: Optional[np.ndarray] = None): super().__init__(name, image, vis_backends, save_dir, bbox_color, kpt_color, link_color, text_color, skeleton, line_width, radius, show_keypoint_weight, backend, alpha) self.det_kpt_color = det_kpt_color self.det_dataset_skeleton = det_dataset_skeleton self.det_dataset_link_color = det_dataset_link_color def _draw_3d_keypoints( self, image: np.ndarray, keypoints_3d: np.ndarray, keypoints_3d_scores: np.ndarray, kpt_thr: float = 0.3, num_instances=-1, axis_elev: float = 15.0, axis_azimuth: float = 70, axis_roll: float = 0, x_axis_limit: float = 1.7, y_axis_limit: float = 1.7, z_axis_limit: float = 1.7, axis_dist: float = 10.0, radius: int = 10, thickness: int = 10, ): """ axis_azimuth (float): axis azimuth angle for 3D visualizations. axis_dist (float): axis distance for 3D visualizations. axis_elev (float): axis elevation view angle for 3D visualizations. axis_limit (float): The axis limit to visualize 3d pose. The xyz range will be set as: - x: [x_c - axis_limit/2, x_c + axis_limit/2] - y: [y_c - axis_limit/2, y_c + axis_limit/2] - z: [0, axis_limit]c Where x_c, y_c is the mean value of x and y coordinates """ vis_height, vis_width, _ = image.shape num_instances = len(keypoints_3d) num_fig = num_instances plt.ioff() fig = plt.figure(figsize=(vis_width * num_instances * 0.01, vis_height * 0.01)) def _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, fig_idx, title=None): for idx, (kpts, score, visible) in enumerate(zip(keypoints, scores, keypoints_visible)): valid = np.logical_and(score >= kpt_thr, np.any(~np.isnan(kpts), axis=-1)) ax = fig.add_subplot(1, num_fig, fig_idx * (idx + 1), projection='3d') ax.view_init(elev=axis_elev, azim=axis_azimuth, roll=axis_roll) # ax.set_aspect('auto') ax.set_aspect('equal') ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_zticklabels([]) if title: ax.set_title(f'{title} ({idx})') ax.dist = axis_dist # x_c = 0 # y_c = 0 # ax.set_xlim3d([x_c - x_axis_limit / 2, x_c + x_axis_limit / 2]) # ax.set_ylim3d([y_c - y_axis_limit / 2, y_c + y_axis_limit / 2]) # ax.set_zlim3d([0, z_axis_limit]) kpts = np.array(kpts, copy=False) kpts_valid = kpts[valid] # only use valid keypoints # Calculate min and max for each dimension min_x, min_y, min_z = np.min(kpts_valid, axis=0) max_x, max_y, max_z = np.max(kpts_valid, axis=0) # Add a margin (e.g., 10% of the range) margin_x = 0.1 * (max_x - min_x) margin_y = 0.1 * (max_y - min_y) margin_z = 0.1 * (max_z - min_z) # Set limits with margin ax.set_xlim3d([min_x - margin_x, max_x + margin_x]) ax.set_ylim3d([min_y - margin_y, max_y + margin_y]) ax.set_zlim3d([min_z - margin_z, max_z + margin_z]) # Adjust aspect ratio based on data range max_range = np.array([max_x-min_x, max_y-min_y, max_z-min_z]).max() mid_x = (max_x + min_x) * 0.5 mid_y = (max_y + min_y) * 0.5 mid_z = (max_z + min_z) * 0.5 ax.set_xlim(mid_x - max_range * 0.5, mid_x + max_range * 0.5) ax.set_ylim(mid_y - max_range * 0.5, mid_y + max_range * 0.5) ax.set_zlim(mid_z - max_range * 0.5, mid_z + max_range * 0.5) # Set aspect to 'equal' for true proportions ax.set_aspect('equal') if self.kpt_color is None or isinstance(self.kpt_color, str): kpt_color = [self.kpt_color] * len(kpts) elif len(self.kpt_color) == len(kpts): kpt_color = self.kpt_color else: raise ValueError( f'the length of kpt_color ' f'({len(self.kpt_color)}) does not matches ' f'that of keypoints ({len(kpts)})') kpts_valid = kpts[valid] ## only draw valid keypoints x_3d, y_3d, z_3d = np.split(kpts_valid[:, :3], [1, 2], axis=1) kpt_color = kpt_color[valid][..., ::-1] / 255. ax.scatter(x_3d, y_3d, z_3d, marker='o', color=kpt_color, s=radius) if self.skeleton is not None and self.link_color is not None: if self.link_color is None or isinstance( self.link_color, str): link_color = [self.link_color] * len(self.skeleton) elif len(self.link_color) == len(self.skeleton): link_color = self.link_color else: raise ValueError( f'the length of link_color ' f'({len(self.link_color)}) does not matches ' f'that of skeleton ({len(self.skeleton)})') for sk_id, sk in enumerate(self.skeleton): sk_indices = [_i for _i in sk] xs_3d = kpts[sk_indices, 0] ys_3d = kpts[sk_indices, 1] zs_3d = kpts[sk_indices, 2] kpt_score = score[sk_indices] if kpt_score.min() > kpt_thr: # matplotlib uses RGB color in [0, 1] value range _color = link_color[sk_id][::-1] / 255. ax.plot(xs_3d, ys_3d, zs_3d, color=_color, zdir='z', linewidth=thickness) # Set the linewidth here if keypoints_3d_scores is not None: scores = keypoints_3d_scores else: scores = np.ones(keypoints_3d.shape[:-1]) keypoints_3d_visible = np.ones(keypoints_3d.shape[:-1]) _draw_3d_instances_kpts(keypoints_3d, scores, keypoints_3d_visible, 1) # convert figure to numpy array fig.tight_layout() fig.canvas.draw() pred_img_data = fig.canvas.tostring_rgb() pred_img_data = np.frombuffer( fig.canvas.tostring_rgb(), dtype=np.uint8) if not pred_img_data.any(): pred_img_data = np.full((vis_height, vis_width, 3), 255) else: pred_img_data = pred_img_data.reshape(vis_height, vis_width * num_instances, -1) plt.close(fig) return pred_img_data def _draw_3d_data_samples( self, image: np.ndarray, pose_samples: PoseDataSample, draw_gt: bool = True, kpt_thr: float = 0.3, num_instances=-1, axis_azimuth: float = 70, axis_limit: float = 1.7, axis_dist: float = 10.0, axis_elev: float = 15.0, ): """Draw keypoints and skeletons (optional) of GT or prediction. Args: image (np.ndarray): The image to draw. instances (:obj:`InstanceData`): Data structure for instance-level annotations or predictions. draw_gt (bool): Whether to draw GT PoseDataSample. Default to ``True`` kpt_thr (float, optional): Minimum threshold of keypoints to be shown. Default: 0.3. num_instances (int): Number of instances to be shown in 3D. If smaller than 0, all the instances in the pose_result will be shown. Otherwise, pad or truncate the pose_result to a length of num_instances. axis_azimuth (float): axis azimuth angle for 3D visualizations. axis_dist (float): axis distance for 3D visualizations. axis_elev (float): axis elevation view angle for 3D visualizations. axis_limit (float): The axis limit to visualize 3d pose. The xyz range will be set as: - x: [x_c - axis_limit/2, x_c + axis_limit/2] - y: [y_c - axis_limit/2, y_c + axis_limit/2] - z: [0, axis_limit] Where x_c, y_c is the mean value of x and y coordinates Returns: Tuple(np.ndarray): the drawn image which channel is RGB. """ vis_height, vis_width, _ = image.shape if 'pred_instances' in pose_samples: pred_instances = pose_samples.pred_instances else: pred_instances = InstanceData() if num_instances < 0: if 'keypoints' in pred_instances: num_instances = len(pred_instances) else: num_instances = 0 else: if len(pred_instances) > num_instances: pred_instances_ = InstanceData() for k in pred_instances.keys(): new_val = pred_instances[k][:num_instances] pred_instances_.set_field(new_val, k) pred_instances = pred_instances_ elif num_instances < len(pred_instances): num_instances = len(pred_instances) num_fig = num_instances if draw_gt: vis_width *= 2 num_fig *= 2 plt.ioff() fig = plt.figure( figsize=(vis_width * num_instances * 0.01, vis_height * 0.01)) def _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, fig_idx, title=None): for idx, (kpts, score, visible) in enumerate( zip(keypoints, scores, keypoints_visible)): valid = np.logical_and(score >= kpt_thr, np.any(~np.isnan(kpts), axis=-1)) ax = fig.add_subplot( 1, num_fig, fig_idx * (idx + 1), projection='3d') ax.view_init(elev=axis_elev, azim=axis_azimuth) ax.set_zlim3d([0, axis_limit]) ax.set_aspect('auto') ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_zticklabels([]) ax.scatter([0], [0], [0], marker='o', color='red') if title: ax.set_title(f'{title} ({idx})') ax.dist = axis_dist x_c = np.mean(kpts[valid, 0]) if valid.any() else 0 y_c = np.mean(kpts[valid, 1]) if valid.any() else 0 ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2]) ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2]) kpts = np.array(kpts, copy=False) if self.kpt_color is None or isinstance(self.kpt_color, str): kpt_color = [self.kpt_color] * len(kpts) elif len(self.kpt_color) == len(kpts): kpt_color = self.kpt_color else: raise ValueError( f'the length of kpt_color ' f'({len(self.kpt_color)}) does not matches ' f'that of keypoints ({len(kpts)})') kpts = kpts[valid] x_3d, y_3d, z_3d = np.split(kpts[:, :3], [1, 2], axis=1) kpt_color = kpt_color[valid][..., ::-1] / 255. ax.scatter(x_3d, y_3d, z_3d, marker='o', color=kpt_color) for kpt_idx in range(len(x_3d)): ax.text(x_3d[kpt_idx][0], y_3d[kpt_idx][0], z_3d[kpt_idx][0], str(kpt_idx)) if self.skeleton is not None and self.link_color is not None: if self.link_color is None or isinstance( self.link_color, str): link_color = [self.link_color] * len(self.skeleton) elif len(self.link_color) == len(self.skeleton): link_color = self.link_color else: raise ValueError( f'the length of link_color ' f'({len(self.link_color)}) does not matches ' f'that of skeleton ({len(self.skeleton)})') for sk_id, sk in enumerate(self.skeleton): sk_indices = [_i for _i in sk] xs_3d = kpts[sk_indices, 0] ys_3d = kpts[sk_indices, 1] zs_3d = kpts[sk_indices, 2] kpt_score = score[sk_indices] if kpt_score.min() > kpt_thr: # matplotlib uses RGB color in [0, 1] value range _color = link_color[sk_id][::-1] / 255. ax.plot( xs_3d, ys_3d, zs_3d, color=_color, zdir='z') if 'keypoints' in pred_instances: keypoints = pred_instances.get('keypoints', pred_instances.keypoints) if 'keypoint_scores' in pred_instances: scores = pred_instances.keypoint_scores else: scores = np.ones(keypoints.shape[:-1]) if 'keypoints_visible' in pred_instances: keypoints_visible = pred_instances.keypoints_visible else: keypoints_visible = np.ones(keypoints.shape[:-1]) _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, 1, 'Prediction') if draw_gt and 'gt_instances' in pose_samples: gt_instances = pose_samples.gt_instances if 'lifting_target' in gt_instances: keypoints = gt_instances.get('lifting_target', gt_instances.lifting_target) scores = np.ones(keypoints.shape[:-1]) if 'lifting_target_visible' in gt_instances: keypoints_visible = gt_instances.lifting_target_visible else: keypoints_visible = np.ones(keypoints.shape[:-1]) _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, 2, 'Ground Truth') # convert figure to numpy array fig.tight_layout() fig.canvas.draw() pred_img_data = fig.canvas.tostring_rgb() pred_img_data = np.frombuffer( fig.canvas.tostring_rgb(), dtype=np.uint8) if not pred_img_data.any(): pred_img_data = np.full((vis_height, vis_width, 3), 255) else: pred_img_data = pred_img_data.reshape(vis_height, vis_width * num_instances, -1) plt.close(fig) return pred_img_data def _draw_instances_kpts(self, image: np.ndarray, instances: InstanceData, kpt_thr: float = 0.3, show_kpt_idx: bool = False, skeleton_style: str = 'mmpose'): """Draw keypoints and skeletons (optional) of GT or prediction. Args: image (np.ndarray): The image to draw. instances (:obj:`InstanceData`): Data structure for instance-level annotations or predictions. kpt_thr (float, optional): Minimum threshold of keypoints to be shown. Default: 0.3. show_kpt_idx (bool): Whether to show the index of keypoints. Defaults to ``False`` skeleton_style (str): Skeleton style selection. Defaults to ``'mmpose'`` Returns: np.ndarray: the drawn image which channel is RGB. """ self.set_image(image) img_h, img_w, _ = image.shape if 'keypoints' in instances: keypoints = instances.get('transformed_keypoints', instances.keypoints) if 'keypoint_scores' in instances: scores = instances.keypoint_scores else: scores = np.ones(keypoints.shape[:-1]) if 'keypoints_visible' in instances: keypoints_visible = instances.keypoints_visible else: keypoints_visible = np.ones(keypoints.shape[:-1]) if skeleton_style == 'openpose': keypoints_info = np.concatenate( (keypoints, scores[..., None], keypoints_visible[..., None]), axis=-1) # compute neck joint neck = np.mean(keypoints_info[:, [5, 6]], axis=1) # neck score when visualizing pred neck[:, 2:4] = np.logical_and( keypoints_info[:, 5, 2:4] > kpt_thr, keypoints_info[:, 6, 2:4] > kpt_thr).astype(int) new_keypoints_info = np.insert( keypoints_info, 17, neck, axis=1) mmpose_idx = [ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 ] openpose_idx = [ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 ] new_keypoints_info[:, openpose_idx] = \ new_keypoints_info[:, mmpose_idx] keypoints_info = new_keypoints_info keypoints, scores, keypoints_visible = keypoints_info[ ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] kpt_color = self.kpt_color if self.det_kpt_color is not None: kpt_color = self.det_kpt_color for kpts, score, visible in zip(keypoints, scores, keypoints_visible): kpts = np.array(kpts, copy=False) if kpt_color is None or isinstance(kpt_color, str): kpt_color = [kpt_color] * len(kpts) elif len(kpt_color) == len(kpts): kpt_color = kpt_color else: raise ValueError(f'the length of kpt_color ' f'({len(kpt_color)}) does not matches ' f'that of keypoints ({len(kpts)})') # draw each point on image for kid, kpt in enumerate(kpts): if score[kid] < kpt_thr or not visible[ kid] or kpt_color[kid] is None: # skip the point that should not be drawn continue color = kpt_color[kid] if not isinstance(color, str): color = tuple(int(c) for c in color) transparency = self.alpha if self.show_keypoint_weight: transparency *= max(0, min(1, score[kid])) self.draw_circles( kpt, radius=np.array([self.radius]), face_colors=color, edge_colors=color, alpha=transparency, line_widths=self.radius) if show_kpt_idx: self.draw_texts( str(kid), kpt, colors=color, font_sizes=self.radius * 3, vertical_alignments='bottom', horizontal_alignments='center') # draw links skeleton = self.skeleton if self.det_dataset_skeleton is not None: skeleton = self.det_dataset_skeleton link_color = self.link_color if self.det_dataset_link_color is not None: link_color = self.det_dataset_link_color if skeleton is not None and link_color is not None: if link_color is None or isinstance(link_color, str): link_color = [link_color] * len(skeleton) elif len(link_color) == len(skeleton): link_color = link_color else: raise ValueError( f'the length of link_color ' f'({len(link_color)}) does not matches ' f'that of skeleton ({len(skeleton)})') for sk_id, sk in enumerate(skeleton): pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) if not (visible[sk[0]] and visible[sk[1]]): continue if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0 or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or score[sk[0]] < kpt_thr or score[sk[1]] < kpt_thr or link_color[sk_id] is None): # skip the link that should not be drawn continue X = np.array((pos1[0], pos2[0])) Y = np.array((pos1[1], pos2[1])) color = link_color[sk_id] if not isinstance(color, str): color = tuple(int(c) for c in color) transparency = self.alpha if self.show_keypoint_weight: transparency *= max( 0, min(1, 0.5 * (score[sk[0]] + score[sk[1]]))) if skeleton_style == 'openpose': mX = np.mean(X) mY = np.mean(Y) length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 angle = math.degrees( math.atan2(Y[0] - Y[1], X[0] - X[1])) stickwidth = 2 polygons = cv2.ellipse2Poly( (int(mX), int(mY)), (int(length / 2), int(stickwidth)), int(angle), 0, 360, 1) self.draw_polygons( polygons, edge_colors=color, face_colors=color, alpha=transparency) else: self.draw_lines( X, Y, color, line_widths=self.line_width) return self.get_image() @master_only def add_datasample(self, name: str, image: np.ndarray, data_sample: PoseDataSample, det_data_sample: Optional[PoseDataSample] = None, draw_gt: bool = True, draw_pred: bool = True, draw_2d: bool = True, draw_bbox: bool = False, show_kpt_idx: bool = False, skeleton_style: str = 'mmpose', num_instances: int = -1, show: bool = False, wait_time: float = 0, out_file: Optional[str] = None, kpt_thr: float = 0.3, step: int = 0) -> None: """Draw datasample and save to all backends. - If GT and prediction are plotted at the same time, they are displayed in a stitched image where the left image is the ground truth and the right image is the prediction. - If ``show`` is True, all storage backends are ignored, and the images will be displayed in a local window. - If ``out_file`` is specified, the drawn image will be saved to ``out_file``. t is usually used when the display is not available. Args: name (str): The image identifier image (np.ndarray): The image to draw data_sample (:obj:`PoseDataSample`): The 3d data sample to visualize det_data_sample (:obj:`PoseDataSample`, optional): The 2d detection data sample to visualize draw_gt (bool): Whether to draw GT PoseDataSample. Default to ``True`` draw_pred (bool): Whether to draw Prediction PoseDataSample. Defaults to ``True`` draw_2d (bool): Whether to draw 2d detection results. Defaults to ``True`` draw_bbox (bool): Whether to draw bounding boxes. Default to ``False`` show_kpt_idx (bool): Whether to show the index of keypoints. Defaults to ``False`` skeleton_style (str): Skeleton style selection. Defaults to ``'mmpose'`` num_instances (int): Number of instances to be shown in 3D. If smaller than 0, all the instances in the pose_result will be shown. Otherwise, pad or truncate the pose_result to a length of num_instances. Defaults to -1 show (bool): Whether to display the drawn image. Default to ``False`` wait_time (float): The interval of show (s). Defaults to 0 out_file (str): Path to output file. Defaults to ``None`` kpt_thr (float, optional): Minimum threshold of keypoints to be shown. Default: 0.3. step (int): Global step value to record. Defaults to 0 """ det_img_data = None gt_img_data = None if draw_2d: det_img_data = image.copy() # draw bboxes & keypoints if 'pred_instances' in det_data_sample: det_img_data = self._draw_instances_kpts( det_img_data, det_data_sample.pred_instances, kpt_thr, show_kpt_idx, skeleton_style) if draw_bbox: det_img_data = self._draw_instances_bbox( det_img_data, det_data_sample.pred_instances) pred_img_data = self._draw_3d_data_samples( image.copy(), data_sample, draw_gt=draw_gt, num_instances=num_instances) # merge visualization results if det_img_data is not None and gt_img_data is not None: drawn_img = np.concatenate( (det_img_data, pred_img_data, gt_img_data), axis=1) elif det_img_data is not None: drawn_img = np.concatenate((det_img_data, pred_img_data), axis=1) elif gt_img_data is not None: drawn_img = np.concatenate((det_img_data, gt_img_data), axis=1) else: drawn_img = pred_img_data # It is convenient for users to obtain the drawn image. # For example, the user wants to obtain the drawn image and # save it as a video during video inference. self.set_image(drawn_img) if show: self.show(drawn_img, win_name=name, wait_time=wait_time) if out_file is not None: mmcv.imwrite(drawn_img[..., ::-1], out_file) else: # save drawn_img to backends self.add_image(name, drawn_img, step) return self.get_image()