Spaces:
Runtime error
Runtime error
| import torch.nn.functional as F | |
| from omegaconf import OmegaConf | |
| from .. import get_model | |
| from ..base_model import BaseModel | |
| to_ctr = OmegaConf.to_container # convert DictConfig to dict | |
| class MixedExtractor(BaseModel): | |
| default_conf = { | |
| "detector": {"name": None}, | |
| "descriptor": {"name": None}, | |
| "interpolate_descriptors_from": None, # field name | |
| } | |
| required_data_keys = ["image"] | |
| required_cache_keys = [] | |
| def _init(self, conf): | |
| if conf.detector.name: | |
| self.detector = get_model(conf.detector.name)(to_ctr(conf.detector)) | |
| else: | |
| self.required_data_keys += ["cache"] | |
| self.required_cache_keys += ["keypoints"] | |
| if conf.descriptor.name: | |
| self.descriptor = get_model(conf.descriptor.name)(to_ctr(conf.descriptor)) | |
| else: | |
| self.required_data_keys += ["cache"] | |
| self.required_cache_keys += ["descriptors"] | |
| def _forward(self, data): | |
| if self.conf.detector.name: | |
| pred = self.detector(data) | |
| else: | |
| pred = data["cache"] | |
| if self.conf.detector.name: | |
| pred = {**pred, **self.descriptor({**pred, **data})} | |
| if self.conf.interpolate_descriptors_from: | |
| h, w = data["image"].shape[-2:] | |
| kpts = pred["keypoints"] | |
| pts = (kpts / kpts.new_tensor([[w, h]]) * 2 - 1)[:, None] | |
| pred["descriptors"] = ( | |
| F.grid_sample( | |
| pred[self.conf.interpolate_descriptors_from], | |
| pts, | |
| align_corners=False, | |
| mode="bilinear", | |
| ) | |
| .squeeze(-2) | |
| .transpose(-2, -1) | |
| .contiguous() | |
| ) | |
| return pred | |
| def loss(self, pred, data): | |
| losses = {} | |
| metrics = {} | |
| total = 0 | |
| for k in ["detector", "descriptor"]: | |
| apply = True | |
| if "apply_loss" in self.conf[k].keys(): | |
| apply = self.conf[k].apply_loss | |
| if self.conf[k].name and apply: | |
| try: | |
| losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data}) | |
| except NotImplementedError: | |
| continue | |
| losses = {**losses, **losses_} | |
| metrics = {**metrics, **metrics_} | |
| total = losses_["total"] + total | |
| return {**losses, "total": total}, metrics | |