Spaces:
Configuration error
Configuration error
| import os | |
| import logging | |
| from pathlib import Path | |
| from functools import reduce | |
| from operator import getitem | |
| from datetime import datetime | |
| from logger import setup_logging | |
| from utils import read_json, write_json | |
| class ConfigParser: | |
| __instance = None | |
| def __new__(cls, args, options='', timestamp=True): | |
| raise NotImplementedError('Cannot initialize via Constructor') | |
| def __internal_new__(cls): | |
| return super().__new__(cls) | |
| def get_instance(cls, args=None, options='', timestamp=True): | |
| if not cls.__instance: | |
| if args is None: | |
| NotImplementedError('Cannot initialize without args') | |
| cls.__instance = cls.__internal_new__() | |
| cls.__instance.__init__(args, options) | |
| return cls.__instance | |
| def __init__(self, args, options='', timestamp=True): | |
| # parse default and custom cli options | |
| for opt in options: | |
| args.add_argument(*opt.flags, default=None, type=opt.type) | |
| args = args.parse_args() | |
| if args.device: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.device | |
| if args.resume is None: | |
| msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." | |
| assert args.config is not None, msg_no_cfg | |
| self.cfg_fname = Path(args.config) | |
| config = read_json(self.cfg_fname) | |
| self.resume = None | |
| else: | |
| self.resume = Path(args.resume) | |
| resume_cfg_fname = self.resume.parent / 'config.json' | |
| config = read_json(resume_cfg_fname) | |
| if args.config is not None: | |
| config.update(read_json(Path(args.config))) | |
| # load config file and apply custom cli options | |
| self._config = _update_config(config, options, args) | |
| # set save_dir where trained model and log will be saved. | |
| save_dir = Path(self.config['trainer']['save_dir']) | |
| timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' | |
| if self.config['trainer']['asym']: | |
| exper_name = self.config['name'] + '_asym_' + str(int(self.config['trainer']['percent']*100)) | |
| else: | |
| exper_name = self.config['name'] + '_sym_' + str(int(self.config['trainer']['percent']*100)) | |
| self._save_dir = save_dir / 'models' / exper_name / timestamp | |
| self._log_dir = save_dir / 'log' / exper_name / timestamp | |
| self.save_dir.mkdir(parents=True, exist_ok=True) | |
| self.log_dir.mkdir(parents=True, exist_ok=True) | |
| # save updated config file to the checkpoint dir | |
| write_json(self.config, self.save_dir / 'config.json') | |
| # configure logging module | |
| setup_logging(self.log_dir) | |
| self.log_levels = { | |
| 0: logging.WARNING, | |
| 1: logging.INFO, | |
| 2: logging.DEBUG | |
| } | |
| def initialize(self, name, module, *args, **kwargs): | |
| """ | |
| finds a function handle with the name given as 'type' in config, and returns the | |
| instance initialized with corresponding keyword args given as 'args'. | |
| """ | |
| module_name = self[name]['type'] | |
| module_args = dict(self[name]['args']) | |
| assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' | |
| module_args.update(kwargs) | |
| return getattr(module, module_name)(*args, **module_args) | |
| def __getitem__(self, name): | |
| return self.config[name] | |
| def get_logger(self, name, verbosity=2): | |
| msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, | |
| self.log_levels.keys()) | |
| assert verbosity in self.log_levels, msg_verbosity | |
| logger = logging.getLogger(name) | |
| logger.setLevel(self.log_levels[verbosity]) | |
| return logger | |
| # setting read-only attributes | |
| def config(self): | |
| return self._config | |
| def save_dir(self): | |
| return self._save_dir | |
| def log_dir(self): | |
| return self._log_dir | |
| # helper functions used to update config dict with custom cli options | |
| def _update_config(config, options, args): | |
| for opt in options: | |
| value = getattr(args, _get_opt_name(opt.flags)) | |
| if value is not None: | |
| _set_by_path(config, opt.target, value) | |
| if 'target2' in opt._fields: | |
| _set_by_path(config, opt.target2, value) | |
| if 'target3' in opt._fields: | |
| _set_by_path(config, opt.target3, value) | |
| return config | |
| def _get_opt_name(flags): | |
| for flg in flags: | |
| if flg.startswith('--'): | |
| return flg.replace('--', '') | |
| return flags[0].replace('--', '') | |
| def _set_by_path(tree, keys, value): | |
| """Set a value in a nested object in tree by sequence of keys.""" | |
| _get_by_path(tree, keys[:-1])[keys[-1]] = value | |
| def _get_by_path(tree, keys): | |
| """Access a nested object in tree by sequence of keys.""" | |
| return reduce(getitem, keys, tree) | |