| import torch | |
| import torch.nn as nn | |
| import pandas as pd | |
| from models.networks.utils import UnormGPS | |
| class HybridHead(nn.Module): | |
| """Classification head followed by regression head for the network.""" | |
| def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh): | |
| super().__init__() | |
| self.final_dim = final_dim | |
| self.use_tanh = use_tanh | |
| self.scale_tanh = scale_tanh | |
| self.unorm = UnormGPS() | |
| if quadtree_path is not None: | |
| quadtree = pd.read_csv(quadtree_path) | |
| self.init_quadtree(quadtree) | |
| def init_quadtree(self, quadtree): | |
| quadtree[["min_lat", "max_lat"]] /= 90.0 | |
| quadtree[["min_lon", "max_lon"]] /= 180.0 | |
| self.register_buffer( | |
| "cell_center", | |
| 0.5 * torch.tensor(quadtree[["max_lat", "max_lon"]].values) | |
| + 0.5 * torch.tensor(quadtree[["min_lat", "min_lon"]].values), | |
| ) | |
| self.register_buffer( | |
| "cell_size", | |
| torch.tensor(quadtree[["max_lat", "max_lon"]].values) | |
| - torch.tensor(quadtree[["min_lat", "min_lon"]].values), | |
| ) | |
| def forward(self, x, gt_label): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| classification_logits = x[..., : self.final_dim] | |
| classification = classification_logits.argmax(dim=-1) | |
| regression = x[..., self.final_dim :] | |
| if self.use_tanh: | |
| regression = self.scale_tanh * torch.tanh(regression) | |
| regression = regression.view(regression.shape[0], -1, 2) | |
| if self.training: | |
| regression = torch.gather( | |
| regression, | |
| 1, | |
| gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2), | |
| )[:, 0, :] | |
| size = 2.0 / self.cell_size[gt_label] | |
| center = self.cell_center[gt_label] | |
| gps = ( | |
| self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0 | |
| ) | |
| else: | |
| regression = torch.gather( | |
| regression, | |
| 1, | |
| classification.unsqueeze(-1) | |
| .unsqueeze(-1) | |
| .expand(regression.shape[0], 1, 2), | |
| )[:, 0, :] | |
| size = 2.0 / self.cell_size[classification] | |
| center = self.cell_center[classification] | |
| gps = ( | |
| self.cell_center[classification] | |
| + regression * self.cell_size[classification] / 2.0 | |
| ) | |
| gps = self.unorm(gps) | |
| return { | |
| "label": classification_logits, | |
| "gps": gps, | |
| "size": size, | |
| "center": center, | |
| "reg": regression, | |
| } | |
| class HybridHeadCentroid(nn.Module): | |
| """Classification head followed by regression head for the network.""" | |
| def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh): | |
| super().__init__() | |
| self.final_dim = final_dim | |
| self.use_tanh = use_tanh | |
| self.scale_tanh = scale_tanh | |
| self.unorm = UnormGPS() | |
| if quadtree_path is not None: | |
| quadtree = pd.read_csv(quadtree_path) | |
| self.init_quadtree(quadtree) | |
| def init_quadtree(self, quadtree): | |
| quadtree[["min_lat", "max_lat", "mean_lat"]] /= 90.0 | |
| quadtree[["min_lon", "max_lon", "mean_lon"]] /= 180.0 | |
| self.cell_center = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) | |
| self.cell_size_up = torch.tensor(quadtree[["max_lat", "max_lon"]].values) - torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) | |
| self.cell_size_down = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) - torch.tensor(quadtree[["min_lat", "min_lon"]].values) | |
| def forward(self, x, gt_label): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| classification_logits = x[..., : self.final_dim] | |
| classification = classification_logits.argmax(dim=-1) | |
| self.cell_size_up = self.cell_size_up.to(classification.device) | |
| self.cell_center = self.cell_center.to(classification.device) | |
| self.cell_size_down = self.cell_size_down.to(classification.device) | |
| regression = x[..., self.final_dim :] | |
| if self.use_tanh: | |
| regression = self.scale_tanh * torch.tanh(regression) | |
| regression = regression.view(regression.shape[0], -1, 2) | |
| if self.training: | |
| regression = torch.gather( | |
| regression, | |
| 1, | |
| gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2), | |
| )[:, 0, :] | |
| size = torch.where( | |
| regression > 0, | |
| self.cell_size_up[gt_label], | |
| self.cell_size_down[gt_label], | |
| ) | |
| center = self.cell_center[gt_label] | |
| gps = self.cell_center[gt_label] + regression * size | |
| else: | |
| regression = torch.gather( | |
| regression, | |
| 1, | |
| classification.unsqueeze(-1) | |
| .unsqueeze(-1) | |
| .expand(regression.shape[0], 1, 2), | |
| )[:, 0, :] | |
| size = torch.where( | |
| regression > 0, | |
| self.cell_size_up[classification], | |
| self.cell_size_down[classification], | |
| ) | |
| center = self.cell_center[classification] | |
| gps = self.cell_center[classification] + regression * size | |
| gps = self.unorm(gps) | |
| return { | |
| "label": classification_logits, | |
| "gps": gps, | |
| "size": 1.0 / size, | |
| "center": center, | |
| "reg": regression, | |
| } | |
| class SharedHybridHead(HybridHead): | |
| """Classification head followed by SHARED regression head for the network.""" | |
| def forward(self, x, gt_label): | |
| """Forward pass of the network. | |
| x : Union[torch.Tensor, dict] with the output of the backbone. | |
| """ | |
| classification_logits = x[..., : self.final_dim] | |
| classification = classification_logits.argmax(dim=-1) | |
| regression = x[..., self.final_dim :] | |
| if self.use_tanh: | |
| regression = self.scale_tanh * torch.tanh(regression) | |
| if self.training: | |
| gps = ( | |
| self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0 | |
| ) | |
| else: | |
| gps = ( | |
| self.cell_center[classification] | |
| + regression * self.cell_size[classification] / 2.0 | |
| ) | |
| gps = self.unorm(gps) | |
| return {"label": classification_logits, "gps": gps} | |