Spaces:
Sleeping
Sleeping
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| from __future__ import print_function | |
| import argparse | |
| import copy | |
| import logging | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| import torch.optim as optim | |
| import yaml | |
| from tensorboardX import SummaryWriter | |
| from torch.utils.data import DataLoader | |
| from wenet.dataset.dataset import Dataset | |
| from wenet.utils.checkpoint import ( | |
| load_checkpoint, | |
| save_checkpoint, | |
| load_trained_modules, | |
| ) | |
| from wenet.utils.executor import Executor | |
| from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols | |
| from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing | |
| from wenet.utils.config import override_config | |
| from wenet.utils.init_model import init_model | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description="training your network") | |
| parser.add_argument("--config", required=True, help="config file") | |
| parser.add_argument( | |
| "--data_type", | |
| default="raw", | |
| choices=["raw", "shard"], | |
| help="train and cv data type", | |
| ) | |
| parser.add_argument("--train_data", required=True, help="train data file") | |
| parser.add_argument("--cv_data", required=True, help="cv data file") | |
| parser.add_argument( | |
| "--gpu", type=int, default=-1, help="gpu id for this local rank, -1 for cpu" | |
| ) | |
| parser.add_argument("--model_dir", required=True, help="save model dir") | |
| parser.add_argument("--checkpoint", help="checkpoint model") | |
| parser.add_argument( | |
| "--tensorboard_dir", default="tensorboard", help="tensorboard log dir" | |
| ) | |
| parser.add_argument( | |
| "--ddp.rank", | |
| dest="rank", | |
| default=0, | |
| type=int, | |
| help="global rank for distributed training", | |
| ) | |
| parser.add_argument( | |
| "--ddp.world_size", | |
| dest="world_size", | |
| default=-1, | |
| type=int, | |
| help="""number of total processes/gpus for | |
| distributed training""", | |
| ) | |
| parser.add_argument( | |
| "--ddp.dist_backend", | |
| dest="dist_backend", | |
| default="nccl", | |
| choices=["nccl", "gloo"], | |
| help="distributed backend", | |
| ) | |
| parser.add_argument( | |
| "--ddp.init_method", dest="init_method", default=None, help="ddp init method" | |
| ) | |
| parser.add_argument( | |
| "--num_workers", | |
| default=0, | |
| type=int, | |
| help="num of subprocess workers for reading", | |
| ) | |
| parser.add_argument( | |
| "--pin_memory", | |
| action="store_true", | |
| default=False, | |
| help="Use pinned memory buffers used for reading", | |
| ) | |
| parser.add_argument( | |
| "--use_amp", | |
| action="store_true", | |
| default=False, | |
| help="Use automatic mixed precision training", | |
| ) | |
| parser.add_argument( | |
| "--fp16_grad_sync", | |
| action="store_true", | |
| default=False, | |
| help="Use fp16 gradient sync for ddp", | |
| ) | |
| parser.add_argument("--cmvn", default=None, help="global cmvn file") | |
| parser.add_argument( | |
| "--symbol_table", required=True, help="model unit symbol table for training" | |
| ) | |
| parser.add_argument( | |
| "--non_lang_syms", help="non-linguistic symbol file. One symbol per line." | |
| ) | |
| parser.add_argument("--prefetch", default=100, type=int, help="prefetch number") | |
| parser.add_argument( | |
| "--bpe_model", default=None, type=str, help="bpe model for english part" | |
| ) | |
| parser.add_argument( | |
| "--override_config", action="append", default=[], help="override yaml config" | |
| ) | |
| parser.add_argument( | |
| "--enc_init", | |
| default=None, | |
| type=str, | |
| help="Pre-trained model to initialize encoder", | |
| ) | |
| parser.add_argument( | |
| "--enc_init_mods", | |
| default="encoder.", | |
| type=lambda s: [str(mod) for mod in s.split(",") if s != ""], | |
| help="List of encoder modules \ | |
| to initialize ,separated by a comma", | |
| ) | |
| parser.add_argument("--lfmmi_dir", default="", required=False, help="LF-MMI dir") | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = get_args() | |
| logging.basicConfig( | |
| level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" | |
| ) | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) | |
| # Set random seed | |
| torch.manual_seed(777) | |
| with open(args.config, "r") as fin: | |
| configs = yaml.load(fin, Loader=yaml.FullLoader) | |
| if len(args.override_config) > 0: | |
| configs = override_config(configs, args.override_config) | |
| distributed = args.world_size > 1 | |
| if distributed: | |
| logging.info("training on multiple gpus, this gpu {}".format(args.gpu)) | |
| dist.init_process_group( | |
| args.dist_backend, | |
| init_method=args.init_method, | |
| world_size=args.world_size, | |
| rank=args.rank, | |
| ) | |
| symbol_table = read_symbol_table(args.symbol_table) | |
| train_conf = configs["dataset_conf"] | |
| cv_conf = copy.deepcopy(train_conf) | |
| cv_conf["speed_perturb"] = False | |
| cv_conf["spec_aug"] = False | |
| cv_conf["spec_sub"] = False | |
| cv_conf["spec_trim"] = False | |
| cv_conf["shuffle"] = False | |
| non_lang_syms = read_non_lang_symbols(args.non_lang_syms) | |
| train_dataset = Dataset( | |
| args.data_type, | |
| args.train_data, | |
| symbol_table, | |
| train_conf, | |
| args.bpe_model, | |
| non_lang_syms, | |
| True, | |
| ) | |
| cv_dataset = Dataset( | |
| args.data_type, | |
| args.cv_data, | |
| symbol_table, | |
| cv_conf, | |
| args.bpe_model, | |
| non_lang_syms, | |
| partition=False, | |
| ) | |
| train_data_loader = DataLoader( | |
| train_dataset, | |
| batch_size=None, | |
| pin_memory=args.pin_memory, | |
| num_workers=args.num_workers, | |
| prefetch_factor=args.prefetch, | |
| ) | |
| cv_data_loader = DataLoader( | |
| cv_dataset, | |
| batch_size=None, | |
| pin_memory=args.pin_memory, | |
| num_workers=args.num_workers, | |
| prefetch_factor=args.prefetch, | |
| ) | |
| if "fbank_conf" in configs["dataset_conf"]: | |
| input_dim = configs["dataset_conf"]["fbank_conf"]["num_mel_bins"] | |
| else: | |
| input_dim = configs["dataset_conf"]["mfcc_conf"]["num_mel_bins"] | |
| vocab_size = len(symbol_table) | |
| # Save configs to model_dir/train.yaml for inference and export | |
| configs["input_dim"] = input_dim | |
| configs["output_dim"] = vocab_size | |
| configs["cmvn_file"] = args.cmvn | |
| configs["is_json_cmvn"] = True | |
| configs["lfmmi_dir"] = args.lfmmi_dir | |
| if args.rank == 0: | |
| saved_config_path = os.path.join(args.model_dir, "train.yaml") | |
| with open(saved_config_path, "w") as fout: | |
| data = yaml.dump(configs) | |
| fout.write(data) | |
| # Init asr model from configs | |
| model = init_model(configs) | |
| print(model) | |
| num_params = sum(p.numel() for p in model.parameters()) | |
| print("the number of model params: {:,d}".format(num_params)) | |
| # !!!IMPORTANT!!! | |
| # Try to export the model by script, if fails, we should refine | |
| # the code to satisfy the script export requirements | |
| if args.rank == 0: | |
| script_model = torch.jit.script(model) | |
| script_model.save(os.path.join(args.model_dir, "init.zip")) | |
| executor = Executor() | |
| # If specify checkpoint, load some info from checkpoint | |
| if args.checkpoint is not None: | |
| infos = load_checkpoint(model, args.checkpoint) | |
| elif args.enc_init is not None: | |
| logging.info("load pretrained encoders: {}".format(args.enc_init)) | |
| infos = load_trained_modules(model, args) | |
| else: | |
| infos = {} | |
| start_epoch = infos.get("epoch", -1) + 1 | |
| cv_loss = infos.get("cv_loss", 0.0) | |
| step = infos.get("step", -1) | |
| num_epochs = configs.get("max_epoch", 100) | |
| model_dir = args.model_dir | |
| writer = None | |
| if args.rank == 0: | |
| os.makedirs(model_dir, exist_ok=True) | |
| exp_id = os.path.basename(model_dir) | |
| writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) | |
| if distributed: | |
| assert torch.cuda.is_available() | |
| # cuda model is required for nn.parallel.DistributedDataParallel | |
| model.cuda() | |
| model = torch.nn.parallel.DistributedDataParallel( | |
| model, find_unused_parameters=True | |
| ) | |
| device = torch.device("cuda") | |
| if args.fp16_grad_sync: | |
| from torch.distributed.algorithms.ddp_comm_hooks import ( | |
| default as comm_hooks, | |
| ) | |
| model.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) | |
| else: | |
| use_cuda = args.gpu >= 0 and torch.cuda.is_available() | |
| device = torch.device("cuda" if use_cuda else "cpu") | |
| model = model.to(device) | |
| if configs["optim"] == "adam": | |
| optimizer = optim.Adam(model.parameters(), **configs["optim_conf"]) | |
| elif configs["optim"] == "adamw": | |
| optimizer = optim.AdamW(model.parameters(), **configs["optim_conf"]) | |
| else: | |
| raise ValueError("unknown optimizer: " + configs["optim"]) | |
| if configs["scheduler"] == "warmuplr": | |
| scheduler = WarmupLR(optimizer, **configs["scheduler_conf"]) | |
| elif configs["scheduler"] == "NoamHoldAnnealing": | |
| scheduler = NoamHoldAnnealing(optimizer, **configs["scheduler_conf"]) | |
| else: | |
| raise ValueError("unknown scheduler: " + configs["scheduler"]) | |
| final_epoch = None | |
| configs["rank"] = args.rank | |
| configs["is_distributed"] = distributed | |
| configs["use_amp"] = args.use_amp | |
| if start_epoch == 0 and args.rank == 0: | |
| save_model_path = os.path.join(model_dir, "init.pt") | |
| save_checkpoint(model, save_model_path) | |
| # Start training loop | |
| executor.step = step | |
| scheduler.set_step(step) | |
| # used for pytorch amp mixed precision training | |
| scaler = None | |
| if args.use_amp: | |
| scaler = torch.cuda.amp.GradScaler() | |
| for epoch in range(start_epoch, num_epochs): | |
| train_dataset.set_epoch(epoch) | |
| configs["epoch"] = epoch | |
| lr = optimizer.param_groups[0]["lr"] | |
| logging.info("Epoch {} TRAIN info lr {}".format(epoch, lr)) | |
| executor.train( | |
| model, | |
| optimizer, | |
| scheduler, | |
| train_data_loader, | |
| device, | |
| writer, | |
| configs, | |
| scaler, | |
| ) | |
| total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device, configs) | |
| cv_loss = total_loss / num_seen_utts | |
| logging.info("Epoch {} CV info cv_loss {}".format(epoch, cv_loss)) | |
| if args.rank == 0: | |
| save_model_path = os.path.join(model_dir, "{}.pt".format(epoch)) | |
| save_checkpoint( | |
| model, | |
| save_model_path, | |
| {"epoch": epoch, "lr": lr, "cv_loss": cv_loss, "step": executor.step}, | |
| ) | |
| writer.add_scalar("epoch/cv_loss", cv_loss, epoch) | |
| writer.add_scalar("epoch/lr", lr, epoch) | |
| final_epoch = epoch | |
| if final_epoch is not None and args.rank == 0: | |
| final_model_path = os.path.join(model_dir, "final.pt") | |
| os.remove(final_model_path) if os.path.exists(final_model_path) else None | |
| os.symlink("{}.pt".format(final_epoch), final_model_path) | |
| writer.close() | |
| if __name__ == "__main__": | |
| main() | |