Spaces:
Runtime error
Runtime error
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from .. import logger | |
| from ..utils.base_model import BaseModel | |
| thirdparty_path = Path(__file__).parent / "../../third_party" | |
| sys.path.append(str(thirdparty_path)) | |
| from omniglue.src import omniglue | |
| omniglue_path = thirdparty_path / "omniglue" | |
| class OmniGlue(BaseModel): | |
| default_conf = { | |
| "match_threshold": 0.02, | |
| "max_keypoints": 2048, | |
| } | |
| required_inputs = ["image0", "image1"] | |
| dino_v2_link_dict = { | |
| "dinov2_vitb14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" | |
| } | |
| def _init(self, conf): | |
| logger.info("Loading OmniGlue model") | |
| og_model_path = omniglue_path / "models" / "omniglue.onnx" | |
| sp_model_path = omniglue_path / "models" / "sp_v6.onnx" | |
| dino_model_path = ( | |
| omniglue_path / "models" / "dinov2_vitb14_pretrain.pth" # ~330MB | |
| ) | |
| if not dino_model_path.exists(): | |
| link = self.dino_v2_link_dict.get(dino_model_path.name, None) | |
| if link is not None: | |
| cmd = ["wget", "--quiet", link, "-O", str(dino_model_path)] | |
| logger.info(f"Downloading the dinov2 model with `{cmd}`.") | |
| subprocess.run(cmd, check=True) | |
| else: | |
| logger.error(f"Invalid dinov2 model: {dino_model_path.name}") | |
| self.net = omniglue.OmniGlue( | |
| og_export=str(og_model_path), | |
| sp_export=str(sp_model_path), | |
| dino_export=str(dino_model_path), | |
| max_keypoints=self.conf["max_keypoints"], | |
| ) | |
| logger.info("Loaded OmniGlue model done!") | |
| def _forward(self, data): | |
| image0_rgb_np = data["image0"][0].permute(1, 2, 0).cpu().numpy() * 255 | |
| image1_rgb_np = data["image1"][0].permute(1, 2, 0).cpu().numpy() * 255 | |
| image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255 | |
| image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255 | |
| match_kp0, match_kp1, match_confidences = self.net.FindMatches( | |
| image0_rgb_np, image1_rgb_np, self.conf["max_keypoints"] | |
| ) | |
| # filter matches | |
| match_threshold = self.conf["match_threshold"] | |
| keep_idx = [] | |
| for i in range(match_kp0.shape[0]): | |
| if match_confidences[i] > match_threshold: | |
| keep_idx.append(i) | |
| scores = torch.from_numpy(match_confidences[keep_idx]).reshape(-1, 1) | |
| pred = { | |
| "keypoints0": torch.from_numpy(match_kp0[keep_idx]), | |
| "keypoints1": torch.from_numpy(match_kp1[keep_idx]), | |
| "mconf": scores, | |
| } | |
| top_k = self.conf["max_keypoints"] | |
| if top_k is not None and len(scores) > top_k: | |
| keep = torch.argsort(scores, descending=True)[:top_k] | |
| scores = scores[keep] | |
| pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( | |
| pred["keypoints0"][keep], | |
| pred["keypoints1"][keep], | |
| scores, | |
| ) | |
| return pred | |