Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023 Wenet Community. (authors: Xingchen Song) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Requirements: | |
| ```bash | |
| pip install -U openai-whisper | |
| ``` | |
| Example: | |
| ```bash | |
| # Converts the model from OpenAI to WeNet format: | |
| python convert_whisper_to_wenet_config_and_ckpt.py \ | |
| --whisper_ckpt large-v3.pt \ | |
| --output_dir exp/whisper/large-v3 | |
| ``` | |
| """ | |
| import argparse | |
| import copy | |
| import os | |
| import sys | |
| import torch | |
| import yaml | |
| _cpath_ = sys.path[0] | |
| sys.path.remove(_cpath_) | |
| from whisper.tokenizer import get_tokenizer | |
| sys.path.insert(0, _cpath_) | |
| def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): | |
| configs = {} | |
| configs['input_dim'] = dims['n_mels'] | |
| configs['output_dim'] = dims['n_vocab'] | |
| assert dims['n_vocab'] == tokenizer.encoding.n_vocab, "{} v.s. {}".format( | |
| dims['n_vocab'], tokenizer.encoding.n_vocab) | |
| configs['encoder'] = 'transformer' | |
| configs['encoder_conf'] = {} | |
| configs['encoder_conf']['gradient_checkpointing'] = True | |
| configs['encoder_conf']['input_layer'] = 'conv1d2' | |
| configs['encoder_conf']['output_size'] = dims['n_audio_state'] | |
| configs['encoder_conf']['attention_heads'] = dims['n_audio_head'] | |
| configs['encoder_conf']['linear_units'] = dims['n_audio_state'] * 4 | |
| configs['encoder_conf']['num_blocks'] = dims['n_audio_layer'] | |
| configs['encoder_conf']['dropout_rate'] = 0.1 | |
| configs['encoder_conf']['positional_dropout_rate'] = 0.1 | |
| configs['encoder_conf']['attention_dropout_rate'] = 0.0 | |
| configs['encoder_conf']['normalize_before'] = True | |
| configs['encoder_conf']['use_dynamic_chunk'] = False | |
| configs['encoder_conf']['use_dynamic_left_chunk'] = False | |
| configs['encoder_conf']['pos_enc_layer_type'] = "abs_pos_whisper" | |
| configs['encoder_conf']['static_chunk_size'] = -1 | |
| configs['encoder_conf']['key_bias'] = False | |
| configs['encoder_conf']['activation_type'] = "gelu" | |
| configs['decoder'] = 'transformer' | |
| configs['decoder_conf'] = {} | |
| configs['decoder_conf']['tie_word_embedding'] = True | |
| configs['decoder_conf']['gradient_checkpointing'] = True | |
| configs['decoder_conf']['attention_heads'] = dims['n_text_head'] | |
| configs['decoder_conf']['linear_units'] = dims['n_text_state'] * 4 | |
| configs['decoder_conf']['num_blocks'] = dims['n_text_layer'] | |
| configs['decoder_conf']['dropout_rate'] = 0.1 | |
| configs['decoder_conf']['positional_dropout_rate'] = 0.1 | |
| configs['decoder_conf']['self_attention_dropout_rate'] = 0.0 | |
| configs['decoder_conf']['src_attention_dropout_rate'] = 0.0 | |
| configs['decoder_conf']['input_layer'] = "embed_learnable_pe" | |
| configs['decoder_conf']['use_output_layer'] = True | |
| configs['decoder_conf']['normalize_before'] = True | |
| configs['decoder_conf']['src_attention'] = True | |
| configs['decoder_conf']['key_bias'] = False | |
| configs['decoder_conf']['activation_type'] = "gelu" | |
| configs['tokenizer'] = 'whisper' | |
| configs['tokenizer_conf'] = {} | |
| configs['tokenizer_conf']['is_multilingual'] = dims['n_vocab'] >= 51865 | |
| configs['tokenizer_conf']['num_languages'] = dims['n_vocab'] - 51765 - \ | |
| int(configs['tokenizer_conf']['is_multilingual']) | |
| configs['tokenizer_conf']['split_with_space'] = False | |
| configs['tokenizer_conf']['bpe_path'] = None | |
| configs['tokenizer_conf']['symbol_table_path'] = None | |
| configs['tokenizer_conf']['non_lang_syms_path'] = None | |
| configs['tokenizer_conf']['special_tokens'] = {} | |
| configs['tokenizer_conf']['special_tokens']['sot'] = tokenizer.sot | |
| configs['tokenizer_conf']['special_tokens']['eot'] = tokenizer.eot | |
| configs['tokenizer_conf']['special_tokens'][ | |
| 'sot_prev'] = tokenizer.sot_prev | |
| configs['tokenizer_conf']['special_tokens'][ | |
| 'transcribe'] = tokenizer.transcribe | |
| configs['tokenizer_conf']['special_tokens'][ | |
| 'translate'] = tokenizer.translate | |
| configs['tokenizer_conf']['special_tokens'][ | |
| 'no_timestamps'] = tokenizer.no_timestamps | |
| configs['tokenizer_conf']['special_tokens'][ | |
| 'no_speech'] = tokenizer.no_speech | |
| configs['tokenizer_conf']['special_tokens']['timestamp_begin'] = \ | |
| tokenizer.timestamp_begin | |
| configs['ctc_conf'] = {} | |
| configs['ctc_conf']['ctc_blank_id'] = tokenizer.no_speech | |
| configs['cmvn'] = None | |
| configs['cmvn_conf'] = {} | |
| configs['cmvn_conf']['cmvn_file'] = None | |
| configs['cmvn_conf']['is_json_cmvn'] = None | |
| configs['model'] = "whisper" | |
| configs['model_conf'] = {} | |
| configs['model_conf']['ctc_weight'] = 0.3 | |
| configs['model_conf']['lsm_weight'] = 0.1 | |
| configs['model_conf']['length_normalized_loss'] = False | |
| configs['dataset'] = "asr" | |
| configs['dataset_conf'] = {} | |
| configs['dataset_conf']['filter_conf'] = {} | |
| configs['dataset_conf']['filter_conf'][ | |
| 'max_length'] = dims['n_audio_ctx'] * 2 # 1/2 subsample # noqa | |
| configs['dataset_conf']['filter_conf']['min_length'] = 0 | |
| configs['dataset_conf']['filter_conf']['token_max_length'] = dims[ | |
| 'n_text_ctx'] | |
| configs['dataset_conf']['filter_conf']['token_min_length'] = 1 | |
| configs['dataset_conf']['resample_conf'] = {} | |
| configs['dataset_conf']['resample_conf']['resample_rate'] = 16000 | |
| # NOTE: Disable speed_perturb, https://github.com/wenet-e2e/wenet/issues/2171 | |
| configs['dataset_conf']['speed_perturb'] = False | |
| configs['dataset_conf']['spec_aug'] = True | |
| configs['dataset_conf']['spec_aug_conf'] = {} | |
| configs['dataset_conf']['spec_aug_conf']['num_t_mask'] = 2 | |
| configs['dataset_conf']['spec_aug_conf']['num_f_mask'] = 2 | |
| configs['dataset_conf']['spec_aug_conf']['max_t'] = 50 | |
| configs['dataset_conf']['spec_aug_conf']['max_f'] = 10 | |
| configs['dataset_conf']['spec_sub'] = True | |
| configs['dataset_conf']['spec_sub_conf'] = {} | |
| configs['dataset_conf']['spec_sub_conf']['num_t_sub'] = 3 | |
| configs['dataset_conf']['spec_sub_conf']['max_t'] = 30 | |
| configs['dataset_conf']['spec_trim'] = False | |
| configs['dataset_conf']['shuffle'] = True | |
| configs['dataset_conf']['shuffle_conf'] = {} | |
| configs['dataset_conf']['shuffle_conf']['shuffle_size'] = 1500 | |
| configs['dataset_conf']['sort'] = True | |
| configs['dataset_conf']['sort_conf'] = {} | |
| configs['dataset_conf']['sort_conf']['sort_size'] = 500 | |
| configs['dataset_conf']['feats_type'] = "log_mel_spectrogram" | |
| configs['dataset_conf']['log_mel_spectrogram_conf'] = {} | |
| configs['dataset_conf']['log_mel_spectrogram_conf']['n_fft'] = 400 | |
| configs['dataset_conf']['log_mel_spectrogram_conf']['hop_length'] = 160 | |
| configs['dataset_conf']['log_mel_spectrogram_conf']['num_mel_bins'] = dims[ | |
| 'n_mels'] | |
| configs['dataset_conf']['log_mel_spectrogram_conf']['padding'] = 0 | |
| configs['dataset_conf']['batch_conf'] = {} | |
| configs['dataset_conf']['batch_conf']['batch_type'] = 'dynamic' | |
| configs['dataset_conf']['batch_conf']['batch_size'] = 26 | |
| configs['dataset_conf']['batch_conf']['max_frames_in_batch'] = 12000 | |
| configs['dataset_conf']['language_conf'] = {} | |
| configs['dataset_conf']['language_conf']['limited_langs'] = ['zh'] | |
| configs['grad_clip'] = 5 | |
| configs['accum_grad'] = 4 | |
| configs['max_epoch'] = 100 | |
| configs['log_interval'] = 100 | |
| configs['optim'] = "adam" | |
| configs['optim_conf'] = {} | |
| configs['optim_conf']['lr'] = 0.0005 | |
| configs['scheduler'] = "warmuplr" | |
| configs['scheduler_conf'] = {} | |
| configs['scheduler_conf']['warmup_steps'] = 12000 | |
| with open(wenet_yaml_path, '+w') as f: | |
| f.write(yaml.dump(configs)) | |
| f.flush() | |
| print(configs) | |
| def convert_to_wenet_state_dict(whisper_state_dict, wenet_state_dict_path): | |
| wenet_state_dict = {} | |
| unused = [] | |
| print( | |
| "===================== start CKPT Conversion =========================" | |
| ) | |
| for name in whisper_state_dict.keys(): | |
| original_name = copy.deepcopy(name) | |
| name = name.replace("encoder.conv1", "encoder.embed.conv.0") | |
| name = name.replace("encoder.conv2", "encoder.embed.conv.2") | |
| name = name.replace("decoder.token_embedding", "decoder.embed.0") | |
| name = name.replace("encoder.blocks", "encoder.encoders") | |
| name = name.replace("decoder.blocks", "decoder.decoders") | |
| name = name.replace(".cross_attn.query", ".src_attn.linear_q") | |
| name = name.replace(".cross_attn.key", ".src_attn.linear_k") | |
| name = name.replace(".cross_attn.value", ".src_attn.linear_v") | |
| name = name.replace(".cross_attn.out", ".src_attn.linear_out") | |
| name = name.replace(".attn.query", ".self_attn.linear_q") | |
| name = name.replace(".attn.key", ".self_attn.linear_k") | |
| name = name.replace(".attn.value", ".self_attn.linear_v") | |
| name = name.replace(".attn.out", ".self_attn.linear_out") | |
| name = name.replace("mlp.0", "feed_forward.w_1") | |
| name = name.replace("mlp.2", "feed_forward.w_2") | |
| if "decoder" in name: | |
| name = name.replace("cross_attn_ln", "norm2") | |
| name = name.replace("mlp_ln", "norm3") | |
| else: | |
| name = name.replace("mlp_ln", "norm2") | |
| name = name.replace("attn_ln", "norm1") | |
| name = name.replace("encoder.ln_post", "encoder.after_norm") | |
| name = name.replace("decoder.ln", "decoder.after_norm") | |
| if original_name == "decoder.positional_embedding": | |
| whisper_state_dict[name] = whisper_state_dict[name].unsqueeze(0) | |
| name = "decoder.embed.1.pe" | |
| elif original_name == "encoder.positional_embedding": | |
| whisper_state_dict[name] = whisper_state_dict[name].unsqueeze(0) | |
| name = "encoder.embed.pos_enc.pe" | |
| print("name {} ==> {}".format(original_name, name)) | |
| print("type {} ==> torch.float32".format( | |
| whisper_state_dict[original_name].dtype)) | |
| print("shape {}\n".format(whisper_state_dict[original_name].shape)) | |
| if (original_name == name): | |
| unused.append(name) | |
| else: | |
| wenet_state_dict[name] = whisper_state_dict[original_name].float() | |
| for name in unused: | |
| print("NOTE!!! drop {}".format(name)) | |
| print("Saving fp32 ckpt to {}...".format(wenet_state_dict_path)) | |
| torch.save(wenet_state_dict, wenet_state_dict_path) | |
| print( | |
| "DONE\n===================== End CKPT Conversion =========================\n" | |
| ) | |
| def convert_to_wenet_units(tokenizer, units_txt_path): | |
| """ NOTE(xcsong): | |
| The "units.txt" file is solely for adapting to the training API of Wenet | |
| and for quickly checking the corresponding text of an ID when necessary. | |
| It does not play any role in the tokenization process, | |
| which is carried out by the tokenizer of openai-whisper. | |
| """ | |
| n_vocab = tokenizer.encoding.n_vocab | |
| with open(units_txt_path, "+w") as f: | |
| for i in range(n_vocab): | |
| unit = str(tokenizer.encoding.decode_single_token_bytes(i)) | |
| if len(unit) == 0: | |
| unit = str(i) | |
| print("can not decode id {}, convert to str({})".format(i, i)) | |
| unit = unit.replace(" ", "<space>") | |
| f.write("{} {}\n".format(unit, i)) | |
| f.flush() | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description='load and parse whisper') | |
| # yapf: disable | |
| parser.add_argument( | |
| '--whisper_ckpt', | |
| required=True, | |
| help='https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt' # noqa | |
| ) | |
| # yapf: enable | |
| parser.add_argument('--output_dir', | |
| default='.', | |
| help='output file in wenet\'s style: ' + | |
| 'units.txt, train.yaml, model.pt') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = get_args() | |
| checkpoint = torch.load(args.whisper_ckpt, map_location="cpu") | |
| multilingual = checkpoint["dims"]['n_vocab'] >= 51865 | |
| num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) | |
| tokenizer = get_tokenizer(multilingual=multilingual, | |
| num_languages=num_languages) | |
| convert_to_wenet_state_dict( | |
| checkpoint["model_state_dict"], | |
| os.path.join(args.output_dir, 'wenet_whisper.pt')) | |
| convert_to_wenet_units(tokenizer, os.path.join(args.output_dir, | |
| 'units.txt')) | |
| convert_to_wenet_yaml(tokenizer, checkpoint["dims"], | |
| os.path.join(args.output_dir, 'train.yaml')) | |
| if __name__ == "__main__": | |
| main() | |