Spaces:
Runtime error
Runtime error
| import argparse | |
| import random | |
| import numpy as np | |
| import torch | |
| import pprint | |
| import yaml | |
| def str2bool(v): | |
| if v.lower() in ('yes', 'true', 't', 'y', '1'): | |
| return True | |
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError('Boolean value expected.') | |
| def is_interactive(): | |
| import __main__ as main | |
| return not hasattr(main, '__file__') | |
| def get_optimizer(optim, verbose=False): | |
| # Bind the optimizer | |
| if optim == 'rms': | |
| if verbose: | |
| print("Optimizer: Using RMSProp") | |
| optimizer = torch.optim.RMSprop | |
| elif optim == 'adam': | |
| if verbose: | |
| print("Optimizer: Using Adam") | |
| optimizer = torch.optim.Adam | |
| elif optim == 'adamw': | |
| if verbose: | |
| print("Optimizer: Using AdamW") | |
| # optimizer = torch.optim.AdamW | |
| optimizer = 'adamw' | |
| elif optim == 'adamax': | |
| if verbose: | |
| print("Optimizer: Using Adamax") | |
| optimizer = torch.optim.Adamax | |
| elif optim == 'sgd': | |
| if verbose: | |
| print("Optimizer: SGD") | |
| optimizer = torch.optim.SGD | |
| else: | |
| assert False, "Please add your optimizer %s in the list." % optim | |
| return optimizer | |
| def parse_args(parse=True, **optional_kwargs): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--seed', type=int, default=9595, help='random seed') | |
| # Data Splits | |
| parser.add_argument("--train", default='karpathy_train') | |
| parser.add_argument("--valid", default='karpathy_val') | |
| parser.add_argument("--test", default='karpathy_test') | |
| # parser.add_argument('--test_only', action='store_true') | |
| # Quick experiments | |
| parser.add_argument('--train_topk', type=int, default=-1) | |
| parser.add_argument('--valid_topk', type=int, default=-1) | |
| # Checkpoint | |
| parser.add_argument('--output', type=str, default='snap/test') | |
| parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).') | |
| parser.add_argument('--from_scratch', action='store_true') | |
| # CPU/GPU | |
| parser.add_argument("--multiGPU", action='store_const', default=False, const=True) | |
| parser.add_argument('--fp16', action='store_true') | |
| parser.add_argument("--distributed", action='store_true') | |
| parser.add_argument("--num_workers", default=0, type=int) | |
| parser.add_argument('--local_rank', type=int, default=-1) | |
| # parser.add_argument('--rank', type=int, default=-1) | |
| # Model Config | |
| # parser.add_argument('--encoder_backbone', type=str, default='openai/clip-vit-base-patch32') | |
| # parser.add_argument('--decoder_backbone', type=str, default='bert-base-uncased') | |
| parser.add_argument('--tokenizer', type=str, default='openai/clip-vit-base-patch32') | |
| # parser.add_argument('--position_embedding_type', type=str, default='absolute') | |
| # parser.add_argument('--encoder_transform', action='store_true') | |
| parser.add_argument('--max_text_length', type=int, default=40) | |
| # parser.add_argument('--image_size', type=int, default=224) | |
| # parser.add_argument('--patch_size', type=int, default=32) | |
| # parser.add_argument('--decoder_num_layers', type=int, default=12) | |
| # Training | |
| parser.add_argument('--batch_size', type=int, default=256) | |
| parser.add_argument('--valid_batch_size', type=int, default=None) | |
| parser.add_argument('--optim', default='adamw') | |
| parser.add_argument('--warmup_ratio', type=float, default=0.05) | |
| parser.add_argument('--weight_decay', type=float, default=0.01) | |
| parser.add_argument('--clip_grad_norm', type=float, default=-1.0) | |
| parser.add_argument('--gradient_accumulation_steps', type=int, default=1) | |
| parser.add_argument('--lr', type=float, default=1e-4) | |
| parser.add_argument('--adam_eps', type=float, default=1e-6) | |
| parser.add_argument('--adam_beta1', type=float, default=0.9) | |
| parser.add_argument('--adam_beta2', type=float, default=0.999) | |
| parser.add_argument('--epochs', type=int, default=20) | |
| # parser.add_argument('--dropout', type=float, default=0.1) | |
| # Inference | |
| # parser.add_argument('--num_beams', type=int, default=1) | |
| # parser.add_argument('--gen_max_length', type=int, default=20) | |
| parser.add_argument('--start_from', type=str, default=None) | |
| # Data | |
| # parser.add_argument('--do_lower_case', type=str2bool, default=None) | |
| # parser.add_argument('--prefix', type=str, default=None) | |
| # COCO Caption | |
| # parser.add_argument('--no_prefix', action='store_true') | |
| parser.add_argument('--no_cls', action='store_true') | |
| parser.add_argument('--cfg', type=str, default=None) | |
| parser.add_argument('--id', type=str, default=None) | |
| # Etc. | |
| parser.add_argument('--comment', type=str, default='') | |
| parser.add_argument("--dry", action='store_true') | |
| # Parse the arguments. | |
| if parse: | |
| args = parser.parse_args() | |
| # For interative engironmnet (ex. jupyter) | |
| else: | |
| args = parser.parse_known_args()[0] | |
| loaded_kwargs = {} | |
| if args.cfg is not None: | |
| cfg_path = f'configs/{args.cfg}.yaml' | |
| with open(cfg_path, 'r') as f: | |
| loaded_kwargs = yaml.safe_load(f) | |
| # Namespace => Dictionary | |
| parsed_kwargs = vars(args) | |
| parsed_kwargs.update(optional_kwargs) | |
| kwargs = {} | |
| kwargs.update(parsed_kwargs) | |
| kwargs.update(loaded_kwargs) | |
| args = Config(**kwargs) | |
| # Bind optimizer class. | |
| verbose = False | |
| args.optimizer = get_optimizer(args.optim, verbose=verbose) | |
| # Set seeds | |
| torch.manual_seed(args.seed) | |
| random.seed(args.seed) | |
| np.random.seed(args.seed) | |
| return args | |
| class Config(object): | |
| def __init__(self, **kwargs): | |
| """Configuration Class: set kwargs as class attributes with setattr""" | |
| for k, v in kwargs.items(): | |
| setattr(self, k, v) | |
| def config_str(self): | |
| return pprint.pformat(self.__dict__) | |
| def __repr__(self): | |
| """Pretty-print configurations in alphabetical order""" | |
| config_str = 'Configurations\n' | |
| config_str += self.config_str | |
| return config_str | |
| # def update(self, **kwargs): | |
| # for k, v in kwargs.items(): | |
| # setattr(self, k, v) | |
| # def save(self, path): | |
| # with open(path, 'w') as f: | |
| # yaml.dump(self.__dict__, f, default_flow_style=False) | |
| # @classmethod | |
| # def load(cls, path): | |
| # with open(path, 'r') as f: | |
| # kwargs = yaml.load(f) | |
| # return Config(**kwargs) | |
| if __name__ == '__main__': | |
| args = parse_args(True) | |