| #!/usr/bin/env python | |
| # -*- coding:utf-8 -*- | |
| # Power by Zongsheng Yue 2023-10-26 20:20:36 | |
| import argparse | |
| from omegaconf import OmegaConf | |
| from utils.util_common import get_obj_from_str | |
| from utils.util_opts import str2bool | |
| def get_parser(**parser_kwargs): | |
| parser = argparse.ArgumentParser(**parser_kwargs) | |
| parser.add_argument( | |
| "--save_dir", | |
| type=str, | |
| default="./save_dir", | |
| help="Folder to save the checkpoints and training log", | |
| ) | |
| parser.add_argument( | |
| "--resume", | |
| type=str, | |
| const=True, | |
| default="", | |
| nargs="?", | |
| help="resume from the save_dir or checkpoint", | |
| ) | |
| parser.add_argument( | |
| "--cfg_path", | |
| type=str, | |
| default="./configs/training/ffhq256_bicubic8.yaml", | |
| help="Configs of yaml file", | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = get_parser() | |
| configs = OmegaConf.load(args.cfg_path) | |
| # merge args to config | |
| for key in vars(args): | |
| if key in ['cfg_path', 'save_dir', 'resume', ]: | |
| configs[key] = getattr(args, key) | |
| trainer = get_obj_from_str(configs.trainer.target)(configs) | |
| trainer.train() | |