Spaces:
Build error
Build error
| import sys | |
| from pathlib import Path | |
| import torch | |
| from ..utils.base_model import BaseModel | |
| sys.path.append(str(Path(__file__).parent / "../../third_party")) | |
| from TopicFM.src import get_model_cfg | |
| from TopicFM.src.models.topic_fm import TopicFM as _TopicFM | |
| topicfm_path = Path(__file__).parent / "../../third_party/TopicFM" | |
| class TopicFM(BaseModel): | |
| default_conf = { | |
| "weights": "outdoor", | |
| "match_threshold": 0.2, | |
| "n_sampling_topics": 4, | |
| "max_keypoints": -1, | |
| } | |
| required_inputs = ["image0", "image1"] | |
| def _init(self, conf): | |
| _conf = dict(get_model_cfg()) | |
| _conf["match_coarse"]["thr"] = conf["match_threshold"] | |
| _conf["coarse"]["n_samples"] = conf["n_sampling_topics"] | |
| weight_path = topicfm_path / "pretrained/model_best.ckpt" | |
| self.net = _TopicFM(config=_conf) | |
| ckpt_dict = torch.load(weight_path, map_location="cpu") | |
| self.net.load_state_dict(ckpt_dict["state_dict"]) | |
| def _forward(self, data): | |
| data_ = { | |
| "image0": data["image0"], | |
| "image1": data["image1"], | |
| } | |
| self.net(data_) | |
| pred = { | |
| "keypoints0": data_["mkpts0_f"], | |
| "keypoints1": data_["mkpts1_f"], | |
| "mconf": data_["mconf"], | |
| } | |
| scores = data_["mconf"] | |
| top_k = self.conf["max_keypoints"] | |
| if top_k is not None and len(scores) > top_k: | |
| keep = torch.argsort(scores, descending=True)[:top_k] | |
| scores = scores[keep] | |
| pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( | |
| pred["keypoints0"][keep], | |
| pred["keypoints1"][keep], | |
| scores, | |
| ) | |
| return pred | |