Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ Checkpoint Averaging Script | |
| This script averages all model weights for checkpoints in specified path that match | |
| the specified filter wildcard. All checkpoints must be from the exact same model. | |
| For any hope of decent results, the checkpoints should be from the same or child | |
| (via resumes) training session. This can be viewed as similar to maintaining running | |
| EMA (exponential moving average) of the model weights or performing SWA (stochastic | |
| weight averaging), but post-training. | |
| Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import argparse | |
| import os | |
| import glob | |
| import hashlib | |
| from collections import OrderedDict, defaultdict | |
| import re | |
| import copy | |
| from safetensors.torch import load_file as safetensors_load_file | |
| from safetensors.torch import save_file as safetensors_save_file | |
| parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') | |
| parser.add_argument('--input', default='', nargs="+", type=str, metavar='PATHS', | |
| help='path(s) to base input folder containing checkpoints') | |
| parser.add_argument('--output', type=str, default='avgmodel.pt', metavar='PATH', | |
| help='output file name of the averaged checkpoint') | |
| parser.add_argument('--suffix', default='', type=str, metavar='WILDCARD', | |
| help='checkpoint suffix') | |
| parser.add_argument('--min', type=int, default=500, help='Minimal iteration of checkpoints to average') | |
| parser.add_argument('--max', type=int, default=-1, help='Maximum iteration of checkpoints to average') | |
| def main(): | |
| args = parser.parse_args() | |
| patterns = args.input | |
| sel_checkpoint_filenames = [] | |
| for pattern in patterns: | |
| if args.suffix is not None: | |
| if not args.suffix.startswith('*'): | |
| pattern += '*' | |
| pattern += args.suffix | |
| checkpoint_filenames = glob.glob(pattern, recursive=True) | |
| if len(checkpoint_filenames) == 0: | |
| print("WARNING: No checkpoints matching '{}' and iteration >= {} in '{}'".format( | |
| pattern, args.min, args.input)) | |
| continue | |
| sel_checkpoint_filenames += checkpoint_filenames | |
| avg_ckpt = {} | |
| avg_counts = {} | |
| for i, c in enumerate(sel_checkpoint_filenames): | |
| if c.endswith(".safetensors"): | |
| checkpoint = safetensors_load_file(c) | |
| else: | |
| checkpoint = torch.load(c, map_location='cpu') | |
| print(c) | |
| for k in checkpoint: | |
| # Skip ema weights | |
| if k.startswith("model_ema."): | |
| continue | |
| if k not in avg_ckpt: | |
| avg_ckpt[k] = checkpoint[k] | |
| print(f"Copy {k}") | |
| avg_counts[k] = 1 | |
| # Another occurrence of a previously seen nn.Module. | |
| elif isinstance(checkpoint[k], nn.Module): | |
| #print(f"nn.Module: {k}") | |
| avg_state_dict = avg_ckpt[k].state_dict() | |
| param_state_dict = checkpoint[k] | |
| for m_k, m_v in param_state_dict.state_dict().items(): | |
| if m_k not in avg_state_dict: | |
| avg_state_dict[m_k] = copy.copy(m_v) | |
| print(f"Copy {k}.{m_k}") | |
| else: | |
| avg_state_dict[m_k].data += m_v | |
| print(f"Accumulate {k}.{m_k}") | |
| avg_ckpt[k].load_state_dict(avg_state_dict) | |
| avg_counts[k] = 1 | |
| # Another occurrence of a previously seen nn.Parameter. | |
| elif isinstance(checkpoint[k], (nn.Parameter, torch.Tensor)): | |
| #print(f"nn.Parameter: {k}") | |
| avg_ckpt[k].data += checkpoint[k].data | |
| avg_counts[k] += 1 | |
| else: | |
| print(f"NOT copying {type(checkpoint[k])}: {k}") | |
| pass | |
| for k in avg_ckpt: | |
| # safetensors use torch.Tensor instead of nn.Parameter. | |
| if isinstance(avg_ckpt[k], (nn.Parameter, torch.Tensor)): | |
| print(f"Averaging nn.Parameter: {k}") | |
| avg_ckpt[k].data /= avg_counts[k] | |
| elif isinstance(avg_ckpt[k], nn.Module): | |
| print(f"Averaging nn.Module: {k}") | |
| avg_state_dict = avg_ckpt[k].state_dict() | |
| for m_k, m_v in avg_state_dict.items(): | |
| m_v.data = (m_v.data / avg_counts[k]).to(m_v.data.dtype) | |
| avg_ckpt[k].load_state_dict(avg_state_dict) | |
| else: | |
| print(f"NOT averaging {type(avg_ckpt[k])}: {k}") | |
| if args.output.endswith(".safetensors"): | |
| safetensors_save_file(avg_ckpt, args.output) | |
| else: | |
| torch.save(avg_ckpt, args.output) | |
| print("=> Saved state_dict to '{}'".format(args.output)) | |
| if __name__ == '__main__': | |
| main() | |