Spaces:
Sleeping
Sleeping
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| from __future__ import print_function | |
| import argparse | |
| import os | |
| import torch | |
| import yaml | |
| from wenet.utils.checkpoint import load_checkpoint | |
| from wenet.utils.init_model import init_model | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description="export your script model") | |
| parser.add_argument("--config", required=True, help="config file") | |
| parser.add_argument("--checkpoint", required=True, help="checkpoint model") | |
| parser.add_argument("--output_file", default=None, help="output file") | |
| parser.add_argument( | |
| "--output_quant_file", default=None, help="output quantized model file" | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = get_args() | |
| # No need gpu for model export | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| with open(args.config, "r") as fin: | |
| configs = yaml.load(fin, Loader=yaml.FullLoader) | |
| model = init_model(configs) | |
| print(model) | |
| load_checkpoint(model, args.checkpoint) | |
| # Export jit torch script model | |
| if args.output_file: | |
| script_model = torch.jit.script(model) | |
| script_model.save(args.output_file) | |
| print("Export model successfully, see {}".format(args.output_file)) | |
| # Export quantized jit torch script model | |
| if args.output_quant_file: | |
| quantized_model = torch.quantization.quantize_dynamic( | |
| model, {torch.nn.Linear}, dtype=torch.qint8 | |
| ) | |
| print(quantized_model) | |
| script_quant_model = torch.jit.script(quantized_model) | |
| script_quant_model.save(args.output_quant_file) | |
| print( | |
| "Export quantized model successfully, " | |
| "see {}".format(args.output_quant_file) | |
| ) | |
| if __name__ == "__main__": | |
| main() | |