Spaces:
Build error
Build error
| """ | |
| @date: 2021/11/22 | |
| @description: Conversion training ckpt into inference ckpt | |
| """ | |
| import argparse | |
| import os | |
| import torch | |
| from config.defaults import merge_from_file | |
| def parse_option(): | |
| parser = argparse.ArgumentParser(description='Conversion training ckpt into inference ckpt') | |
| parser.add_argument('--cfg', | |
| type=str, | |
| required=True, | |
| metavar='FILE', | |
| help='path of config file') | |
| parser.add_argument('--output_path', | |
| type=str, | |
| help='path of output ckpt') | |
| args = parser.parse_args() | |
| print("arguments:") | |
| for arg in vars(args): | |
| print(arg, ":", getattr(args, arg)) | |
| print("-" * 50) | |
| return args | |
| def convert_ckpt(): | |
| args = parse_option() | |
| config = merge_from_file(args.cfg) | |
| ck_dir = os.path.join("checkpoints", f"{config.MODEL.ARGS[0]['decoder_name']}_{config.MODEL.ARGS[0]['output_name']}_Net", | |
| config.TAG) | |
| print(f"Processing {ck_dir}") | |
| model_paths = [name for name in os.listdir(ck_dir) if '_best_' in name] | |
| if len(model_paths) == 0: | |
| print("Not find best ckpt") | |
| return | |
| model_path = os.path.join(ck_dir, model_paths[0]) | |
| print(f"Loading {model_path}") | |
| checkpoint = torch.load(model_path, map_location=torch.device('cuda:0')) | |
| net = checkpoint['net'] | |
| output_path = None | |
| if args.output_path is None: | |
| output_path = os.path.join(ck_dir, 'best.pkl') | |
| else: | |
| output_path = args.output_path | |
| if output_path is None: | |
| print("Output path is invalid") | |
| print(f"Save on: {output_path}") | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| torch.save(net, output_path) | |
| if __name__ == '__main__': | |
| convert_ckpt() | |