Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import torch | |
| from wenet.finetune.lora.utils import (inject_lora_to_model, | |
| mark_only_lora_as_trainable) | |
| from wenet.k2.model import K2Model | |
| from wenet.llm_asr.init_llmasr import init_llmasr | |
| from wenet.paraformer.cif import Cif | |
| from wenet.paraformer.layers import SanmDecoder, SanmEncoder | |
| from wenet.paraformer.paraformer import Paraformer, Predictor | |
| from wenet.LLM.causallm_model import CausalLM | |
| from wenet.LLM.decoder import DecoderOnly | |
| from wenet.ssl.init_model import WENET_SSL_MODEL_CLASS | |
| from wenet.transducer.joint import TransducerJoint | |
| from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, | |
| RNNPredictor) | |
| from wenet.transducer.transducer import Transducer | |
| from wenet.transformer.asr_model import ASRModel | |
| from wenet.transformer.cmvn import GlobalCMVN | |
| from wenet.transformer.ctc import CTC | |
| from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder | |
| from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder | |
| from wenet.branchformer.encoder import BranchformerEncoder | |
| from wenet.e_branchformer.encoder import EBranchformerEncoder | |
| from wenet.squeezeformer.encoder import SqueezeformerEncoder | |
| from wenet.efficient_conformer.encoder import EfficientConformerEncoder | |
| from wenet.ctl_model.encoder import DualTransformerEncoder, DualConformerEncoder | |
| from wenet.ctl_model.asr_model_ctl import CTLModel | |
| from wenet.whisper.whisper import Whisper | |
| from wenet.utils.cmvn import load_cmvn | |
| from wenet.utils.checkpoint import load_checkpoint, load_trained_modules | |
| WENET_ENCODER_CLASSES = { | |
| "transformer": TransformerEncoder, | |
| "conformer": ConformerEncoder, | |
| "squeezeformer": SqueezeformerEncoder, | |
| "efficientConformer": EfficientConformerEncoder, | |
| "branchformer": BranchformerEncoder, | |
| "e_branchformer": EBranchformerEncoder, | |
| "dual_transformer": DualTransformerEncoder, | |
| "dual_conformer": DualConformerEncoder, | |
| 'sanm_encoder': SanmEncoder, | |
| } | |
| WENET_DECODER_CLASSES = { | |
| "transformer": TransformerDecoder, | |
| "bitransformer": BiTransformerDecoder, | |
| "sanm_decoder": SanmDecoder, | |
| } | |
| WENET_CTC_CLASSES = { | |
| "ctc": CTC, | |
| } | |
| WENET_PREDICTOR_CLASSES = { | |
| "rnn": RNNPredictor, | |
| "embedding": EmbeddingPredictor, | |
| "conv": ConvPredictor, | |
| "cif_predictor": Cif, | |
| "paraformer_predictor": Predictor, | |
| } | |
| WENET_JOINT_CLASSES = { | |
| "transducer_joint": TransducerJoint, | |
| } | |
| WENET_MODEL_CLASSES = { | |
| "asr_model": ASRModel, | |
| "ctl_model": CTLModel, | |
| "whisper": Whisper, | |
| "k2_model": K2Model, | |
| "transducer": Transducer, | |
| 'paraformer': Paraformer, | |
| 'causal_llm': CausalLM, | |
| } | |
| def init_speech_model(args, configs): | |
| # TODO(xcsong): Forcefully read the 'cmvn' attribute. | |
| if configs.get('cmvn', None) == 'global_cmvn': | |
| mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], | |
| configs['cmvn_conf']['is_json_cmvn']) | |
| global_cmvn = GlobalCMVN( | |
| torch.from_numpy(mean).float(), | |
| torch.from_numpy(istd).float()) | |
| else: | |
| global_cmvn = None | |
| input_dim = configs['input_dim'] | |
| vocab_size = configs['output_dim'] | |
| encoder_type = configs.get('encoder', 'conformer') | |
| decoder_type = configs.get('decoder', 'bitransformer') | |
| ctc_type = configs.get('ctc', 'ctc') | |
| encoder = WENET_ENCODER_CLASSES[encoder_type]( | |
| input_dim, | |
| global_cmvn=global_cmvn, | |
| **configs['encoder_conf'], | |
| **configs['encoder_conf']['efficient_conf'] | |
| if 'efficient_conf' in configs['encoder_conf'] else {}) | |
| decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, | |
| encoder.output_size(), | |
| **configs['decoder_conf']) | |
| ctc = WENET_CTC_CLASSES[ctc_type]( | |
| vocab_size, | |
| encoder.output_size(), | |
| blank_id=configs['ctc_conf']['ctc_blank_id'] | |
| if 'ctc_conf' in configs else 0) | |
| model_type = configs.get('model', 'asr_model') | |
| if model_type == "transducer": | |
| predictor_type = configs.get('predictor', 'rnn') | |
| joint_type = configs.get('joint', 'transducer_joint') | |
| predictor = WENET_PREDICTOR_CLASSES[predictor_type]( | |
| vocab_size, **configs['predictor_conf']) | |
| joint = WENET_JOINT_CLASSES[joint_type](vocab_size, | |
| **configs['joint_conf']) | |
| model = WENET_MODEL_CLASSES[model_type]( | |
| vocab_size=vocab_size, | |
| blank=0, | |
| predictor=predictor, | |
| encoder=encoder, | |
| attention_decoder=decoder, | |
| joint=joint, | |
| ctc=ctc, | |
| special_tokens=configs.get('tokenizer_conf', | |
| {}).get('special_tokens', None), | |
| **configs['model_conf']) | |
| elif model_type == 'paraformer': | |
| predictor_type = configs.get('predictor', 'cif') | |
| predictor = WENET_PREDICTOR_CLASSES[predictor_type]( | |
| **configs['predictor_conf']) | |
| model = WENET_MODEL_CLASSES[model_type]( | |
| vocab_size=vocab_size, | |
| encoder=encoder, | |
| decoder=decoder, | |
| predictor=predictor, | |
| ctc=ctc, | |
| **configs['model_conf'], | |
| special_tokens=configs.get('tokenizer_conf', | |
| {}).get('special_tokens', None), | |
| ) | |
| elif model_type in WENET_SSL_MODEL_CLASS.keys(): | |
| from wenet.ssl.init_model import init_model as init_ssl_model | |
| model = init_ssl_model(configs, encoder) | |
| else: | |
| model = WENET_MODEL_CLASSES[model_type]( | |
| vocab_size=vocab_size, | |
| encoder=encoder, | |
| decoder=decoder, | |
| ctc=ctc, | |
| special_tokens=configs.get('tokenizer_conf', | |
| {}).get('special_tokens', None), | |
| **configs['model_conf']) | |
| return model, configs | |
| def init_causal_llm(configs): | |
| vocab_size = configs['output_dim'] | |
| assert configs['decoder'] == 'decoder_only' | |
| assert configs['model'] == 'causal_lm' | |
| decoder_only = DecoderOnly(**configs['decoder_conf']) | |
| model = CausalLM( | |
| vocab_size, | |
| decoder_only, | |
| **configs['model_conf'], | |
| special_tokens=configs.get('tokenizer_conf', | |
| {}).get('special_tokens', None), | |
| ) | |
| return model, configs | |
| def init_model(args, configs): | |
| model_type = configs.get('model', 'asr_model') | |
| configs['model'] = model_type | |
| if model_type == 'causal_lm': | |
| model, configs = init_causal_llm(configs) | |
| elif model_type == "llmasr": | |
| model = init_llmasr(args, configs) | |
| return model | |
| else: | |
| model, configs = init_speech_model(args, configs) | |
| if hasattr(args, 'use_lora') and args.use_lora: | |
| inject_lora_to_model(model, configs['lora_conf']) | |
| # If specify checkpoint, load some info from checkpoint | |
| if hasattr(args, 'checkpoint') and args.checkpoint is not None: | |
| infos = load_checkpoint(model, args.checkpoint) | |
| elif hasattr(args, 'enc_init') and args.enc_init is not None: | |
| infos = load_trained_modules(model, args) | |
| else: | |
| infos = {} | |
| if configs.get('init_step', False): | |
| infos = {} | |
| configs["init_infos"] = infos | |
| if hasattr(args, 'use_lora') and args.use_lora: | |
| if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path: | |
| load_checkpoint(model, args.lora_ckpt_path) | |
| print(configs) | |
| # Trye to tie some weights | |
| if hasattr(model, 'tie_or_clone_weights'): | |
| if not hasattr(args, 'jit'): | |
| args.jit = True # i.e. export onnx/jit/ipex | |
| model.tie_or_clone_weights(args.jit) | |
| if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora: | |
| mark_only_lora_as_trainable(model, bias='lora_only') | |
| if int(os.environ.get('RANK', 0)) == 0: | |
| print(configs) | |
| return model, configs | |