Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Dict, List, Optional, Union | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from mmocr.models.textrecog.encoders import BaseEncoder | |
| from mmocr.registry import MODELS | |
| from mmocr.structures import TextSpottingDataSample | |
| class SPTSEncoder(BaseEncoder): | |
| """SPTS Encoder. | |
| Args: | |
| d_backbone (int): Backbone output dimension. | |
| d_model (int): Model output dimension. | |
| init_cfg (dict or list[dict], optional): Initialization configs. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| d_backbone: int = 2048, | |
| d_model: int = 256, | |
| init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| self.input_proj = nn.Conv2d(d_backbone, d_model, kernel_size=1) | |
| def forward(self, | |
| feat: Tensor, | |
| data_samples: List[TextSpottingDataSample] = None) -> Tensor: | |
| """Forward propagation of encoder. | |
| Args: | |
| feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. | |
| data_samples (list[TextSpottingDataSample]): Batch of | |
| TextSpottingDataSample. | |
| Defaults to None. | |
| Returns: | |
| Tensor: A tensor of shape :math:`(N, T, D_m)`. | |
| """ | |
| return self.input_proj(feat[0]) | |