Spaces:
Runtime error
Runtime error
| """ | |
| A two-view sparse feature matching pipeline. | |
| This model contains sub-models for each step: | |
| feature extraction, feature matching, outlier filtering, pose estimation. | |
| Each step is optional, and the features or matches can be provided as input. | |
| Default: SuperPoint with nearest neighbor matching. | |
| Convention for the matches: m0[i] is the index of the keypoint in image 1 | |
| that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched. | |
| """ | |
| from omegaconf import OmegaConf | |
| from . import get_model | |
| from .base_model import BaseModel | |
| to_ctr = OmegaConf.to_container # convert DictConfig to dict | |
| class TwoViewPipeline(BaseModel): | |
| default_conf = { | |
| "extractor": { | |
| "name": None, | |
| "trainable": False, | |
| }, | |
| "matcher": {"name": None}, | |
| "filter": {"name": None}, | |
| "solver": {"name": None}, | |
| "ground_truth": {"name": None}, | |
| "allow_no_extract": False, | |
| "run_gt_in_forward": False, | |
| } | |
| required_data_keys = ["view0", "view1"] | |
| strict_conf = False # need to pass new confs to children models | |
| components = [ | |
| "extractor", | |
| "matcher", | |
| "filter", | |
| "solver", | |
| "ground_truth", | |
| ] | |
| def _init(self, conf): | |
| if conf.extractor.name: | |
| self.extractor = get_model(conf.extractor.name)(to_ctr(conf.extractor)) | |
| if conf.matcher.name: | |
| self.matcher = get_model(conf.matcher.name)(to_ctr(conf.matcher)) | |
| if conf.filter.name: | |
| self.filter = get_model(conf.filter.name)(to_ctr(conf.filter)) | |
| if conf.solver.name: | |
| self.solver = get_model(conf.solver.name)(to_ctr(conf.solver)) | |
| if conf.ground_truth.name: | |
| self.ground_truth = get_model(conf.ground_truth.name)( | |
| to_ctr(conf.ground_truth) | |
| ) | |
| def extract_view(self, data, i): | |
| data_i = data[f"view{i}"] | |
| pred_i = data_i.get("cache", {}) | |
| skip_extract = len(pred_i) > 0 and self.conf.allow_no_extract | |
| if self.conf.extractor.name and not skip_extract: | |
| pred_i = {**pred_i, **self.extractor(data_i)} | |
| elif self.conf.extractor.name and not self.conf.allow_no_extract: | |
| pred_i = {**pred_i, **self.extractor({**data_i, **pred_i})} | |
| return pred_i | |
| def _forward(self, data): | |
| pred0 = self.extract_view(data, "0") | |
| pred1 = self.extract_view(data, "1") | |
| pred = { | |
| **{k + "0": v for k, v in pred0.items()}, | |
| **{k + "1": v for k, v in pred1.items()}, | |
| } | |
| if self.conf.matcher.name: | |
| pred = {**pred, **self.matcher({**data, **pred})} | |
| if self.conf.filter.name: | |
| pred = {**pred, **self.filter({**data, **pred})} | |
| if self.conf.solver.name: | |
| pred = {**pred, **self.solver({**data, **pred})} | |
| if self.conf.ground_truth.name and self.conf.run_gt_in_forward: | |
| gt_pred = self.ground_truth({**data, **pred}) | |
| pred.update({f"gt_{k}": v for k, v in gt_pred.items()}) | |
| return pred | |
| def loss(self, pred, data): | |
| losses = {} | |
| metrics = {} | |
| total = 0 | |
| # get labels | |
| if self.conf.ground_truth.name and not self.conf.run_gt_in_forward: | |
| gt_pred = self.ground_truth({**data, **pred}) | |
| pred.update({f"gt_{k}": v for k, v in gt_pred.items()}) | |
| for k in self.components: | |
| apply = True | |
| if "apply_loss" in self.conf[k].keys(): | |
| apply = self.conf[k].apply_loss | |
| if self.conf[k].name and apply: | |
| try: | |
| losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data}) | |
| except NotImplementedError: | |
| continue | |
| losses = {**losses, **losses_} | |
| metrics = {**metrics, **metrics_} | |
| total = losses_["total"] + total | |
| return {**losses, "total": total}, metrics | |