| import torch | |
| from metrics.utils import haversine, reverse | |
| from torchmetrics import Metric | |
| class HaversineMetrics(Metric): | |
| """ | |
| Computes the average haversine distance between the predicted and ground truth points. | |
| Compute the accuracy given some radiuses. | |
| Compute the Geoguessr score given some radiuses. | |
| Args: | |
| acc_radiuses (list): list of radiuses to compute the accuracy from | |
| acc_area (list): list of areas to compute the accuracy from. | |
| acc_data (list): list of auxilliary data to compute the accuracy from. | |
| """ | |
| def __init__( | |
| self, | |
| acc_radiuses=[], | |
| acc_area=["country", "region", "sub-region", "city"], | |
| aux_data=[], | |
| ): | |
| super().__init__() | |
| self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
| self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
| for acc in acc_radiuses: | |
| self.add_state( | |
| f"close_enough_points_{acc}", | |
| default=torch.tensor(0.0), | |
| dist_reduce_fx="sum", | |
| ) | |
| for acc in acc_area: | |
| self.add_state( | |
| f"close_enough_points_{acc}", | |
| default=torch.tensor(0.0), | |
| dist_reduce_fx="sum", | |
| ) | |
| self.add_state( | |
| f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum" | |
| ) | |
| self.acc_radius = acc_radiuses | |
| self.acc_area = acc_area | |
| self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") | |
| self.aux = len(aux_data) > 0 | |
| self.aux_list = aux_data | |
| if self.aux: | |
| self.aux_count = {} | |
| for col in self.aux_list: | |
| self.add_state( | |
| f"aux_{col}", | |
| default=torch.tensor(0.0), | |
| dist_reduce_fx="sum", | |
| ) | |
| def update(self, pred, gt): | |
| haversine_distance = haversine(pred["gps"], gt["gps"]) | |
| for acc in self.acc_radius: | |
| self.__dict__[f"close_enough_points_{acc}"] += ( | |
| haversine_distance < acc | |
| ).sum() | |
| if len(self.acc_area) > 0: | |
| area_pred, area_gt = reverse(pred["gps"], gt, self.acc_area) | |
| for acc in self.acc_area: | |
| self.__dict__[f"close_enough_points_{acc}"] += ( | |
| area_pred[acc] == area_gt["_".join(["unique", acc])] | |
| ).sum() | |
| self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])]) | |
| self.haversine_sum += haversine_distance.sum() | |
| self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum() | |
| if self.aux: | |
| if "land_cover" in self.aux_list: | |
| col = "land_cover" | |
| self.__dict__[f"aux_{col}"] += ( | |
| pred[col].argmax(dim=1) == gt[col].argmax(dim=1) | |
| ).sum() | |
| if "road_index" in self.aux_list: | |
| col = "road_index" | |
| self.__dict__[f"aux_{col}"] += ( | |
| pred[col].argmax(dim=1) == gt[col].argmax(dim=1) | |
| ).sum() | |
| if "drive_side" in self.aux_list: | |
| col = "drive_side" | |
| self.__dict__[f"aux_{col}"] += ( | |
| (pred[col] > 0.5).float() == gt[col] | |
| ).sum() | |
| if "climate" in self.aux_list: | |
| col = "climate" | |
| self.__dict__[f"aux_{col}"] += ( | |
| pred[col].argmax(dim=1) == gt[col].argmax(dim=1) | |
| ).sum() | |
| if "soil" in self.aux_list: | |
| col = "soil" | |
| self.__dict__[f"aux_{col}"] += ( | |
| pred[col].argmax(dim=1) == gt[col].argmax(dim=1) | |
| ).sum() | |
| if "dist_sea" in self.aux_list: | |
| col = "dist_sea" | |
| self.__dict__[f"aux_{col}"] += ( | |
| (pred[col] - gt[col]).pow(2).sum(dim=1).sum() | |
| ) | |
| self.count += pred["gps"].shape[0] | |
| def compute(self): | |
| output = { | |
| "Haversine": self.haversine_sum / self.count, | |
| "Geoguessr": self.geoguessr_sum / self.count, | |
| } | |
| for acc in self.acc_radius: | |
| output[f"Accuracy_{acc}_km_radius"] = ( | |
| self.__dict__[f"close_enough_points_{acc}"] / self.count | |
| ) | |
| for acc in self.acc_area: | |
| output[f"Accuracy_{acc}"] = ( | |
| self.__dict__[f"close_enough_points_{acc}"] | |
| / self.__dict__[f"count_{acc}"] | |
| ) | |
| if self.aux: | |
| for col in self.aux_list: | |
| output["_".join(["Accuracy", col])] = ( | |
| self.__dict__[f"aux_{col}"] / self.count | |
| ) | |
| return output | |