Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import tarfile | |
| import sys | |
| sys.path.append('.') | |
| from imaginaire.utils.io import download_file_from_google_drive # noqa: E402 | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Download and process dataset') | |
| parser.add_argument('--dataset', help='Name of the dataset.', required=True, | |
| choices=['afhq_dog2cat', | |
| 'animal_faces']) | |
| parser.add_argument('--data_dir', default='./dataset', | |
| help='Directory to save all datasets.') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| if args.dataset == 'afhq_dog2cat': | |
| url = '1XaiwS0eRctqm-JEDezOBy4TXriAQgc4_' | |
| elif args.dataset == 'animal_faces': | |
| url = '1ftr1xWm0VakGlLUWi7-hdAt9W37luQOA' | |
| else: | |
| raise ValueError('Invalid dataset {}.'.format(args.dataset)) | |
| # Create the dataset directory. | |
| if not os.path.exists(args.data_dir): | |
| os.makedirs(args.data_dir) | |
| # Download the compressed dataset. | |
| folder_path = os.path.join(args.data_dir, args.dataset + '_raw') | |
| compressed_path = folder_path + '.tar.gz' | |
| if not os.path.exists(compressed_path) and not os.path.exists(folder_path): | |
| print("Downloading the dataset {}.".format(args.dataset)) | |
| download_file_from_google_drive(url, compressed_path) | |
| # Extract the dataset. | |
| if not os.path.exists(folder_path): | |
| print("Extracting the dataset {}.".format(args.dataset)) | |
| with tarfile.open(compressed_path) as tar: | |
| tar.extractall(folder_path) | |
| if __name__ == "__main__": | |
| main() | |