Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import os | |
| import shutil | |
| import torch | |
| def get_checkpoint_files(s): | |
| s = s.strip() | |
| if ',' in s: | |
| return [get_checkpoint_files(chunk) for chunk in s.split(',')] | |
| return 'last.ckpt' if s == 'last' else f'{s}.ckpt' | |
| def main(args): | |
| checkpoint_fnames = get_checkpoint_files(args.epochs) | |
| if isinstance(checkpoint_fnames, str): | |
| checkpoint_fnames = [checkpoint_fnames] | |
| assert len(checkpoint_fnames) >= 1 | |
| checkpoint_path = os.path.join(args.indir, 'models', checkpoint_fnames[0]) | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| del checkpoint['optimizer_states'] | |
| if len(checkpoint_fnames) > 1: | |
| for fname in checkpoint_fnames[1:]: | |
| print('sum', fname) | |
| sum_tensors_cnt = 0 | |
| other_cp = torch.load(os.path.join(args.indir, 'models', fname), map_location='cpu') | |
| for k in checkpoint['state_dict'].keys(): | |
| if checkpoint['state_dict'][k].dtype is torch.float: | |
| checkpoint['state_dict'][k].data.add_(other_cp['state_dict'][k].data) | |
| sum_tensors_cnt += 1 | |
| print('summed', sum_tensors_cnt, 'tensors') | |
| for k in checkpoint['state_dict'].keys(): | |
| if checkpoint['state_dict'][k].dtype is torch.float: | |
| checkpoint['state_dict'][k].data.mul_(1 / float(len(checkpoint_fnames))) | |
| state_dict = checkpoint['state_dict'] | |
| if not args.leave_discriminators: | |
| for k in list(state_dict.keys()): | |
| if k.startswith('discriminator.'): | |
| del state_dict[k] | |
| if not args.leave_losses: | |
| for k in list(state_dict.keys()): | |
| if k.startswith('loss_'): | |
| del state_dict[k] | |
| out_checkpoint_path = os.path.join(args.outdir, 'models', 'best.ckpt') | |
| os.makedirs(os.path.dirname(out_checkpoint_path), exist_ok=True) | |
| torch.save(checkpoint, out_checkpoint_path) | |
| shutil.copy2(os.path.join(args.indir, 'config.yaml'), | |
| os.path.join(args.outdir, 'config.yaml')) | |
| if __name__ == '__main__': | |
| import argparse | |
| aparser = argparse.ArgumentParser() | |
| aparser.add_argument('indir', | |
| help='Path to directory with output of training ' | |
| '(i.e. directory, which has samples, modules, config.yaml and train.log') | |
| aparser.add_argument('outdir', | |
| help='Where to put minimal checkpoint, which can be consumed by "bin/predict.py"') | |
| aparser.add_argument('--epochs', type=str, default='last', | |
| help='Which checkpoint to take. ' | |
| 'Can be "last" or integer - number of epoch') | |
| aparser.add_argument('--leave-discriminators', action='store_true', | |
| help='If enabled, the state of discriminators will not be removed from the checkpoint') | |
| aparser.add_argument('--leave-losses', action='store_true', | |
| help='If enabled, weights of nn-based losses (e.g. perceptual) will not be removed') | |
| main(aparser.parse_args()) | |