Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import math | |
| import torch | |
| class BoxCoder(object): | |
| """ | |
| This class encodes and decodes a set of bounding boxes into | |
| the representation used for training the regressors. | |
| """ | |
| def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): | |
| """ | |
| Arguments: | |
| weights (4-element tuple) | |
| bbox_xform_clip (float) | |
| """ | |
| self.weights = weights | |
| self.bbox_xform_clip = bbox_xform_clip | |
| def encode(self, reference_boxes, proposals): | |
| """ | |
| Encode a set of proposals with respect to some | |
| reference boxes | |
| Arguments: | |
| reference_boxes (Tensor): reference boxes | |
| proposals (Tensor): boxes to be encoded | |
| """ | |
| TO_REMOVE = 1 # TODO remove | |
| ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE | |
| ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE | |
| ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths | |
| ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights | |
| gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE | |
| gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE | |
| gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths | |
| gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights | |
| wx, wy, ww, wh = self.weights | |
| targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths | |
| targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights | |
| targets_dw = ww * torch.log(gt_widths / ex_widths) | |
| targets_dh = wh * torch.log(gt_heights / ex_heights) | |
| targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) | |
| return targets | |
| def decode(self, rel_codes, boxes): | |
| """ | |
| From a set of original boxes and encoded relative box offsets, | |
| get the decoded boxes. | |
| Arguments: | |
| rel_codes (Tensor): encoded boxes | |
| boxes (Tensor): reference boxes. | |
| """ | |
| boxes = boxes.to(rel_codes.dtype) | |
| TO_REMOVE = 1 # TODO remove | |
| widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE | |
| heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE | |
| ctr_x = boxes[:, 0] + 0.5 * widths | |
| ctr_y = boxes[:, 1] + 0.5 * heights | |
| wx, wy, ww, wh = self.weights | |
| dx = rel_codes[:, 0::4] / wx | |
| dy = rel_codes[:, 1::4] / wy | |
| dw = rel_codes[:, 2::4] / ww | |
| dh = rel_codes[:, 3::4] / wh | |
| # Prevent sending too large values into torch.exp() | |
| dw = torch.clamp(dw, max=self.bbox_xform_clip) | |
| dh = torch.clamp(dh, max=self.bbox_xform_clip) | |
| pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] | |
| pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] | |
| pred_w = torch.exp(dw) * widths[:, None] | |
| pred_h = torch.exp(dh) * heights[:, None] | |
| pred_boxes = torch.zeros_like(rel_codes) | |
| # x1 | |
| pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w | |
| # y1 | |
| pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h | |
| # x2 (note: "- 1" is correct; don't be fooled by the asymmetry) | |
| pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1 | |
| # y2 (note: "- 1" is correct; don't be fooled by the asymmetry) | |
| pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1 | |
| return pred_boxes | |