Spaces:
Runtime error
Runtime error
add: ModelCache
Browse files- common/utils.py +73 -11
- hloc/matchers/omniglue.py +1 -0
- test_app_cli.py +36 -5
common/utils.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
import cv2
|
|
|
|
| 3 |
import torch
|
| 4 |
import random
|
|
|
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import gradio as gr
|
| 7 |
from pathlib import Path
|
|
@@ -42,6 +45,66 @@ MATCHER_ZOO = None
|
|
| 42 |
models_already_loaded = {}
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def load_config(config_name: str) -> Dict[str, Any]:
|
| 46 |
"""
|
| 47 |
Load a YAML configuration file.
|
|
@@ -579,6 +642,7 @@ def run_matching(
|
|
| 579 |
ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
|
| 580 |
choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
|
| 581 |
matcher_zoo: Dict[str, Any] = None,
|
|
|
|
| 582 |
) -> Tuple[
|
| 583 |
np.ndarray,
|
| 584 |
np.ndarray,
|
|
@@ -639,15 +703,12 @@ def run_matching(
|
|
| 639 |
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
| 640 |
t0 = time.time()
|
| 641 |
cache_key = "{}_{}".format(key, match_conf["model"]["name"])
|
| 642 |
-
|
| 643 |
-
|
| 644 |
matcher.conf["max_keypoints"] = extract_max_keypoints
|
| 645 |
matcher.conf["match_threshold"] = match_threshold
|
| 646 |
logger.info(f"Loaded cached model {cache_key}")
|
| 647 |
-
|
| 648 |
-
matcher = get_model(match_conf)
|
| 649 |
-
models_already_loaded[cache_key] = matcher
|
| 650 |
-
# gr.Info(f"Loading model using: {time.time()-t0:.3f}s")
|
| 651 |
logger.info(f"Loading model using: {time.time()-t0:.3f}s")
|
| 652 |
t1 = time.time()
|
| 653 |
|
|
@@ -663,14 +724,15 @@ def run_matching(
|
|
| 663 |
extract_conf["model"]["max_keypoints"] = extract_max_keypoints
|
| 664 |
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
| 665 |
cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
|
| 666 |
-
|
| 667 |
-
|
|
|
|
|
|
|
|
|
|
| 668 |
extractor.conf["max_keypoints"] = extract_max_keypoints
|
| 669 |
extractor.conf["keypoint_threshold"] = keypoint_threshold
|
| 670 |
logger.info(f"Loaded cached model {cache_key}")
|
| 671 |
-
|
| 672 |
-
extractor = get_feature_model(extract_conf)
|
| 673 |
-
models_already_loaded[cache_key] = extractor
|
| 674 |
pred0 = extract_features.extract(
|
| 675 |
extractor, image0, extract_conf["preprocessing"]
|
| 676 |
)
|
|
|
|
| 1 |
import os
|
| 2 |
import cv2
|
| 3 |
+
import sys
|
| 4 |
import torch
|
| 5 |
import random
|
| 6 |
+
import psutil
|
| 7 |
+
import shutil
|
| 8 |
import numpy as np
|
| 9 |
import gradio as gr
|
| 10 |
from pathlib import Path
|
|
|
|
| 45 |
models_already_loaded = {}
|
| 46 |
|
| 47 |
|
| 48 |
+
class ModelCache:
|
| 49 |
+
def __init__(self, max_memory_size: int = 8):
|
| 50 |
+
self.max_memory_size = max_memory_size
|
| 51 |
+
self.current_memory_size = 0
|
| 52 |
+
self.model_dict = {}
|
| 53 |
+
self.model_timestamps = []
|
| 54 |
+
|
| 55 |
+
def cache_model(self, model_key, model_loader_func, model_conf):
|
| 56 |
+
if model_key in self.model_dict:
|
| 57 |
+
self.model_timestamps.remove(model_key)
|
| 58 |
+
self.model_timestamps.append(model_key)
|
| 59 |
+
logger.info(f"Load cached {model_key}")
|
| 60 |
+
return self.model_dict[model_key]
|
| 61 |
+
|
| 62 |
+
model = self._load_model_from_disk(model_loader_func, model_conf)
|
| 63 |
+
while self._calculate_model_memory() > self.max_memory_size:
|
| 64 |
+
if len(self.model_timestamps) == 0:
|
| 65 |
+
logger.warn(
|
| 66 |
+
"RAM: {}GB, MAX RAM: {}GB".format(
|
| 67 |
+
self._calculate_model_memory(), self.max_memory_size
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
break
|
| 71 |
+
oldest_model_key = self.model_timestamps.pop(0)
|
| 72 |
+
self.current_memory_size = self._calculate_model_memory()
|
| 73 |
+
logger.info(f"Del cached {oldest_model_key}")
|
| 74 |
+
del self.model_dict[oldest_model_key]
|
| 75 |
+
|
| 76 |
+
self.model_dict[model_key] = model
|
| 77 |
+
self.model_timestamps.append(model_key)
|
| 78 |
+
|
| 79 |
+
self.print_memory_usage()
|
| 80 |
+
logger.info(f"Total cached {list(self.model_dict.keys())}")
|
| 81 |
+
|
| 82 |
+
return model
|
| 83 |
+
|
| 84 |
+
def _load_model_from_disk(self, model_loader_func, model_conf):
|
| 85 |
+
return model_loader_func(model_conf)
|
| 86 |
+
|
| 87 |
+
def _calculate_model_memory(self, verbose=False):
|
| 88 |
+
host_colocation = int(os.environ.get("HOST_COLOCATION", "1"))
|
| 89 |
+
vm = psutil.virtual_memory()
|
| 90 |
+
du = shutil.disk_usage(".")
|
| 91 |
+
vm_ratio = host_colocation * vm.used / vm.total
|
| 92 |
+
if verbose:
|
| 93 |
+
logger.info(
|
| 94 |
+
f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB"
|
| 95 |
+
)
|
| 96 |
+
# logger.info(
|
| 97 |
+
# f"DISK: {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}GB"
|
| 98 |
+
# )
|
| 99 |
+
return vm.used / 1e9
|
| 100 |
+
|
| 101 |
+
def print_memory_usage(self):
|
| 102 |
+
self._calculate_model_memory(verbose=True)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
model_cache = ModelCache()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
def load_config(config_name: str) -> Dict[str, Any]:
|
| 109 |
"""
|
| 110 |
Load a YAML configuration file.
|
|
|
|
| 642 |
ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
|
| 643 |
choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
|
| 644 |
matcher_zoo: Dict[str, Any] = None,
|
| 645 |
+
use_cached_model: bool = True,
|
| 646 |
) -> Tuple[
|
| 647 |
np.ndarray,
|
| 648 |
np.ndarray,
|
|
|
|
| 703 |
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
| 704 |
t0 = time.time()
|
| 705 |
cache_key = "{}_{}".format(key, match_conf["model"]["name"])
|
| 706 |
+
matcher = model_cache.cache_model(cache_key, get_model, match_conf)
|
| 707 |
+
if use_cached_model:
|
| 708 |
matcher.conf["max_keypoints"] = extract_max_keypoints
|
| 709 |
matcher.conf["match_threshold"] = match_threshold
|
| 710 |
logger.info(f"Loaded cached model {cache_key}")
|
| 711 |
+
|
|
|
|
|
|
|
|
|
|
| 712 |
logger.info(f"Loading model using: {time.time()-t0:.3f}s")
|
| 713 |
t1 = time.time()
|
| 714 |
|
|
|
|
| 724 |
extract_conf["model"]["max_keypoints"] = extract_max_keypoints
|
| 725 |
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
| 726 |
cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
|
| 727 |
+
|
| 728 |
+
extractor = model_cache.cache_model(
|
| 729 |
+
cache_key, get_feature_model, extract_conf
|
| 730 |
+
)
|
| 731 |
+
if use_cached_model:
|
| 732 |
extractor.conf["max_keypoints"] = extract_max_keypoints
|
| 733 |
extractor.conf["keypoint_threshold"] = keypoint_threshold
|
| 734 |
logger.info(f"Loaded cached model {cache_key}")
|
| 735 |
+
|
|
|
|
|
|
|
| 736 |
pred0 = extract_features.extract(
|
| 737 |
extractor, image0, extract_conf["preprocessing"]
|
| 738 |
)
|
hloc/matchers/omniglue.py
CHANGED
|
@@ -10,6 +10,7 @@ from ..utils.base_model import BaseModel
|
|
| 10 |
thirdparty_path = Path(__file__).parent / "../../third_party"
|
| 11 |
sys.path.append(str(thirdparty_path))
|
| 12 |
from omniglue.src import omniglue
|
|
|
|
| 13 |
omniglue_path = thirdparty_path / "omniglue"
|
| 14 |
|
| 15 |
|
|
|
|
| 10 |
thirdparty_path = Path(__file__).parent / "../../third_party"
|
| 11 |
sys.path.append(str(thirdparty_path))
|
| 12 |
from omniglue.src import omniglue
|
| 13 |
+
|
| 14 |
omniglue_path = thirdparty_path / "omniglue"
|
| 15 |
|
| 16 |
|
test_app_cli.py
CHANGED
|
@@ -12,11 +12,11 @@ from common.utils import (
|
|
| 12 |
from common.api import ImageMatchingAPI
|
| 13 |
|
| 14 |
|
| 15 |
-
def
|
| 16 |
img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
|
| 17 |
img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
|
| 18 |
-
image0 = cv2.imread(str(img_path1))[:, :, ::-1]
|
| 19 |
-
image1 = cv2.imread(str(img_path2))[:, :, ::-1]
|
| 20 |
|
| 21 |
matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"])
|
| 22 |
for k, v in matcher_zoo_restored.items():
|
|
@@ -27,15 +27,46 @@ def test_api(config: dict = None):
|
|
| 27 |
logger.info(f"Testing {k} ...")
|
| 28 |
api = ImageMatchingAPI(conf=v, device=device)
|
| 29 |
api(image0, image1)
|
| 30 |
-
log_path = ROOT / "
|
| 31 |
log_path.mkdir(exist_ok=True, parents=True)
|
| 32 |
api.visualize(log_path=log_path)
|
| 33 |
else:
|
| 34 |
logger.info(f"Skipping {k} ...")
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
if __name__ == "__main__":
|
| 38 |
import argparse
|
| 39 |
|
| 40 |
config = load_config(ROOT / "common/config.yaml")
|
| 41 |
-
|
|
|
|
|
|
| 12 |
from common.api import ImageMatchingAPI
|
| 13 |
|
| 14 |
|
| 15 |
+
def test_all(config: dict = None):
|
| 16 |
img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
|
| 17 |
img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
|
| 18 |
+
image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB
|
| 19 |
+
image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB
|
| 20 |
|
| 21 |
matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"])
|
| 22 |
for k, v in matcher_zoo_restored.items():
|
|
|
|
| 27 |
logger.info(f"Testing {k} ...")
|
| 28 |
api = ImageMatchingAPI(conf=v, device=device)
|
| 29 |
api(image0, image1)
|
| 30 |
+
log_path = ROOT / "experiments" / "all"
|
| 31 |
log_path.mkdir(exist_ok=True, parents=True)
|
| 32 |
api.visualize(log_path=log_path)
|
| 33 |
else:
|
| 34 |
logger.info(f"Skipping {k} ...")
|
| 35 |
|
| 36 |
|
| 37 |
+
def test_one():
|
| 38 |
+
img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
|
| 39 |
+
img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
|
| 40 |
+
image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB
|
| 41 |
+
image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB
|
| 42 |
+
|
| 43 |
+
conf = {
|
| 44 |
+
"matcher": {
|
| 45 |
+
"output": "matches-omniglue",
|
| 46 |
+
"model": {
|
| 47 |
+
"name": "omniglue",
|
| 48 |
+
"match_threshold": 0.2,
|
| 49 |
+
"features": "null",
|
| 50 |
+
},
|
| 51 |
+
"preprocessing": {
|
| 52 |
+
"grayscale": False,
|
| 53 |
+
"resize_max": 1024,
|
| 54 |
+
"dfactor": 8,
|
| 55 |
+
"force_resize": False,
|
| 56 |
+
},
|
| 57 |
+
},
|
| 58 |
+
"dense": True,
|
| 59 |
+
}
|
| 60 |
+
api = ImageMatchingAPI(conf=conf, device=device)
|
| 61 |
+
api(image0, image1)
|
| 62 |
+
log_path = ROOT / "experiments" / "one"
|
| 63 |
+
log_path.mkdir(exist_ok=True, parents=True)
|
| 64 |
+
api.visualize(log_path=log_path)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
if __name__ == "__main__":
|
| 68 |
import argparse
|
| 69 |
|
| 70 |
config = load_config(ROOT / "common/config.yaml")
|
| 71 |
+
test_one()
|
| 72 |
+
test_all(config)
|