Spaces:
Build error
Build error
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from hloc import logger | |
| from hloc.utils.base_model import BaseModel | |
| sys.path.append(str(Path(__file__).parent / "../../third_party")) | |
| from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer | |
| from ASpanFormer.src.config.default import get_cfg_defaults | |
| from ASpanFormer.src.utils.misc import lower_config | |
| aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer" | |
| class ASpanFormer(BaseModel): | |
| default_conf = { | |
| "weights": "outdoor", | |
| "match_threshold": 0.2, | |
| "sinkhorn_iterations": 20, | |
| "max_keypoints": 2048, | |
| "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py", | |
| "model_name": "weights_aspanformer.tar", | |
| } | |
| required_inputs = ["image0", "image1"] | |
| proxy = "http://localhost:1080" | |
| aspanformer_models = { | |
| "weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t" | |
| } | |
| def _init(self, conf): | |
| model_path = ( | |
| aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt") | |
| ) | |
| # Download the model. | |
| if not model_path.exists(): | |
| # model_path.parent.mkdir(exist_ok=True) | |
| tar_path = aspanformer_path / conf["model_name"] | |
| if not tar_path.exists(): | |
| link = self.aspanformer_models[conf["model_name"]] | |
| cmd = [ | |
| "gdown", | |
| link, | |
| "-O", | |
| str(tar_path), | |
| "--proxy", | |
| self.proxy, | |
| ] | |
| cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)] | |
| logger.info( | |
| f"Downloading the Aspanformer model with `{cmd_wo_proxy}`." | |
| ) | |
| try: | |
| subprocess.run(cmd_wo_proxy, check=True) | |
| except subprocess.CalledProcessError as e: | |
| logger.info(f"Downloading failed {e}.") | |
| logger.info( | |
| f"Downloading the Aspanformer model with `{cmd}`." | |
| ) | |
| try: | |
| subprocess.run(cmd, check=True) | |
| except subprocess.CalledProcessError as e: | |
| logger.error( | |
| f"Failed to download the Aspanformer model: {e}" | |
| ) | |
| cmd = ["tar", "-xvf", str(tar_path), "-C", str(aspanformer_path)] | |
| logger.info(f"Unzip model file `{cmd}`.") | |
| subprocess.run(cmd, check=True) | |
| config = get_cfg_defaults() | |
| config.merge_from_file(conf["config_path"]) | |
| _config = lower_config(config) | |
| # update: match threshold | |
| _config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"] | |
| _config["aspan"]["match_coarse"]["skh_iters"] = conf[ | |
| "sinkhorn_iterations" | |
| ] | |
| self.net = _ASpanFormer(config=_config["aspan"]) | |
| weight_path = model_path | |
| state_dict = torch.load(str(weight_path), map_location="cpu")[ | |
| "state_dict" | |
| ] | |
| self.net.load_state_dict(state_dict, strict=False) | |
| logger.info("Loaded Aspanformer model") | |
| def _forward(self, data): | |
| data_ = { | |
| "image0": data["image0"], | |
| "image1": data["image1"], | |
| } | |
| self.net(data_, online_resize=True) | |
| pred = { | |
| "keypoints0": data_["mkpts0_f"], | |
| "keypoints1": data_["mkpts1_f"], | |
| "mconf": data_["mconf"], | |
| } | |
| scores = data_["mconf"] | |
| top_k = self.conf["max_keypoints"] | |
| if top_k is not None and len(scores) > top_k: | |
| keep = torch.argsort(scores, descending=True)[:top_k] | |
| scores = scores[keep] | |
| pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( | |
| pred["keypoints0"][keep], | |
| pred["keypoints1"][keep], | |
| scores, | |
| ) | |
| return pred | |