Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import glob | |
| import os | |
| import lmdb | |
| from tqdm import tqdm | |
| from imaginaire.utils import path | |
| def construct_file_path(root, data_type, sequence, filename, ext): | |
| """Get file path for our dataset structure.""" | |
| return '%s/%s/%s/%s.%s' % (root, data_type, sequence, filename, ext) | |
| def check_and_add(filepath, key, filepaths, keys, remove_missing=False): | |
| r"""Add filepath and key to list of filepaths and keys. | |
| Args: | |
| filepath (str): Filepath to add. | |
| key (str): LMDB key for this filepath. | |
| filepaths (list): List of filepaths added so far. | |
| keys (list): List of keys added so far. | |
| remove_missing (bool): If ``True``, removes missing files, otherwise | |
| raises an error. | |
| Returns: | |
| (int): Size of file at filepath. | |
| """ | |
| if not os.path.exists(filepath): | |
| print(filepath + ' does not exist.') | |
| if remove_missing: | |
| return -1 | |
| else: | |
| raise FileNotFoundError(filepath + ' does not exist.') | |
| filepaths.append(filepath) | |
| keys.append(key) | |
| return os.path.getsize(filepath) | |
| def write_entry(txn, key, filepath): | |
| r"""Dump binary contents of file associated with key to LMDB. | |
| Args: | |
| txn: handle to LMDB. | |
| key (str): LMDB key for this filepath. | |
| filepath (str): Filepath to add. | |
| """ | |
| with open(filepath, 'rb') as f: | |
| data = f.read() | |
| txn.put(key.encode('ascii'), data) | |
| def build_lmdb(filepaths, keys, output_filepath, map_size, large): | |
| r"""Write out lmdb containing (key, contents of filepath) to file. | |
| Args: | |
| filepaths (list): List of filepath strings. | |
| keys (list): List of key strings associated with filepaths. | |
| output_filepath (str): Location to write LMDB to. | |
| map_size (int): Size of LMDB. | |
| large (bool): Is the dataset large? | |
| """ | |
| if large: | |
| db = lmdb.open(output_filepath, map_size=map_size, writemap=True) | |
| else: | |
| db = lmdb.open(output_filepath, map_size=map_size) | |
| txn = db.begin(write=True) | |
| print('Writing LMDB to:', output_filepath) | |
| for filepath, key in tqdm(zip(filepaths, keys), total=len(keys)): | |
| write_entry(txn, key, filepath) | |
| txn.commit() | |
| def get_all_filenames_from_list(list_name): | |
| r"""Get all filenames from list. | |
| Args: | |
| list_name (str): Path to filename list. | |
| Returns: | |
| all_filenames (dict): Folder name for key, and filename for values. | |
| """ | |
| with open(list_name, 'rt') as f: | |
| lines = f.readlines() | |
| lines = [line.strip() for line in lines] | |
| all_filenames = dict() | |
| for line in lines: | |
| if '/' in line: | |
| file_str = line.split('/')[0:-1] | |
| folder_name = os.path.join(*file_str) | |
| image_name = line.split('/')[-1].replace('.jpg', '') | |
| else: | |
| folder_name = '.' | |
| image_name = line.replace('.jpg', '') | |
| if folder_name in all_filenames: | |
| all_filenames[folder_name].append(image_name) | |
| else: | |
| all_filenames[folder_name] = [image_name] | |
| return all_filenames | |
| def get_lmdb_data_types(cfg): | |
| r"""Get the data types which should be put in LMDB. | |
| Args: | |
| cfg: Configuration object. | |
| """ | |
| data_types, extensions = [], [] | |
| for data_type in cfg.data.input_types: | |
| name = list(data_type.keys()) | |
| assert len(name) == 1 | |
| name = name[0] | |
| info = data_type[name] | |
| if 'computed_on_the_fly' not in info: | |
| info['computed_on_the_fly'] = False | |
| is_lmdb = not info['computed_on_the_fly'] | |
| if not is_lmdb: | |
| continue | |
| ext = info['ext'] | |
| data_types.append(name) | |
| extensions.append(ext) | |
| cfg.data.data_types = data_types | |
| cfg.data.extensions = extensions | |
| return cfg | |
| def create_metadata(data_root=None, cfg=None, paired=None, input_list=''): | |
| r"""Main function. | |
| Args: | |
| data_root (str): Location of dataset root. | |
| cfg (object): Loaded config object. | |
| paired (bool): Paired or unpaired data. | |
| input_list (str): Path to filename containing list of inputs. | |
| Returns: | |
| (tuple): | |
| - all_filenames (dict): Key of data type, values with sequences. | |
| - extensions (dict): Extension of each data type. | |
| """ | |
| cfg = get_lmdb_data_types(cfg) | |
| # Get list of all data_types in the dataset. | |
| available_data_types = path.get_immediate_subdirectories(data_root) | |
| print(available_data_types) | |
| required_data_types = cfg.data.data_types | |
| data_exts = cfg.data.extensions | |
| # Find filenames. | |
| assert set(required_data_types).issubset(set(available_data_types)), \ | |
| print(set(required_data_types) - set(available_data_types), 'missing') | |
| # Find extensions for each data type. | |
| extensions = {} | |
| for data_type, data_ext in zip(required_data_types, data_exts): | |
| extensions[data_type] = data_ext | |
| print('Data file extensions:', extensions) | |
| if paired: | |
| if input_list != '': | |
| all_filenames = get_all_filenames_from_list(input_list) | |
| else: | |
| # Get list of all sequences in the dataset. | |
| if 'data_keypoint' in required_data_types: | |
| search_dir = 'data_keypoint' | |
| elif 'data_segmaps' in required_data_types: | |
| search_dir = 'data_segmaps' | |
| else: | |
| search_dir = required_data_types[0] | |
| print('Searching in dir: %s' % search_dir) | |
| sequences = path.get_recursive_subdirectories( | |
| os.path.join(data_root, search_dir), | |
| extensions[search_dir]) | |
| print('Found %d sequences' % (len(sequences))) | |
| # Get filenames in each sequence. | |
| all_filenames = {} | |
| for sequence in sequences: | |
| folder = '%s/%s/%s/*.%s' % ( | |
| data_root, search_dir, sequence, | |
| extensions[search_dir]) | |
| filenames = sorted(glob.glob(folder)) | |
| filenames = [ | |
| os.path.splitext(os.path.basename(filename))[0] for | |
| filename in filenames] | |
| all_filenames[sequence] = filenames | |
| total_filenames = [len(filenames) | |
| for _, filenames in all_filenames.items()] | |
| print('Found %d files' % (sum(total_filenames))) | |
| else: | |
| # Get sequences in each data type. | |
| all_filenames = {} | |
| for data_type in required_data_types: | |
| all_filenames[data_type] = {} | |
| sequences = path.get_recursive_subdirectories( | |
| os.path.join(data_root, data_type), extensions[data_type]) | |
| # Get filenames in each sequence. | |
| total_filenames = 0 | |
| for sequence in sequences: | |
| folder = '%s/%s/%s/*.%s' % ( | |
| data_root, data_type, sequence, extensions[data_type]) | |
| filenames = sorted(glob.glob(folder)) | |
| filenames = [ | |
| os.path.splitext(os.path.basename(filename))[0] for | |
| filename in filenames] | |
| all_filenames[data_type][sequence] = filenames | |
| total_filenames += len(filenames) | |
| print('Data type: %s, Found %d sequences, Found %d files' % | |
| (data_type, len(sequences), total_filenames)) | |
| return all_filenames, extensions | |