Spaces:
Runtime error
Runtime error
| import os | |
| import os.path as osp | |
| import shutil | |
| import copy | |
| import time | |
| import pprint | |
| import numpy as np | |
| import torch | |
| import argparse | |
| import json | |
| import yaml | |
| from easydict import EasyDict as edict | |
| from .model_zoo import get_model | |
| ############ | |
| # cfg_bank # | |
| ############ | |
| def cfg_solvef(cmd, root): | |
| if not isinstance(cmd, str): | |
| return cmd | |
| if cmd.find('SAME')==0: | |
| zoom = root | |
| p = cmd[len('SAME'):].strip('()').split('.') | |
| p = [pi.strip() for pi in p] | |
| for pi in p: | |
| try: | |
| pi = int(pi) | |
| except: | |
| pass | |
| try: | |
| zoom = zoom[pi] | |
| except: | |
| return cmd | |
| return cfg_solvef(zoom, root) | |
| if cmd.find('SEARCH')==0: | |
| zoom = root | |
| p = cmd[len('SEARCH'):].strip('()').split('.') | |
| p = [pi.strip() for pi in p] | |
| find = True | |
| # Depth first search | |
| for pi in p: | |
| try: | |
| pi = int(pi) | |
| except: | |
| pass | |
| try: | |
| zoom = zoom[pi] | |
| except: | |
| find = False | |
| break | |
| if find: | |
| return cfg_solvef(zoom, root) | |
| else: | |
| if isinstance(root, dict): | |
| for ri in root: | |
| rv = cfg_solvef(cmd, root[ri]) | |
| if rv != cmd: | |
| return rv | |
| if isinstance(root, list): | |
| for ri in root: | |
| rv = cfg_solvef(cmd, ri) | |
| if rv != cmd: | |
| return rv | |
| return cmd | |
| if cmd.find('MODEL')==0: | |
| goto = cmd[len('MODEL'):].strip('()') | |
| return model_cfg_bank()(goto) | |
| if cmd.find('DATASET')==0: | |
| goto = cmd[len('DATASET'):].strip('()') | |
| return dataset_cfg_bank()(goto) | |
| return cmd | |
| def cfg_solve(cfg, cfg_root): | |
| # The function solve cfg element such that | |
| # all sorrogate input are settled. | |
| # (i.e. SAME(***) ) | |
| if isinstance(cfg, list): | |
| for i in range(len(cfg)): | |
| if isinstance(cfg[i], (list, dict)): | |
| cfg[i] = cfg_solve(cfg[i], cfg_root) | |
| else: | |
| cfg[i] = cfg_solvef(cfg[i], cfg_root) | |
| if isinstance(cfg, dict): | |
| for k in cfg: | |
| if isinstance(cfg[k], (list, dict)): | |
| cfg[k] = cfg_solve(cfg[k], cfg_root) | |
| else: | |
| cfg[k] = cfg_solvef(cfg[k], cfg_root) | |
| return cfg | |
| class model_cfg_bank(object): | |
| def __init__(self): | |
| self.cfg_dir = osp.join('configs', 'model') | |
| self.cfg_bank = edict() | |
| def __call__(self, name): | |
| if name not in self.cfg_bank: | |
| cfg_path = self.get_yaml_path(name) | |
| with open(cfg_path, 'r') as f: | |
| cfg_new = yaml.load( | |
| f, Loader=yaml.FullLoader) | |
| cfg_new = edict(cfg_new) | |
| self.cfg_bank.update(cfg_new) | |
| cfg = self.cfg_bank[name] | |
| cfg.name = name | |
| if 'super_cfg' not in cfg: | |
| cfg = cfg_solve(cfg, cfg) | |
| self.cfg_bank[name] = cfg | |
| return copy.deepcopy(cfg) | |
| super_cfg = self.__call__(cfg.super_cfg) | |
| # unlike other field, | |
| # args will not be replaced but update. | |
| if 'args' in cfg: | |
| if 'args' in super_cfg: | |
| super_cfg.args.update(cfg.args) | |
| else: | |
| super_cfg.args = cfg.args | |
| cfg.pop('args') | |
| super_cfg.update(cfg) | |
| super_cfg.pop('super_cfg') | |
| cfg = super_cfg | |
| try: | |
| delete_args = cfg.pop('delete_args') | |
| except: | |
| delete_args = [] | |
| for dargs in delete_args: | |
| cfg.args.pop(dargs) | |
| cfg = cfg_solve(cfg, cfg) | |
| self.cfg_bank[name] = cfg | |
| return copy.deepcopy(cfg) | |
| def get_yaml_path(self, name): | |
| if name.find('openai_unet')==0: | |
| return osp.join( | |
| self.cfg_dir, 'openai_unet.yaml') | |
| elif (name.find('clip')==0) or (name.find('openclip')==0): | |
| return osp.join( | |
| self.cfg_dir, 'clip.yaml') | |
| elif name.find('vd')==0: | |
| return osp.join( | |
| self.cfg_dir, 'vd.yaml') | |
| elif name.find('optimus')==0: | |
| return osp.join( | |
| self.cfg_dir, 'optimus.yaml') | |
| elif name.find('autokl')==0: | |
| return osp.join( | |
| self.cfg_dir, 'autokl.yaml') | |
| else: | |
| raise ValueError | |
| class dataset_cfg_bank(object): | |
| def __init__(self): | |
| self.cfg_dir = osp.join('configs', 'dataset') | |
| self.cfg_bank = edict() | |
| def __call__(self, name): | |
| if name not in self.cfg_bank: | |
| cfg_path = self.get_yaml_path(name) | |
| with open(cfg_path, 'r') as f: | |
| cfg_new = yaml.load( | |
| f, Loader=yaml.FullLoader) | |
| cfg_new = edict(cfg_new) | |
| self.cfg_bank.update(cfg_new) | |
| cfg = self.cfg_bank[name] | |
| cfg.name = name | |
| if cfg.get('super_cfg', None) is None: | |
| cfg = cfg_solve(cfg, cfg) | |
| self.cfg_bank[name] = cfg | |
| return copy.deepcopy(cfg) | |
| super_cfg = self.__call__(cfg.super_cfg) | |
| super_cfg.update(cfg) | |
| cfg = super_cfg | |
| cfg.super_cfg = None | |
| try: | |
| delete = cfg.pop('delete') | |
| except: | |
| delete = [] | |
| for dargs in delete: | |
| cfg.pop(dargs) | |
| cfg = cfg_solve(cfg, cfg) | |
| self.cfg_bank[name] = cfg | |
| return copy.deepcopy(cfg) | |
| def get_yaml_path(self, name): | |
| if name.find('laion2b')==0: | |
| return osp.join( | |
| self.cfg_dir, 'laion2b.yaml') | |
| else: | |
| raise ValueError | |
| class experiment_cfg_bank(object): | |
| def __init__(self): | |
| self.cfg_dir = osp.join('configs', 'experiment') | |
| self.cfg_bank = edict() | |
| def __call__(self, name): | |
| if name not in self.cfg_bank: | |
| cfg_path = self.get_yaml_path(name) | |
| with open(cfg_path, 'r') as f: | |
| cfg = yaml.load( | |
| f, Loader=yaml.FullLoader) | |
| cfg = edict(cfg) | |
| cfg = cfg_solve(cfg, cfg) | |
| cfg = cfg_solve(cfg, cfg) | |
| # twice for SEARCH | |
| self.cfg_bank[name] = cfg | |
| return copy.deepcopy(cfg) | |
| def get_yaml_path(self, name): | |
| return osp.join( | |
| self.cfg_dir, name+'.yaml') | |
| def load_cfg_yaml(path): | |
| if osp.isfile(path): | |
| cfg_path = path | |
| elif osp.isfile(osp.join('configs', 'experiment', path)): | |
| cfg_path = osp.join('configs', 'experiment', path) | |
| elif osp.isfile(osp.join('configs', 'experiment', path+'.yaml')): | |
| cfg_path = osp.join('configs', 'experiment', path+'.yaml') | |
| else: | |
| assert False, 'No such config!' | |
| with open(cfg_path, 'r') as f: | |
| cfg = yaml.load(f, Loader=yaml.FullLoader) | |
| cfg = edict(cfg) | |
| cfg = cfg_solve(cfg, cfg) | |
| cfg = cfg_solve(cfg, cfg) | |
| return cfg | |
| ############## | |
| # cfg_helper # | |
| ############## | |
| def get_experiment_id(ref=None): | |
| if ref is None: | |
| time.sleep(0.5) | |
| return int(time.time()*100) | |
| else: | |
| try: | |
| return int(ref) | |
| except: | |
| pass | |
| _, ref = osp.split(ref) | |
| ref = ref.split('_')[0] | |
| try: | |
| return int(ref) | |
| except: | |
| assert False, 'Invalid experiment ID!' | |
| def record_resume_cfg(path): | |
| cnt = 0 | |
| while True: | |
| if osp.exists(path+'.{:04d}'.format(cnt)): | |
| cnt += 1 | |
| continue | |
| shutil.copyfile(path, path+'.{:04d}'.format(cnt)) | |
| break | |
| def get_command_line_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--debug', action='store_true', default=False) | |
| parser.add_argument('--config', type=str) | |
| parser.add_argument('--gpu', nargs='+', type=int) | |
| parser.add_argument('--node_rank', type=int) | |
| parser.add_argument('--node_list', nargs='+', type=str) | |
| parser.add_argument('--nodes', type=int) | |
| parser.add_argument('--addr', type=str, default='127.0.0.1') | |
| parser.add_argument('--port', type=int, default=11233) | |
| parser.add_argument('--signature', nargs='+', type=str) | |
| parser.add_argument('--seed', type=int) | |
| parser.add_argument('--eval', type=str) | |
| parser.add_argument('--eval_subdir', type=str) | |
| parser.add_argument('--pretrained', type=str) | |
| parser.add_argument('--resume_dir', type=str) | |
| parser.add_argument('--resume_step', type=int) | |
| parser.add_argument('--resume_weight', type=str) | |
| args = parser.parse_args() | |
| # Special handling the resume | |
| if args.resume_dir is not None: | |
| cfg = edict() | |
| cfg.env = edict() | |
| cfg.env.debug = args.debug | |
| cfg.env.resume = edict() | |
| cfg.env.resume.dir = args.resume_dir | |
| cfg.env.resume.step = args.resume_step | |
| cfg.env.resume.weight = args.resume_weight | |
| return cfg | |
| cfg = load_cfg_yaml(args.config) | |
| cfg.env.debug = args.debug | |
| cfg.env.gpu_device = [0] if args.gpu is None else list(args.gpu) | |
| cfg.env.master_addr = args.addr | |
| cfg.env.master_port = args.port | |
| cfg.env.dist_url = 'tcp://{}:{}'.format(args.addr, args.port) | |
| if args.node_list is None: | |
| cfg.env.node_rank = 0 if args.node_rank is None else args.node_rank | |
| cfg.env.nodes = 1 if args.nodes is None else args.nodes | |
| else: | |
| import socket | |
| hostname = socket.gethostname() | |
| assert cfg.env.master_addr == args.node_list[0] | |
| cfg.env.node_rank = args.node_list.index(hostname) | |
| cfg.env.nodes = len(args.node_list) | |
| cfg.env.node_list = args.node_list | |
| istrain = False if args.eval is not None else True | |
| isdebug = cfg.env.debug | |
| if istrain: | |
| if isdebug: | |
| cfg.env.experiment_id = 999999999999 | |
| cfg.train.signature = ['debug'] | |
| else: | |
| cfg.env.experiment_id = get_experiment_id() | |
| if args.signature is not None: | |
| cfg.train.signature = args.signature | |
| else: | |
| if 'train' in cfg: | |
| cfg.pop('train') | |
| cfg.env.experiment_id = get_experiment_id(args.eval) | |
| if args.signature is not None: | |
| cfg.eval.signature = args.signature | |
| if isdebug and (args.eval is None): | |
| cfg.env.experiment_id = 999999999999 | |
| cfg.eval.signature = ['debug'] | |
| if args.eval_subdir is not None: | |
| if isdebug: | |
| cfg.eval.eval_subdir = 'debug' | |
| else: | |
| cfg.eval.eval_subdir = args.eval_subdir | |
| if args.pretrained is not None: | |
| cfg.eval.pretrained = args.pretrained | |
| # The override pretrained over the setting in cfg.model | |
| if args.seed is not None: | |
| cfg.env.rnd_seed = args.seed | |
| return cfg | |
| def cfg_initiates(cfg): | |
| cfge = cfg.env | |
| isdebug = cfge.debug | |
| isresume = 'resume' in cfge | |
| istrain = 'train' in cfg | |
| haseval = 'eval' in cfg | |
| cfgt = cfg.train if istrain else None | |
| cfgv = cfg.eval if haseval else None | |
| ############################### | |
| # get some environment params # | |
| ############################### | |
| cfge.computer = os.uname() | |
| cfge.torch_version = str(torch.__version__) | |
| ########## | |
| # resume # | |
| ########## | |
| if isresume: | |
| resume_cfg_path = osp.join(cfge.resume.dir, 'config.yaml') | |
| record_resume_cfg(resume_cfg_path) | |
| with open(resume_cfg_path, 'r') as f: | |
| cfg_resume = yaml.load(f, Loader=yaml.FullLoader) | |
| cfg_resume = edict(cfg_resume) | |
| cfg_resume.env.update(cfge) | |
| cfg = cfg_resume | |
| cfge = cfg.env | |
| log_file = cfg.train.log_file | |
| print('') | |
| print('##########') | |
| print('# resume #') | |
| print('##########') | |
| print('') | |
| with open(log_file, 'a') as f: | |
| print('', file=f) | |
| print('##########', file=f) | |
| print('# resume #', file=f) | |
| print('##########', file=f) | |
| print('', file=f) | |
| pprint.pprint(cfg) | |
| with open(log_file, 'a') as f: | |
| pprint.pprint(cfg, f) | |
| #################### | |
| # node distributed # | |
| #################### | |
| if cfg.env.master_addr!='127.0.0.1': | |
| os.environ['MASTER_ADDR'] = cfge.master_addr | |
| os.environ['MASTER_PORT'] = '{}'.format(cfge.master_port) | |
| if cfg.env.dist_backend=='nccl': | |
| os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET' | |
| if cfg.env.dist_backend=='gloo': | |
| os.environ['GLOO_SOCKET_FAMILY'] = 'AF_INET' | |
| ####################### | |
| # cuda visible device # | |
| ####################### | |
| os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( | |
| [str(gid) for gid in cfge.gpu_device]) | |
| ##################### | |
| # return resume cfg # | |
| ##################### | |
| if isresume: | |
| return cfg | |
| ############################################# | |
| # some misc setting that not need in resume # | |
| ############################################# | |
| cfgm = cfg.model | |
| cfge.gpu_count = len(cfge.gpu_device) | |
| ########################################## | |
| # align batch size and num worker config # | |
| ########################################## | |
| gpu_n = cfge.gpu_count * cfge.nodes | |
| def align_batch_size(bs, bs_per_gpu): | |
| assert (bs is not None) or (bs_per_gpu is not None) | |
| bs = bs_per_gpu * gpu_n if bs is None else bs | |
| bs_per_gpu = bs // gpu_n if bs_per_gpu is None else bs_per_gpu | |
| assert (bs == bs_per_gpu * gpu_n) | |
| return bs, bs_per_gpu | |
| if istrain: | |
| cfgt.batch_size, cfgt.batch_size_per_gpu = \ | |
| align_batch_size(cfgt.batch_size, cfgt.batch_size_per_gpu) | |
| cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu = \ | |
| align_batch_size(cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu) | |
| if haseval: | |
| cfgv.batch_size, cfgv.batch_size_per_gpu = \ | |
| align_batch_size(cfgv.batch_size, cfgv.batch_size_per_gpu) | |
| cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu = \ | |
| align_batch_size(cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu) | |
| ################## | |
| # create log dir # | |
| ################## | |
| if istrain: | |
| if not isdebug: | |
| sig = cfgt.get('signature', []) | |
| sig = sig + ['s{}'.format(cfge.rnd_seed)] | |
| else: | |
| sig = ['debug'] | |
| log_dir = [ | |
| cfge.log_root_dir, | |
| '{}_{}'.format(cfgm.symbol, cfgt.dataset.symbol), | |
| '_'.join([str(cfge.experiment_id)] + sig) | |
| ] | |
| log_dir = osp.join(*log_dir) | |
| log_file = osp.join(log_dir, 'train.log') | |
| if not osp.exists(log_file): | |
| os.makedirs(osp.dirname(log_file)) | |
| cfgt.log_dir = log_dir | |
| cfgt.log_file = log_file | |
| if haseval: | |
| cfgv.log_dir = log_dir | |
| cfgv.log_file = log_file | |
| else: | |
| model_symbol = cfgm.symbol | |
| if cfgv.get('dataset', None) is None: | |
| dataset_symbol = 'nodataset' | |
| else: | |
| dataset_symbol = cfgv.dataset.symbol | |
| log_dir = osp.join(cfge.log_root_dir, '{}_{}'.format(model_symbol, dataset_symbol)) | |
| exp_dir = search_experiment_folder(log_dir, cfge.experiment_id) | |
| if exp_dir is None: | |
| if not isdebug: | |
| sig = cfgv.get('signature', []) + ['evalonly'] | |
| else: | |
| sig = ['debug'] | |
| exp_dir = '_'.join([str(cfge.experiment_id)] + sig) | |
| eval_subdir = cfgv.get('eval_subdir', None) | |
| # override subdir in debug mode (if eval_subdir is set) | |
| eval_subdir = 'debug' if (eval_subdir is not None) and isdebug else eval_subdir | |
| if eval_subdir is not None: | |
| log_dir = osp.join(log_dir, exp_dir, eval_subdir) | |
| else: | |
| log_dir = osp.join(log_dir, exp_dir) | |
| disable_log_override = cfgv.get('disable_log_override', False) | |
| if osp.isdir(log_dir): | |
| if disable_log_override: | |
| assert False, 'Override an exsited log_dir is disabled at [{}]'.format(log_dir) | |
| else: | |
| os.makedirs(log_dir) | |
| log_file = osp.join(log_dir, 'eval.log') | |
| cfgv.log_dir = log_dir | |
| cfgv.log_file = log_file | |
| ###################### | |
| # print and save cfg # | |
| ###################### | |
| pprint.pprint(cfg) | |
| if cfge.node_rank==0: | |
| with open(log_file, 'w') as f: | |
| pprint.pprint(cfg, f) | |
| with open(osp.join(log_dir, 'config.yaml'), 'w') as f: | |
| yaml.dump(edict_2_dict(cfg), f) | |
| else: | |
| with open(osp.join(log_dir, 'config.yaml.{}'.format(cfge.node_rank)), 'w') as f: | |
| yaml.dump(edict_2_dict(cfg), f) | |
| ############# | |
| # save code # | |
| ############# | |
| save_code = False | |
| if istrain: | |
| save_code = cfgt.get('save_code', False) | |
| elif haseval: | |
| save_code = cfgv.get('save_code', False) | |
| save_code = save_code and (cfge.node_rank==0) | |
| if save_code: | |
| codedir = osp.join(log_dir, 'code') | |
| if osp.exists(codedir): | |
| shutil.rmtree(codedir) | |
| for d in ['configs', 'lib']: | |
| fromcodedir = d | |
| tocodedir = osp.join(codedir, d) | |
| shutil.copytree( | |
| fromcodedir, tocodedir, | |
| ignore=shutil.ignore_patterns( | |
| '*__pycache__*', '*build*')) | |
| for codei in os.listdir('.'): | |
| if osp.splitext(codei)[1] == 'py': | |
| shutil.copy(codei, codedir) | |
| ####################### | |
| # set matplotlib mode # | |
| ####################### | |
| if 'matplotlib_mode' in cfge: | |
| try: | |
| matplotlib.use(cfge.matplotlib_mode) | |
| except: | |
| print('Warning: matplotlib mode [{}] failed to be set!'.format(cfge.matplotlib_mode)) | |
| return cfg | |
| def edict_2_dict(x): | |
| if isinstance(x, dict): | |
| xnew = {} | |
| for k in x: | |
| xnew[k] = edict_2_dict(x[k]) | |
| return xnew | |
| elif isinstance(x, list): | |
| xnew = [] | |
| for i in range(len(x)): | |
| xnew.append( edict_2_dict(x[i]) ) | |
| return xnew | |
| else: | |
| return x | |
| def search_experiment_folder(root, exid): | |
| target = None | |
| for fi in os.listdir(root): | |
| if not osp.isdir(osp.join(root, fi)): | |
| continue | |
| if int(fi.split('_')[0]) == exid: | |
| if target is not None: | |
| return None # duplicated | |
| elif target is None: | |
| target = fi | |
| return target | |