| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import copy | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from fvcore.transforms import HFlipTransform, TransformList | 
					
					
						
						| 
							 | 
						from torch.nn import functional as F | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from detectron2.data.transforms import RandomRotation, RotationTransform, apply_transform_gens | 
					
					
						
						| 
							 | 
						from detectron2.modeling.postprocessing import detector_postprocess | 
					
					
						
						| 
							 | 
						from detectron2.modeling.test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from ..converters import HFlipConverter | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class DensePoseDatasetMapperTTA(DatasetMapperTTA): | 
					
					
						
						| 
							 | 
						    def __init__(self, cfg): | 
					
					
						
						| 
							 | 
						        super().__init__(cfg=cfg) | 
					
					
						
						| 
							 | 
						        self.angles = cfg.TEST.AUG.ROTATION_ANGLES | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __call__(self, dataset_dict): | 
					
					
						
						| 
							 | 
						        ret = super().__call__(dataset_dict=dataset_dict) | 
					
					
						
						| 
							 | 
						        numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy() | 
					
					
						
						| 
							 | 
						        for angle in self.angles: | 
					
					
						
						| 
							 | 
						            rotate = RandomRotation(angle=angle, expand=True) | 
					
					
						
						| 
							 | 
						            new_numpy_image, tfms = apply_transform_gens([rotate], np.copy(numpy_image)) | 
					
					
						
						| 
							 | 
						            torch_image = torch.from_numpy(np.ascontiguousarray(new_numpy_image.transpose(2, 0, 1))) | 
					
					
						
						| 
							 | 
						            dic = copy.deepcopy(dataset_dict) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            dic["transforms"] = TransformList( | 
					
					
						
						| 
							 | 
						                [ret[-1]["transforms"].transforms[0]] + tfms.transforms | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            dic["image"] = torch_image | 
					
					
						
						| 
							 | 
						            ret.append(dic) | 
					
					
						
						| 
							 | 
						        return ret | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class DensePoseGeneralizedRCNNWithTTA(GeneralizedRCNNWithTTA): | 
					
					
						
						| 
							 | 
						    def __init__(self, cfg, model, transform_data, tta_mapper=None, batch_size=1): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            cfg (CfgNode): | 
					
					
						
						| 
							 | 
						            model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on. | 
					
					
						
						| 
							 | 
						            transform_data (DensePoseTransformData): contains symmetry label | 
					
					
						
						| 
							 | 
						                transforms used for horizontal flip | 
					
					
						
						| 
							 | 
						            tta_mapper (callable): takes a dataset dict and returns a list of | 
					
					
						
						| 
							 | 
						                augmented versions of the dataset dict. Defaults to | 
					
					
						
						| 
							 | 
						                `DatasetMapperTTA(cfg)`. | 
					
					
						
						| 
							 | 
						            batch_size (int): batch the augmented images into this batch size for inference. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self._transform_data = transform_data.to(model.device) | 
					
					
						
						| 
							 | 
						        super().__init__(cfg=cfg, model=model, tta_mapper=tta_mapper, batch_size=batch_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def _inference_one_image(self, input): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            input (dict): one dataset dict with "image" field being a CHW tensor | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            dict: one output dict | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        orig_shape = (input["height"], input["width"]) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        input["image"] = input["image"].to(torch.uint8) | 
					
					
						
						| 
							 | 
						        augmented_inputs, tfms = self._get_augmented_inputs(input) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        with self._turn_off_roi_heads(["mask_on", "keypoint_on", "densepose_on"]): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms) | 
					
					
						
						| 
							 | 
						        merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.cfg.MODEL.MASK_ON or self.cfg.MODEL.DENSEPOSE_ON: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            augmented_instances = self._rescale_detected_boxes( | 
					
					
						
						| 
							 | 
						                augmented_inputs, merged_instances, tfms | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            outputs = self._batch_inference(augmented_inputs, augmented_instances) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            del augmented_inputs, augmented_instances | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if self.cfg.MODEL.MASK_ON: | 
					
					
						
						| 
							 | 
						                merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms) | 
					
					
						
						| 
							 | 
						            if self.cfg.MODEL.DENSEPOSE_ON: | 
					
					
						
						| 
							 | 
						                merged_instances.pred_densepose = self._reduce_pred_densepose(outputs, tfms) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            merged_instances = detector_postprocess(merged_instances, *orig_shape) | 
					
					
						
						| 
							 | 
						            return {"instances": merged_instances} | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            return {"instances": merged_instances} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _get_augmented_boxes(self, augmented_inputs, tfms): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        outputs = self._batch_inference(augmented_inputs) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        all_boxes = [] | 
					
					
						
						| 
							 | 
						        all_scores = [] | 
					
					
						
						| 
							 | 
						        all_classes = [] | 
					
					
						
						| 
							 | 
						        for output, tfm in zip(outputs, tfms): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if not any(isinstance(t, RotationTransform) for t in tfm.transforms): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                pred_boxes = output.pred_boxes.tensor | 
					
					
						
						| 
							 | 
						                original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy()) | 
					
					
						
						| 
							 | 
						                all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device)) | 
					
					
						
						| 
							 | 
						                all_scores.extend(output.scores) | 
					
					
						
						| 
							 | 
						                all_classes.extend(output.pred_classes) | 
					
					
						
						| 
							 | 
						        all_boxes = torch.cat(all_boxes, dim=0) | 
					
					
						
						| 
							 | 
						        return all_boxes, all_scores, all_classes | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _reduce_pred_densepose(self, outputs, tfms): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        for idx, (output, tfm) in enumerate(zip(outputs, tfms)): | 
					
					
						
						| 
							 | 
						            for t in tfm.transforms: | 
					
					
						
						| 
							 | 
						                for attr in ["coarse_segm", "fine_segm", "u", "v"]: | 
					
					
						
						| 
							 | 
						                    setattr( | 
					
					
						
						| 
							 | 
						                        output.pred_densepose, | 
					
					
						
						| 
							 | 
						                        attr, | 
					
					
						
						| 
							 | 
						                        _inverse_rotation( | 
					
					
						
						| 
							 | 
						                            getattr(output.pred_densepose, attr), output.pred_boxes.tensor, t | 
					
					
						
						| 
							 | 
						                        ), | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						            if any(isinstance(t, HFlipTransform) for t in tfm.transforms): | 
					
					
						
						| 
							 | 
						                output.pred_densepose = HFlipConverter.convert( | 
					
					
						
						| 
							 | 
						                    output.pred_densepose, self._transform_data | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            self._incremental_avg_dp(outputs[0].pred_densepose, output.pred_densepose, idx) | 
					
					
						
						| 
							 | 
						        return outputs[0].pred_densepose | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def _incremental_avg_dp(self, avg, new_el, idx): | 
					
					
						
						| 
							 | 
						        for attr in ["coarse_segm", "fine_segm", "u", "v"]: | 
					
					
						
						| 
							 | 
						            setattr(avg, attr, (getattr(avg, attr) * idx + getattr(new_el, attr)) / (idx + 1)) | 
					
					
						
						| 
							 | 
						            if idx: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                setattr(new_el, attr, None) | 
					
					
						
						| 
							 | 
						        return avg | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _inverse_rotation(densepose_attrs, boxes, transform): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if len(boxes) == 0 or not isinstance(transform, RotationTransform): | 
					
					
						
						| 
							 | 
						        return densepose_attrs | 
					
					
						
						| 
							 | 
						    boxes = boxes.int().cpu().numpy() | 
					
					
						
						| 
							 | 
						    wh_boxes = boxes[:, 2:] - boxes[:, :2]   | 
					
					
						
						| 
							 | 
						    inv_boxes = rotate_box_inverse(transform, boxes).astype(int)   | 
					
					
						
						| 
							 | 
						    wh_diff = (inv_boxes[:, 2:] - inv_boxes[:, :2] - wh_boxes) // 2   | 
					
					
						
						| 
							 | 
						    rotation_matrix = torch.tensor([transform.rm_image]).to(device=densepose_attrs.device).float() | 
					
					
						
						| 
							 | 
						    rotation_matrix[:, :, -1] = 0 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    l_bds = np.maximum(0, -wh_diff) | 
					
					
						
						| 
							 | 
						    for i in range(len(densepose_attrs)): | 
					
					
						
						| 
							 | 
						        if min(wh_boxes[i]) <= 0: | 
					
					
						
						| 
							 | 
						            continue | 
					
					
						
						| 
							 | 
						        densepose_attr = densepose_attrs[[i]].clone() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        densepose_attr = F.interpolate(densepose_attr, wh_boxes[i].tolist()[::-1], mode="bilinear") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        densepose_attr = F.pad(densepose_attr, tuple(np.repeat(np.maximum(0, wh_diff[i]), 2))) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        grid = F.affine_grid(rotation_matrix, size=densepose_attr.shape) | 
					
					
						
						| 
							 | 
						        densepose_attr = F.grid_sample(densepose_attr, grid) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        r_bds = densepose_attr.shape[2:][::-1] - l_bds[i] | 
					
					
						
						| 
							 | 
						        densepose_attr = densepose_attr[:, :, l_bds[i][1] : r_bds[1], l_bds[i][0] : r_bds[0]] | 
					
					
						
						| 
							 | 
						        if min(densepose_attr.shape) > 0: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            densepose_attr = F.interpolate( | 
					
					
						
						| 
							 | 
						                densepose_attr, densepose_attrs.shape[-2:], mode="bilinear" | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            densepose_attr[:, 0] += 1e-10 | 
					
					
						
						| 
							 | 
						            densepose_attrs[i] = densepose_attr | 
					
					
						
						| 
							 | 
						    return densepose_attrs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def rotate_box_inverse(rot_tfm, rotated_box): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    rotated_box is a N * 4 array of [x0, y0, x1, y1] boxes | 
					
					
						
						| 
							 | 
						    When a bbox is rotated, it gets bigger, because we need to surround the tilted bbox | 
					
					
						
						| 
							 | 
						    So when a bbox is rotated then inverse-rotated, it is much bigger than the original | 
					
					
						
						| 
							 | 
						    This function aims to invert the rotation on the box, but also resize it to its original size | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    invrot_box = rot_tfm.inverse().apply_box(rotated_box) | 
					
					
						
						| 
							 | 
						    h, w = rotated_box[:, 3] - rotated_box[:, 1], rotated_box[:, 2] - rotated_box[:, 0] | 
					
					
						
						| 
							 | 
						    ih, iw = invrot_box[:, 3] - invrot_box[:, 1], invrot_box[:, 2] - invrot_box[:, 0] | 
					
					
						
						| 
							 | 
						    assert 2 * rot_tfm.abs_sin**2 != 1, "45 degrees angle can't be inverted" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    orig_h = (h * rot_tfm.abs_cos - w * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2) | 
					
					
						
						| 
							 | 
						    orig_w = (w * rot_tfm.abs_cos - h * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    invrot_box[:, 0] += (iw - orig_w) / 2 | 
					
					
						
						| 
							 | 
						    invrot_box[:, 1] += (ih - orig_h) / 2 | 
					
					
						
						| 
							 | 
						    invrot_box[:, 2] -= (iw - orig_w) / 2 | 
					
					
						
						| 
							 | 
						    invrot_box[:, 3] -= (ih - orig_h) / 2 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return invrot_box | 
					
					
						
						| 
							 | 
						
 |