Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| """Config utilities for yml file.""" | |
| import collections | |
| import functools | |
| import os | |
| import re | |
| import yaml | |
| from imaginaire.utils.distributed import master_only_print as print | |
| DEBUG = False | |
| USE_JIT = False | |
| class AttrDict(dict): | |
| """Dict as attribute trick.""" | |
| def __init__(self, *args, **kwargs): | |
| super(AttrDict, self).__init__(*args, **kwargs) | |
| self.__dict__ = self | |
| for key, value in self.__dict__.items(): | |
| if isinstance(value, dict): | |
| self.__dict__[key] = AttrDict(value) | |
| elif isinstance(value, (list, tuple)): | |
| if isinstance(value[0], dict): | |
| self.__dict__[key] = [AttrDict(item) for item in value] | |
| else: | |
| self.__dict__[key] = value | |
| def yaml(self): | |
| """Convert object to yaml dict and return.""" | |
| yaml_dict = {} | |
| for key, value in self.__dict__.items(): | |
| if isinstance(value, AttrDict): | |
| yaml_dict[key] = value.yaml() | |
| elif isinstance(value, list): | |
| if isinstance(value[0], AttrDict): | |
| new_l = [] | |
| for item in value: | |
| new_l.append(item.yaml()) | |
| yaml_dict[key] = new_l | |
| else: | |
| yaml_dict[key] = value | |
| else: | |
| yaml_dict[key] = value | |
| return yaml_dict | |
| def __repr__(self): | |
| """Print all variables.""" | |
| ret_str = [] | |
| for key, value in self.__dict__.items(): | |
| if isinstance(value, AttrDict): | |
| ret_str.append('{}:'.format(key)) | |
| child_ret_str = value.__repr__().split('\n') | |
| for item in child_ret_str: | |
| ret_str.append(' ' + item) | |
| elif isinstance(value, list): | |
| if isinstance(value[0], AttrDict): | |
| ret_str.append('{}:'.format(key)) | |
| for item in value: | |
| # Treat as AttrDict above. | |
| child_ret_str = item.__repr__().split('\n') | |
| for item in child_ret_str: | |
| ret_str.append(' ' + item) | |
| else: | |
| ret_str.append('{}: {}'.format(key, value)) | |
| else: | |
| ret_str.append('{}: {}'.format(key, value)) | |
| return '\n'.join(ret_str) | |
| class Config(AttrDict): | |
| r"""Configuration class. This should include every human specifiable | |
| hyperparameter values for your training.""" | |
| def __init__(self, filename=None, verbose=False): | |
| super(Config, self).__init__() | |
| self.source_filename = filename | |
| # Set default parameters. | |
| # Logging. | |
| large_number = 1000000000 | |
| self.snapshot_save_iter = large_number | |
| self.snapshot_save_epoch = large_number | |
| self.metrics_iter = None | |
| self.metrics_epoch = None | |
| self.snapshot_save_start_iter = 0 | |
| self.snapshot_save_start_epoch = 0 | |
| self.image_save_iter = large_number | |
| self.image_display_iter = large_number | |
| self.max_epoch = large_number | |
| self.max_iter = large_number | |
| self.logging_iter = 100 | |
| self.speed_benchmark = False | |
| # Trainer. | |
| self.trainer = AttrDict( | |
| model_average_config=AttrDict(enabled=False, | |
| beta=0.9999, | |
| start_iteration=1000, | |
| num_batch_norm_estimation_iterations=30, | |
| remove_sn=True), | |
| # model_average=False, | |
| # model_average_beta=0.9999, | |
| # model_average_start_iteration=1000, | |
| # model_average_batch_norm_estimation_iteration=30, | |
| # model_average_remove_sn=True, | |
| image_to_tensorboard=False, | |
| hparam_to_tensorboard=False, | |
| distributed_data_parallel='pytorch', | |
| distributed_data_parallel_params=AttrDict( | |
| find_unused_parameters=False), | |
| delay_allreduce=True, | |
| gan_relativistic=False, | |
| gen_step=1, | |
| dis_step=1, | |
| gan_decay_k=1., | |
| gan_min_k=1., | |
| gan_separate_topk=False, | |
| aug_policy='', | |
| channels_last=False, | |
| strict_resume=True, | |
| amp_gp=False, | |
| amp_config=AttrDict(init_scale=65536.0, | |
| growth_factor=2.0, | |
| backoff_factor=0.5, | |
| growth_interval=2000, | |
| enabled=False)) | |
| # Networks. | |
| self.gen = AttrDict(type='imaginaire.generators.dummy') | |
| self.dis = AttrDict(type='imaginaire.discriminators.dummy') | |
| # Optimizers. | |
| self.gen_opt = AttrDict(type='adam', | |
| fused_opt=False, | |
| lr=0.0001, | |
| adam_beta1=0.0, | |
| adam_beta2=0.999, | |
| eps=1e-8, | |
| lr_policy=AttrDict(iteration_mode=False, | |
| type='step', | |
| step_size=large_number, | |
| gamma=1)) | |
| self.dis_opt = AttrDict(type='adam', | |
| fused_opt=False, | |
| lr=0.0001, | |
| adam_beta1=0.0, | |
| adam_beta2=0.999, | |
| eps=1e-8, | |
| lr_policy=AttrDict(iteration_mode=False, | |
| type='step', | |
| step_size=large_number, | |
| gamma=1)) | |
| # Data. | |
| self.data = AttrDict(name='dummy', | |
| type='imaginaire.datasets.images', | |
| num_workers=0) | |
| self.test_data = AttrDict(name='dummy', | |
| type='imaginaire.datasets.images', | |
| num_workers=0, | |
| test=AttrDict(is_lmdb=False, | |
| roots='', | |
| batch_size=1)) | |
| # Cudnn. | |
| self.cudnn = AttrDict(deterministic=False, | |
| benchmark=True) | |
| # Others. | |
| self.pretrained_weight = '' | |
| self.inference_args = AttrDict() | |
| # Update with given configurations. | |
| assert os.path.exists(filename), 'File {} not exist.'.format(filename) | |
| loader = yaml.SafeLoader | |
| loader.add_implicit_resolver( | |
| u'tag:yaml.org,2002:float', | |
| re.compile(u'''^(?: | |
| [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? | |
| |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) | |
| |\\.[0-9_]+(?:[eE][-+][0-9]+)? | |
| |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* | |
| |[-+]?\\.(?:inf|Inf|INF) | |
| |\\.(?:nan|NaN|NAN))$''', re.X), | |
| list(u'-+0123456789.')) | |
| try: | |
| with open(filename, 'r') as f: | |
| cfg_dict = yaml.load(f, Loader=loader) | |
| except EnvironmentError: | |
| print('Please check the file with name of "%s"', filename) | |
| recursive_update(self, cfg_dict) | |
| # Put common opts in both gen and dis. | |
| if 'common' in cfg_dict: | |
| self.common = AttrDict(**cfg_dict['common']) | |
| self.gen.common = self.common | |
| self.dis.common = self.common | |
| if verbose: | |
| print(' imaginaire config '.center(80, '-')) | |
| print(self.__repr__()) | |
| print(''.center(80, '-')) | |
| def rsetattr(obj, attr, val): | |
| """Recursively find object and set value""" | |
| pre, _, post = attr.rpartition('.') | |
| return setattr(rgetattr(obj, pre) if pre else obj, post, val) | |
| def rgetattr(obj, attr, *args): | |
| """Recursively find object and return value""" | |
| def _getattr(obj, attr): | |
| r"""Get attribute.""" | |
| return getattr(obj, attr, *args) | |
| return functools.reduce(_getattr, [obj] + attr.split('.')) | |
| def recursive_update(d, u): | |
| """Recursively update AttrDict d with AttrDict u""" | |
| for key, value in u.items(): | |
| if isinstance(value, collections.abc.Mapping): | |
| d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) | |
| elif isinstance(value, (list, tuple)): | |
| if isinstance(value[0], dict): | |
| d.__dict__[key] = [AttrDict(item) for item in value] | |
| else: | |
| d.__dict__[key] = value | |
| else: | |
| d.__dict__[key] = value | |
| return d | |