Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import os | |
| import torch | |
| import random | |
| import string | |
| import yaml | |
| from easydict import EasyDict as edict | |
| import utils | |
| from utils import log | |
| def parse_arguments(args): | |
| """ | |
| Parse arguments from command line. | |
| Syntax: --key1.key2.key3=value --> value | |
| --key1.key2.key3= --> None | |
| --key1.key2.key3 --> True | |
| --key1.key2.key3! --> False | |
| """ | |
| opt_cmd = {} | |
| for arg in args: | |
| assert(arg.startswith("--")) | |
| if "=" not in arg[2:]: | |
| key_str,value = (arg[2:-1],"false") if arg[-1]=="!" else (arg[2:],"true") | |
| else: | |
| key_str,value = arg[2:].split("=") | |
| keys_sub = key_str.split(".") | |
| opt_sub = opt_cmd | |
| for k in keys_sub[:-1]: | |
| if k not in opt_sub: opt_sub[k] = {} | |
| opt_sub = opt_sub[k] | |
| assert keys_sub[-1] not in opt_sub,keys_sub[-1] | |
| opt_sub[keys_sub[-1]] = yaml.safe_load(value) | |
| opt_cmd = edict(opt_cmd) | |
| return opt_cmd | |
| def set(opt_cmd={}): | |
| log.info("setting configurations...") | |
| # load config from yaml file | |
| assert("yaml" in opt_cmd) | |
| fname = "options/{}.yaml".format(opt_cmd.yaml) | |
| opt_base = load_options(fname) | |
| # override with command line arguments | |
| opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True) | |
| process_options(opt) | |
| log.options(opt) | |
| return opt | |
| def load_options(fname): | |
| with open(fname) as file: | |
| opt = edict(yaml.safe_load(file)) | |
| if "_parent_" in opt: | |
| # load parent yaml file(s) as base options | |
| parent_fnames = opt.pop("_parent_") | |
| if type(parent_fnames) is str: | |
| parent_fnames = [parent_fnames] | |
| for parent_fname in parent_fnames: | |
| opt_parent = load_options(parent_fname) | |
| opt_parent = override_options(opt_parent,opt,key_stack=[]) | |
| opt = opt_parent | |
| print("loading {}...".format(fname)) | |
| return opt | |
| def override_options(opt,opt_over,key_stack=None,safe_check=False): | |
| for key,value in opt_over.items(): | |
| print(key,value) | |
| if isinstance(value,dict): | |
| # parse child options (until leaf nodes are reached) | |
| opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check) | |
| else: | |
| # ensure command line argument to override is also in yaml file | |
| if safe_check and key not in opt: | |
| add_new = None | |
| while add_new not in ["y","n"]: | |
| key_str = ".".join(key_stack+[key]) | |
| add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str)) | |
| if add_new=="n": | |
| print("safe exiting...") | |
| exit() | |
| opt[key] = value | |
| return opt | |
| def process_options(opt): | |
| # set seed | |
| if opt.seed is not None: | |
| random.seed(opt.seed) | |
| np.random.seed(opt.seed) | |
| torch.manual_seed(opt.seed) | |
| torch.cuda.manual_seed_all(opt.seed) | |
| else: | |
| # create random string as run ID | |
| randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4)) | |
| opt.name = str(opt.name)+"_{}".format(randkey) | |
| assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough | |
| opt.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu) | |
| def save_options_file(opt,output_path): | |
| opt_fname = "{}/options.yaml".format(output_path) | |
| if os.path.isfile(opt_fname): | |
| with open(opt_fname) as file: | |
| opt_old = yaml.safe_load(file) | |
| if opt!=opt_old: | |
| # prompt if options are not identical | |
| opt_new_fname = "{}/options_temp.yaml".format(output_path) | |
| with open(opt_new_fname,"w") as file: | |
| yaml.safe_dump(utils.to_dict(opt),file,default_flow_style=False,indent=4) | |
| print("existing options file found (different from current one)...") | |
| os.system("diff {} {}".format(opt_fname,opt_new_fname)) | |
| os.system("rm {}".format(opt_new_fname)) | |
| override = None | |
| while override not in ["y","n"]: | |
| override = input("override? (y/n) ") | |
| if override=="n": | |
| print("safe exiting...") | |
| exit() | |
| else: print("existing options file found (identical)") | |
| else: print("(creating new options file...)") | |
| with open(opt_fname,"w") as file: | |
| yaml.safe_dump(utils.to_dict(opt),file,default_flow_style=False,indent=4) | |