Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from abc import ABCMeta, abstractmethod | |
| from typing import Tuple | |
| from mmengine.model import BaseModule | |
| from torch import Tensor | |
| from mmocr.utils import DetSampleList | |
| class BaseRoIHead(BaseModule, metaclass=ABCMeta): | |
| """Base class for RoIHeads.""" | |
| def with_rec_head(self): | |
| """bool: whether the RoI head contains a `mask_head`""" | |
| return hasattr(self, 'rec_head') and self.rec_head is not None | |
| def with_extractor(self): | |
| """bool: whether the RoI head contains a `mask_head`""" | |
| return hasattr(self, | |
| 'roi_extractor') and self.roi_extractor is not None | |
| # @abstractmethod | |
| # def init_assigner_sampler(self, *args, **kwargs): | |
| # """Initialize assigner and sampler.""" | |
| # pass | |
| def loss(self, x: Tuple[Tensor], data_samples: DetSampleList): | |
| """Perform forward propagation and loss calculation of the roi head on | |
| the features of the upstream network.""" | |
| def predict(self, x: Tuple[Tensor], | |
| data_samples: DetSampleList) -> DetSampleList: | |
| """Perform forward propagation of the roi head and predict detection | |
| results on the features of the upstream network. | |
| Args: | |
| x (tuple[Tensor]): Features from upstream network. Each | |
| has shape (N, C, H, W). | |
| data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes `gt_instance` | |
| Returns: | |
| list[obj:`DetDataSample`]: Detection results of each image. | |
| Each item usually contains following keys in 'pred_instance' | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| - polygon (List[Tensor]): Has a shape (num_instances, H, W). | |
| """ | |