Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import os, glob, cv2 | |
| import argparse | |
| from argparse import Namespace | |
| import yaml | |
| from tqdm import tqdm | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, SequentialSampler | |
| from src.datasets.custom_dataloader import TestDataLoader | |
| from src.utils.dataset import read_img_gray | |
| from configs.data.base import cfg as data_cfg | |
| import viz | |
| def get_model_config(method_name, dataset_name, root_dir="viz"): | |
| config_file = f"{root_dir}/configs/{method_name}.yml" | |
| with open(config_file, "r") as f: | |
| model_conf = yaml.load(f, Loader=yaml.FullLoader)[dataset_name] | |
| return model_conf | |
| class DemoDataset(Dataset): | |
| def __init__(self, dataset_dir, img_file=None, resize=0, down_factor=16): | |
| self.dataset_dir = dataset_dir | |
| if img_file is None: | |
| self.list_img_files = glob.glob(os.path.join(dataset_dir, "*.*")) | |
| self.list_img_files.sort() | |
| else: | |
| with open(img_file) as f: | |
| self.list_img_files = [ | |
| os.path.join(dataset_dir, img_file.strip()) | |
| for img_file in f.readlines() | |
| ] | |
| self.resize = resize | |
| self.down_factor = down_factor | |
| def __len__(self): | |
| return len(self.list_img_files) | |
| def __getitem__(self, idx): | |
| img_path = self.list_img_files[ | |
| idx | |
| ] # os.path.join(self.dataset_dir, self.list_img_files[idx]) | |
| img, scale = read_img_gray( | |
| img_path, resize=self.resize, down_factor=self.down_factor | |
| ) | |
| return {"img": img, "id": idx, "img_path": img_path} | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Visualize matches") | |
| parser.add_argument("--gpu", "-gpu", type=str, default="0") | |
| parser.add_argument("--method", type=str, default=None) | |
| parser.add_argument("--dataset_dir", type=str, default="data/aachen-day-night") | |
| parser.add_argument("--pair_dir", type=str, default=None) | |
| parser.add_argument( | |
| "--dataset_name", | |
| type=str, | |
| choices=["megadepth", "scannet", "aachen_v1.1", "inloc"], | |
| default="megadepth", | |
| ) | |
| parser.add_argument("--measure_time", action="store_true") | |
| parser.add_argument("--no_viz", action="store_true") | |
| parser.add_argument("--compute_eval_metrics", action="store_true") | |
| parser.add_argument("--run_demo", action="store_true") | |
| args = parser.parse_args() | |
| model_cfg = get_model_config(args.method, args.dataset_name) | |
| class_name = model_cfg["class"] | |
| model = viz.__dict__[class_name](model_cfg) | |
| # all_args = Namespace(**vars(args), **model_cfg) | |
| if not args.run_demo: | |
| if args.dataset_name == "megadepth": | |
| from configs.data.megadepth_test_1500 import cfg | |
| data_cfg.merge_from_other_cfg(cfg) | |
| elif args.dataset_name == "scannet": | |
| from configs.data.scannet_test_1500 import cfg | |
| data_cfg.merge_from_other_cfg(cfg) | |
| elif args.dataset_name == "aachen_v1.1": | |
| data_cfg.merge_from_list( | |
| [ | |
| "DATASET.TEST_DATA_SOURCE", | |
| "aachen_v1.1", | |
| "DATASET.TEST_DATA_ROOT", | |
| os.path.join(args.dataset_dir, "images/images_upright"), | |
| "DATASET.TEST_LIST_PATH", | |
| args.pair_dir, | |
| "DATASET.TEST_IMGSIZE", | |
| model_cfg["imsize"], | |
| ] | |
| ) | |
| elif args.dataset_name == "inloc": | |
| data_cfg.merge_from_list( | |
| [ | |
| "DATASET.TEST_DATA_SOURCE", | |
| "inloc", | |
| "DATASET.TEST_DATA_ROOT", | |
| args.dataset_dir, | |
| "DATASET.TEST_LIST_PATH", | |
| args.pair_dir, | |
| "DATASET.TEST_IMGSIZE", | |
| model_cfg["imsize"], | |
| ] | |
| ) | |
| has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in [ | |
| "megadepth", | |
| "scannet", | |
| ] | |
| dataloader = TestDataLoader(data_cfg) | |
| with torch.no_grad(): | |
| for data_dict in tqdm(dataloader): | |
| for k, v in data_dict.items(): | |
| if isinstance(v, torch.Tensor): | |
| data_dict[k] = v.cuda() if torch.cuda.is_available() else v | |
| img_root_dir = data_cfg.DATASET.TEST_DATA_ROOT | |
| model.match_and_draw( | |
| data_dict, | |
| root_dir=img_root_dir, | |
| ground_truth=has_ground_truth, | |
| measure_time=args.measure_time, | |
| viz_matches=(not args.no_viz), | |
| ) | |
| if args.measure_time: | |
| print( | |
| "Running time for each image is {} miliseconds".format( | |
| model.measure_time() | |
| ) | |
| ) | |
| if args.compute_eval_metrics and has_ground_truth: | |
| model.compute_eval_metrics() | |
| else: | |
| demo_dataset = DemoDataset(args.dataset_dir, img_file=args.pair_dir, resize=640) | |
| sampler = SequentialSampler(demo_dataset) | |
| dataloader = DataLoader(demo_dataset, batch_size=1, sampler=sampler) | |
| writer = cv2.VideoWriter( | |
| "topicfm_demo.mp4", | |
| cv2.VideoWriter_fourcc(*"mp4v"), | |
| 15, | |
| (640 * 2 + 5, 480 * 2 + 10), | |
| ) | |
| model.run_demo( | |
| iter(dataloader), writer | |
| ) # , output_dir="demo", no_display=True) | |