Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| from typing import Any, List | |
| from torch import nn | |
| from detectron2.config import CfgNode | |
| from detectron2.structures import Instances | |
| from .cycle_pix2shape import PixToShapeCycleLoss | |
| from .cycle_shape2shape import ShapeToShapeCycleLoss | |
| from .embed import EmbeddingLoss | |
| from .embed_utils import CseAnnotationsAccumulator | |
| from .mask_or_segm import MaskOrSegmentationLoss | |
| from .registry import DENSEPOSE_LOSS_REGISTRY | |
| from .soft_embed import SoftEmbeddingLoss | |
| from .utils import BilinearInterpolationHelper, LossDict, extract_packed_annotations_from_matches | |
| class DensePoseCseLoss: | |
| """ """ | |
| _EMBED_LOSS_REGISTRY = { | |
| EmbeddingLoss.__name__: EmbeddingLoss, | |
| SoftEmbeddingLoss.__name__: SoftEmbeddingLoss, | |
| } | |
| def __init__(self, cfg: CfgNode): | |
| """ | |
| Initialize CSE loss from configuration options | |
| Args: | |
| cfg (CfgNode): configuration options | |
| """ | |
| self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS | |
| self.w_embed = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT | |
| self.segm_loss = MaskOrSegmentationLoss(cfg) | |
| self.embed_loss = DensePoseCseLoss.create_embed_loss(cfg) | |
| self.do_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.ENABLED | |
| if self.do_shape2shape: | |
| self.w_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT | |
| self.shape2shape_loss = ShapeToShapeCycleLoss(cfg) | |
| self.do_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.ENABLED | |
| if self.do_pix2shape: | |
| self.w_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT | |
| self.pix2shape_loss = PixToShapeCycleLoss(cfg) | |
| def create_embed_loss(cls, cfg: CfgNode): | |
| # registry not used here, since embedding losses are currently local | |
| # and are not used anywhere else | |
| return cls._EMBED_LOSS_REGISTRY[cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME](cfg) | |
| def __call__( | |
| self, | |
| proposals_with_gt: List[Instances], | |
| densepose_predictor_outputs: Any, | |
| embedder: nn.Module, | |
| ) -> LossDict: | |
| if not len(proposals_with_gt): | |
| return self.produce_fake_losses(densepose_predictor_outputs, embedder) | |
| accumulator = CseAnnotationsAccumulator() | |
| packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator) | |
| if packed_annotations is None: | |
| return self.produce_fake_losses(densepose_predictor_outputs, embedder) | |
| h, w = densepose_predictor_outputs.embedding.shape[2:] | |
| interpolator = BilinearInterpolationHelper.from_matches( | |
| packed_annotations, | |
| (h, w), | |
| ) | |
| meshid_to_embed_losses = self.embed_loss( | |
| proposals_with_gt, | |
| densepose_predictor_outputs, | |
| packed_annotations, | |
| interpolator, | |
| embedder, | |
| ) | |
| embed_loss_dict = { | |
| f"loss_densepose_E{meshid}": self.w_embed * meshid_to_embed_losses[meshid] | |
| for meshid in meshid_to_embed_losses | |
| } | |
| all_loss_dict = { | |
| "loss_densepose_S": self.w_segm | |
| * self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations), | |
| **embed_loss_dict, | |
| } | |
| if self.do_shape2shape: | |
| all_loss_dict["loss_shape2shape"] = self.w_shape2shape * self.shape2shape_loss(embedder) | |
| if self.do_pix2shape: | |
| all_loss_dict["loss_pix2shape"] = self.w_pix2shape * self.pix2shape_loss( | |
| proposals_with_gt, densepose_predictor_outputs, packed_annotations, embedder | |
| ) | |
| return all_loss_dict | |
| def produce_fake_losses( | |
| self, densepose_predictor_outputs: Any, embedder: nn.Module | |
| ) -> LossDict: | |
| meshname_to_embed_losses = self.embed_loss.fake_values( | |
| densepose_predictor_outputs, embedder=embedder | |
| ) | |
| embed_loss_dict = { | |
| f"loss_densepose_E{mesh_name}": meshname_to_embed_losses[mesh_name] | |
| for mesh_name in meshname_to_embed_losses | |
| } | |
| all_loss_dict = { | |
| "loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs), | |
| **embed_loss_dict, | |
| } | |
| if self.do_shape2shape: | |
| all_loss_dict["loss_shape2shape"] = self.shape2shape_loss.fake_value(embedder) | |
| if self.do_pix2shape: | |
| all_loss_dict["loss_pix2shape"] = self.pix2shape_loss.fake_value( | |
| densepose_predictor_outputs, embedder | |
| ) | |
| return all_loss_dict | |