Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import sys | |
| import tarfile | |
| sys.path.append('.') | |
| from imaginaire.utils.io import download_file_from_google_drive # noqa: E402 | |
| URLS = { | |
| 'pix2pixhd': '1Xg9m184zkuG8H0LHdBtSzt2VbMi3SWwR', | |
| 'spade': '1ESm-gHWu_aMHnKF42qkGc8qf1SBECsgf', | |
| 'funit': '1a-EE_6RsYPUoKxEl5oXrpRmKYUltqaD-', | |
| 'coco_funit': '1JYVYB0Q1VStDLOb0SBJbN1vkaf6KrGDh', | |
| 'unit': '17BbwnCG7qF7FI-t9VkORv2XCKqlrY1CO', | |
| 'munit': '1VPgHGuQfmm1N1Vh56wr34wtAwaXzjXtH', | |
| 'vid2vid': '1SHvGPMq-55GDUQ0Ac2Ng0eyG5xCPeKhc', | |
| 'fs_vid2vid': '1fTj0HHjzcitgsSeG5O_aWMF8yvCQUQkN', | |
| 'wc_vid2vid/cityscapes': '1KKzrTHfbpBY9xtLqK8e3QvX8psSdrFcD', | |
| 'wc_vid2vid/mannequin': '1mafZf9KJrwUGGI1kBTvwgehHSqP5iaA0', | |
| 'gancraft': '1m6q7ZtYJjxFL0SQ_WzMbvoLZxXmI5_vJ', | |
| } | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Download test data.') | |
| parser.add_argument('--model_name', required=True, | |
| help='Name of the model.') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| test_data_dir = 'projects/' + args.model_name + '/test_data' | |
| print(test_data_dir) | |
| assert args.model_name in URLS, 'No sample test data available' | |
| url = URLS[args.model_name] | |
| if os.path.exists(test_data_dir): | |
| print('Test data exists at', test_data_dir) | |
| compressed_path = test_data_dir + '.tar.gz' | |
| # Extract the dataset. | |
| print('Extracting test data to', test_data_dir) | |
| with tarfile.open(compressed_path) as tar: | |
| tar.extractall(path=test_data_dir) | |
| else: | |
| os.makedirs(test_data_dir, exist_ok=True) | |
| # Download the compressed dataset. | |
| compressed_path = test_data_dir + '.tar.gz' | |
| if not os.path.exists(compressed_path): | |
| print('Downloading test data to', compressed_path) | |
| download_file_from_google_drive(url, compressed_path) | |
| # Extract the dataset. | |
| print('Extracting test data to', test_data_dir) | |
| with tarfile.open(compressed_path) as tar: | |
| tar.extractall(path=test_data_dir) | |
| if __name__ == "__main__": | |
| main() | |