Spaces:
Runtime error
Runtime error
| import pprint | |
| from abc import ABCMeta, abstractmethod | |
| import torch | |
| from itertools import chain | |
| from src.utils.plotting import make_matching_figure, error_colormap | |
| from src.utils.metrics import aggregate_metrics | |
| def flatten_list(x): | |
| return list(chain(*x)) | |
| class Viz(metaclass=ABCMeta): | |
| def __init__(self): | |
| super().__init__() | |
| self.device = torch.device( | |
| "cuda:{}".format(0) if torch.cuda.is_available() else "cpu" | |
| ) | |
| torch.set_grad_enabled(False) | |
| # for evaluation metrics of MegaDepth and ScanNet | |
| self.eval_stats = [] | |
| self.time_stats = [] | |
| def draw_matches(self, mkpts0, mkpts1, img0, img1, conf, path=None, **kwargs): | |
| thr = 5e-4 | |
| # mkpts0 = pe['mkpts0_f'].cpu().numpy() | |
| # mkpts1 = pe['mkpts1_f'].cpu().numpy() | |
| if "conf_thr" in kwargs: | |
| thr = kwargs["conf_thr"] | |
| color = error_colormap(conf, thr, alpha=0.1) | |
| text = [ | |
| f"{self.name}", | |
| f"#Matches: {len(mkpts0)}", | |
| ] | |
| if "R_errs" in kwargs: | |
| text.append( | |
| f"$\\Delta$R:{kwargs['R_errs']:.2f}°, $\\Delta$t:{kwargs['t_errs']:.2f}°", | |
| ) | |
| if path: | |
| make_matching_figure( | |
| img0, img1, mkpts0, mkpts1, color, text=text, path=path, dpi=150 | |
| ) | |
| else: | |
| return make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text) | |
| def match_and_draw(self, data_dict, **kwargs): | |
| pass | |
| def compute_eval_metrics(self, epi_err_thr=5e-4): | |
| # metrics: dict of list, numpy | |
| _metrics = [o["metrics"] for o in self.eval_stats] | |
| metrics = {k: flatten_list([_me[k] for _me in _metrics]) for k in _metrics[0]} | |
| val_metrics_4tb = aggregate_metrics(metrics, epi_err_thr) | |
| print("\n" + pprint.pformat(val_metrics_4tb)) | |
| def measure_time(self): | |
| if len(self.time_stats) == 0: | |
| return 0 | |
| return sum(self.time_stats) / len(self.time_stats) | |