Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021-2023 Intel Corporation | |
| # SPDX-License-Identifier: Apache-2.0 | |
| from __future__ import print_function | |
| import argparse | |
| import logging | |
| import os | |
| import torch | |
| import yaml | |
| from wenet.utils.init_model import init_model | |
| import intel_extension_for_pytorch as ipex | |
| from intel_extension_for_pytorch.quantization import prepare, convert | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description='export your script model') | |
| parser.add_argument('--config', required=True, help='config file') | |
| parser.add_argument('--checkpoint', required=True, help='checkpoint model') | |
| parser.add_argument('--output_file', default=None, help='output file') | |
| parser.add_argument('--dtype', | |
| default="fp32", | |
| help='choose the dtype to run:[fp32,bf16]') | |
| parser.add_argument('--output_quant_file', | |
| default=None, | |
| help='output quantized model file') | |
| args = parser.parse_args() | |
| return args | |
| def scripting(model): | |
| with torch.inference_mode(): | |
| script_model = torch.jit.script(model) | |
| script_model = torch.jit.freeze( | |
| script_model, | |
| preserved_attrs=[ | |
| "forward_encoder_chunk", "ctc_activation", | |
| "forward_attention_decoder", "subsampling_rate", | |
| "right_context", "sos_symbol", "eos_symbol", | |
| "is_bidirectional_decoder" | |
| ]) | |
| return script_model | |
| def main(): | |
| args = get_args() | |
| logging.basicConfig(level=logging.DEBUG, | |
| format='%(asctime)s %(levelname)s %(message)s') | |
| # No need gpu for model export | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
| with open(args.config, 'r') as fin: | |
| configs = yaml.load(fin, Loader=yaml.FullLoader) | |
| model, configs = init_model(args, configs) | |
| print(model) | |
| # Apply IPEX optimization | |
| model.eval() | |
| torch._C._jit_set_texpr_fuser_enabled(False) | |
| model.to(memory_format=torch.channels_last) | |
| if args.dtype == "fp32": | |
| ipex_model = ipex.optimize(model) | |
| elif args.dtype == "bf16": # For Intel 4th generation Xeon (SPR) | |
| ipex_model = ipex.optimize(model, | |
| dtype=torch.bfloat16, | |
| weights_prepack=False) | |
| # Export jit torch script model | |
| if args.output_file: | |
| if args.dtype == "fp32": | |
| script_model = scripting(ipex_model) | |
| elif args.dtype == "bf16": | |
| torch._C._jit_set_autocast_mode(True) | |
| with torch.cpu.amp.autocast(): | |
| script_model = scripting(ipex_model) | |
| script_model.save(args.output_file) | |
| print('Export model successfully, see {}'.format(args.output_file)) | |
| # Export quantized jit torch script model | |
| if args.output_quant_file: | |
| dynamic_qconfig = ipex.quantization.default_dynamic_qconfig | |
| dummy_data = (torch.zeros(1, 67, 80), 16, -16, | |
| torch.zeros(12, 4, 32, 128), torch.zeros(12, 1, 256, 7)) | |
| model = prepare(model, dynamic_qconfig, dummy_data) | |
| model = convert(model) | |
| script_quant_model = scripting(model) | |
| script_quant_model.save(args.output_quant_file) | |
| print('Export quantized model successfully, ' | |
| 'see {}'.format(args.output_quant_file)) | |
| if __name__ == '__main__': | |
| main() | |