Spaces:
Runtime error
Runtime error
update: omniglue
Browse files- common/utils.py +12 -7
- hloc/match_dense.py +1 -0
- hloc/matchers/duster.py +1 -2
- hloc/matchers/omniglue.py +1 -3
- hloc/utils/viz.py +5 -3
- third_party/omniglue/src/omniglue/omniglue_extract.py +8 -3
common/utils.py
CHANGED
|
@@ -642,7 +642,7 @@ def run_matching(
|
|
| 642 |
ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
|
| 643 |
choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
|
| 644 |
matcher_zoo: Dict[str, Any] = None,
|
| 645 |
-
use_cached_model: bool =
|
| 646 |
) -> Tuple[
|
| 647 |
np.ndarray,
|
| 648 |
np.ndarray,
|
|
@@ -696,19 +696,21 @@ def run_matching(
|
|
| 696 |
f"Success! Please be patient and allow for about 2-3 minutes."
|
| 697 |
f" Due to CPU inference, {key} is quiet slow."
|
| 698 |
)
|
|
|
|
| 699 |
model = matcher_zoo[key]
|
| 700 |
match_conf = model["matcher"]
|
| 701 |
# update match config
|
| 702 |
match_conf["model"]["match_threshold"] = match_threshold
|
| 703 |
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
| 704 |
-
t0 = time.time()
|
| 705 |
cache_key = "{}_{}".format(key, match_conf["model"]["name"])
|
| 706 |
-
matcher = model_cache.cache_model(cache_key, get_model, match_conf)
|
| 707 |
if use_cached_model:
|
|
|
|
|
|
|
| 708 |
matcher.conf["max_keypoints"] = extract_max_keypoints
|
| 709 |
matcher.conf["match_threshold"] = match_threshold
|
| 710 |
logger.info(f"Loaded cached model {cache_key}")
|
| 711 |
-
|
|
|
|
| 712 |
logger.info(f"Loading model using: {time.time()-t0:.3f}s")
|
| 713 |
t1 = time.time()
|
| 714 |
|
|
@@ -725,13 +727,16 @@ def run_matching(
|
|
| 725 |
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
| 726 |
cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
|
| 727 |
|
| 728 |
-
extractor = model_cache.cache_model(
|
| 729 |
-
cache_key, get_feature_model, extract_conf
|
| 730 |
-
)
|
| 731 |
if use_cached_model:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
extractor.conf["max_keypoints"] = extract_max_keypoints
|
| 733 |
extractor.conf["keypoint_threshold"] = keypoint_threshold
|
| 734 |
logger.info(f"Loaded cached model {cache_key}")
|
|
|
|
|
|
|
| 735 |
|
| 736 |
pred0 = extract_features.extract(
|
| 737 |
extractor, image0, extract_conf["preprocessing"]
|
|
|
|
| 642 |
ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
|
| 643 |
choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
|
| 644 |
matcher_zoo: Dict[str, Any] = None,
|
| 645 |
+
use_cached_model: bool = False,
|
| 646 |
) -> Tuple[
|
| 647 |
np.ndarray,
|
| 648 |
np.ndarray,
|
|
|
|
| 696 |
f"Success! Please be patient and allow for about 2-3 minutes."
|
| 697 |
f" Due to CPU inference, {key} is quiet slow."
|
| 698 |
)
|
| 699 |
+
t0 = time.time()
|
| 700 |
model = matcher_zoo[key]
|
| 701 |
match_conf = model["matcher"]
|
| 702 |
# update match config
|
| 703 |
match_conf["model"]["match_threshold"] = match_threshold
|
| 704 |
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
|
|
|
| 705 |
cache_key = "{}_{}".format(key, match_conf["model"]["name"])
|
|
|
|
| 706 |
if use_cached_model:
|
| 707 |
+
# because of the model cache, we need to update the config
|
| 708 |
+
matcher = model_cache.cache_model(cache_key, get_model, match_conf)
|
| 709 |
matcher.conf["max_keypoints"] = extract_max_keypoints
|
| 710 |
matcher.conf["match_threshold"] = match_threshold
|
| 711 |
logger.info(f"Loaded cached model {cache_key}")
|
| 712 |
+
else:
|
| 713 |
+
matcher = get_model(match_conf)
|
| 714 |
logger.info(f"Loading model using: {time.time()-t0:.3f}s")
|
| 715 |
t1 = time.time()
|
| 716 |
|
|
|
|
| 727 |
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
| 728 |
cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
|
| 729 |
|
|
|
|
|
|
|
|
|
|
| 730 |
if use_cached_model:
|
| 731 |
+
extractor = model_cache.cache_model(
|
| 732 |
+
cache_key, get_feature_model, extract_conf
|
| 733 |
+
)
|
| 734 |
+
# because of the model cache, we need to update the config
|
| 735 |
extractor.conf["max_keypoints"] = extract_max_keypoints
|
| 736 |
extractor.conf["keypoint_threshold"] = keypoint_threshold
|
| 737 |
logger.info(f"Loaded cached model {cache_key}")
|
| 738 |
+
else:
|
| 739 |
+
extractor = get_feature_model(extract_conf)
|
| 740 |
|
| 741 |
pred0 = extract_features.extract(
|
| 742 |
extractor, image0, extract_conf["preprocessing"]
|
hloc/match_dense.py
CHANGED
|
@@ -216,6 +216,7 @@ confs = {
|
|
| 216 |
"model": {
|
| 217 |
"name": "omniglue",
|
| 218 |
"match_threshold": 0.2,
|
|
|
|
| 219 |
"features": "null",
|
| 220 |
},
|
| 221 |
"preprocessing": {
|
|
|
|
| 216 |
"model": {
|
| 217 |
"name": "omniglue",
|
| 218 |
"match_threshold": 0.2,
|
| 219 |
+
"max_keypoints": 2000,
|
| 220 |
"features": "null",
|
| 221 |
},
|
| 222 |
"preprocessing": {
|
hloc/matchers/duster.py
CHANGED
|
@@ -105,7 +105,7 @@ class Duster(BaseModel):
|
|
| 105 |
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
|
| 106 |
*pts3d_list
|
| 107 |
)
|
| 108 |
-
|
| 109 |
mkpts1 = pts2d_list[1][reciprocal_in_P2]
|
| 110 |
mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
|
| 111 |
|
|
@@ -114,7 +114,6 @@ class Duster(BaseModel):
|
|
| 114 |
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
|
| 115 |
mkpts0 = mkpts0[keep]
|
| 116 |
mkpts1 = mkpts1[keep]
|
| 117 |
-
breakpoint()
|
| 118 |
pred = {
|
| 119 |
"keypoints0": torch.from_numpy(mkpts0),
|
| 120 |
"keypoints1": torch.from_numpy(mkpts1),
|
|
|
|
| 105 |
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
|
| 106 |
*pts3d_list
|
| 107 |
)
|
| 108 |
+
logger.info(f"Found {num_matches} matches")
|
| 109 |
mkpts1 = pts2d_list[1][reciprocal_in_P2]
|
| 110 |
mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
|
| 111 |
|
|
|
|
| 114 |
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
|
| 115 |
mkpts0 = mkpts0[keep]
|
| 116 |
mkpts1 = mkpts1[keep]
|
|
|
|
| 117 |
pred = {
|
| 118 |
"keypoints0": torch.from_numpy(mkpts0),
|
| 119 |
"keypoints1": torch.from_numpy(mkpts1),
|
hloc/matchers/omniglue.py
CHANGED
|
@@ -39,7 +39,6 @@ class OmniGlue(BaseModel):
|
|
| 39 |
subprocess.run(cmd, check=True)
|
| 40 |
else:
|
| 41 |
logger.error(f"Invalid dinov2 model: {dino_model_path.name}")
|
| 42 |
-
|
| 43 |
self.net = omniglue.OmniGlue(
|
| 44 |
og_export=str(og_model_path),
|
| 45 |
sp_export=str(sp_model_path),
|
|
@@ -54,9 +53,8 @@ class OmniGlue(BaseModel):
|
|
| 54 |
image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255
|
| 55 |
image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255
|
| 56 |
match_kp0, match_kp1, match_confidences = self.net.FindMatches(
|
| 57 |
-
image0_rgb_np, image1_rgb_np
|
| 58 |
)
|
| 59 |
-
|
| 60 |
# filter matches
|
| 61 |
match_threshold = self.conf["match_threshold"]
|
| 62 |
keep_idx = []
|
|
|
|
| 39 |
subprocess.run(cmd, check=True)
|
| 40 |
else:
|
| 41 |
logger.error(f"Invalid dinov2 model: {dino_model_path.name}")
|
|
|
|
| 42 |
self.net = omniglue.OmniGlue(
|
| 43 |
og_export=str(og_model_path),
|
| 44 |
sp_export=str(sp_model_path),
|
|
|
|
| 53 |
image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255
|
| 54 |
image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255
|
| 55 |
match_kp0, match_kp1, match_confidences = self.net.FindMatches(
|
| 56 |
+
image0_rgb_np, image1_rgb_np, self.conf["max_keypoints"]
|
| 57 |
)
|
|
|
|
| 58 |
# filter matches
|
| 59 |
match_threshold = self.conf["match_threshold"]
|
| 60 |
keep_idx = []
|
hloc/utils/viz.py
CHANGED
|
@@ -65,9 +65,11 @@ def plot_keypoints(kpts, colors="lime", ps=4):
|
|
| 65 |
if not isinstance(colors, list):
|
| 66 |
colors = [colors] * len(kpts)
|
| 67 |
axes = plt.gcf().axes
|
| 68 |
-
|
| 69 |
-
a
|
| 70 |
-
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
|
| 73 |
"""Plot matches for a pair of existing images.
|
|
|
|
| 65 |
if not isinstance(colors, list):
|
| 66 |
colors = [colors] * len(kpts)
|
| 67 |
axes = plt.gcf().axes
|
| 68 |
+
try:
|
| 69 |
+
for a, k, c in zip(axes, kpts, colors):
|
| 70 |
+
a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
|
| 71 |
+
except IndexError as e:
|
| 72 |
+
pass
|
| 73 |
|
| 74 |
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
|
| 75 |
"""Plot matches for a pair of existing images.
|
third_party/omniglue/src/omniglue/omniglue_extract.py
CHANGED
|
@@ -46,13 +46,18 @@ class OmniGlue:
|
|
| 46 |
dino_export, feature_layer=1
|
| 47 |
)
|
| 48 |
|
| 49 |
-
def FindMatches(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""TODO(omniglue): docstring."""
|
| 51 |
height0, width0 = image0.shape[:2]
|
| 52 |
height1, width1 = image1.shape[:2]
|
| 53 |
# TODO: numpy to torch inputs
|
| 54 |
-
sp_features0 = self.sp_extract(image0, num_features=
|
| 55 |
-
sp_features1 = self.sp_extract(image1, num_features=
|
| 56 |
dino_features0 = self.dino_extract(image0)
|
| 57 |
dino_features1 = self.dino_extract(image1)
|
| 58 |
dino_descriptors0 = dino_extract.get_dino_descriptors(
|
|
|
|
| 46 |
dino_export, feature_layer=1
|
| 47 |
)
|
| 48 |
|
| 49 |
+
def FindMatches(
|
| 50 |
+
self,
|
| 51 |
+
image0: np.ndarray,
|
| 52 |
+
image1: np.ndarray,
|
| 53 |
+
max_keypoints: int = 2048,
|
| 54 |
+
):
|
| 55 |
"""TODO(omniglue): docstring."""
|
| 56 |
height0, width0 = image0.shape[:2]
|
| 57 |
height1, width1 = image1.shape[:2]
|
| 58 |
# TODO: numpy to torch inputs
|
| 59 |
+
sp_features0 = self.sp_extract(image0, num_features=max_keypoints)
|
| 60 |
+
sp_features1 = self.sp_extract(image1, num_features=max_keypoints)
|
| 61 |
dino_features0 = self.dino_extract(image0)
|
| 62 |
dino_features1 = self.dino_extract(image1)
|
| 63 |
dino_descriptors0 = dino_extract.get_dino_descriptors(
|