| import torch | |
| from torch import nn | |
| from hydra.utils import instantiate | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import PyTorchModelHubMixin | |
| class Geolocalizer(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = OmegaConf.create(config) | |
| self.transform = instantiate(self.config.transform) | |
| self.model = instantiate(self.config.model) | |
| self.head = self.model.head | |
| self.mid = self.model.mid | |
| self.backbone = self.model.backbone | |
| def forward(self, img: torch.Tensor): | |
| output = self.head(self.mid(self.backbone({"img": img})), None) | |
| return output["gps"] | |
| def forward_tensor(self, img: torch.Tensor): | |
| output = self.head(self.mid(self.backbone(img)), None) | |
| return output["gps"] | |