Spaces:
Sleeping
Sleeping
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| import os | |
| import argparse | |
| import glob | |
| import yaml | |
| import numpy as np | |
| import torch | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description="average model") | |
| parser.add_argument("--dst_model", required=True, help="averaged model") | |
| parser.add_argument("--src_path", required=True, help="src model path for average") | |
| parser.add_argument("--val_best", action="store_true", help="averaged model") | |
| parser.add_argument("--num", default=5, type=int, help="nums for averaged model") | |
| parser.add_argument( | |
| "--min_epoch", default=0, type=int, help="min epoch used for averaging model" | |
| ) | |
| parser.add_argument( | |
| "--max_epoch", | |
| default=65536, | |
| type=int, | |
| help="max epoch used for averaging model", | |
| ) | |
| args = parser.parse_args() | |
| print(args) | |
| return args | |
| def main(): | |
| args = get_args() | |
| checkpoints = [] | |
| val_scores = [] | |
| if args.val_best: | |
| yamls = glob.glob("{}/[!train]*.yaml".format(args.src_path)) | |
| for y in yamls: | |
| with open(y, "r") as f: | |
| dic_yaml = yaml.load(f, Loader=yaml.FullLoader) | |
| loss = dic_yaml["cv_loss"] | |
| epoch = dic_yaml["epoch"] | |
| if epoch >= args.min_epoch and epoch <= args.max_epoch: | |
| val_scores += [[epoch, loss]] | |
| val_scores = np.array(val_scores) | |
| sort_idx = np.argsort(val_scores[:, -1]) | |
| sorted_val_scores = val_scores[sort_idx][::1] | |
| print("best val scores = " + str(sorted_val_scores[: args.num, 1])) | |
| print( | |
| "selected epochs = " | |
| + str(sorted_val_scores[: args.num, 0].astype(np.int64)) | |
| ) | |
| path_list = [ | |
| args.src_path + "/{}.pt".format(int(epoch)) | |
| for epoch in sorted_val_scores[: args.num, 0] | |
| ] | |
| else: | |
| path_list = glob.glob("{}/[0-9]*.pt".format(args.src_path)) | |
| path_list = sorted(path_list, key=os.path.getmtime) | |
| path_list = path_list[-args.num :] | |
| print(path_list) | |
| avg = None | |
| num = args.num | |
| assert num == len(path_list) | |
| for path in path_list: | |
| print("Processing {}".format(path)) | |
| states = torch.load(path, map_location=torch.device("cpu")) | |
| if avg is None: | |
| avg = states | |
| else: | |
| for k in avg.keys(): | |
| avg[k] += states[k] | |
| # average | |
| for k in avg.keys(): | |
| if avg[k] is not None: | |
| # pytorch 1.6 use true_divide instead of /= | |
| avg[k] = torch.true_divide(avg[k], num) | |
| print("Saving to {}".format(args.dst_model)) | |
| torch.save(avg, args.dst_model) | |
| if __name__ == "__main__": | |
| main() | |