Spaces:
Sleeping
Sleeping
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| from __future__ import print_function | |
| import os | |
| import sys | |
| import copy | |
| import math | |
| import yaml | |
| import logging | |
| from typing import Tuple | |
| import torch | |
| import numpy as np | |
| from wenet.transformer.embedding import NoPositionalEncoding | |
| from wenet.utils.checkpoint import load_checkpoint | |
| from wenet.utils.init_model import init_model | |
| from wenet.bin.export_onnx_cpu import get_args, to_numpy, print_input_output_info | |
| try: | |
| import onnx | |
| import onnxruntime | |
| except ImportError: | |
| print("Please install onnx and onnxruntime!") | |
| sys.exit(1) | |
| logger = logging.getLogger(__file__) | |
| logger.setLevel(logging.INFO) | |
| class BPULayerNorm(torch.nn.Module): | |
| """Refactor torch.nn.LayerNorm to meet 4-D dataflow.""" | |
| def __init__(self, module, chunk_size=8, run_on_bpu=False): | |
| super().__init__() | |
| original = copy.deepcopy(module) | |
| self.hidden = module.weight.size(0) | |
| self.chunk_size = chunk_size | |
| self.run_on_bpu = run_on_bpu | |
| if self.run_on_bpu: | |
| self.weight = torch.nn.Parameter( | |
| module.weight.reshape(1, self.hidden, 1, 1).repeat(1, 1, 1, chunk_size) | |
| ) | |
| self.bias = torch.nn.Parameter( | |
| module.bias.reshape(1, self.hidden, 1, 1).repeat(1, 1, 1, chunk_size) | |
| ) | |
| self.negtive = torch.nn.Parameter( | |
| torch.ones((1, self.hidden, 1, chunk_size)) * -1.0 | |
| ) | |
| self.eps = torch.nn.Parameter( | |
| torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps | |
| ) | |
| self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) | |
| self.mean_conv_1.weight = torch.nn.Parameter( | |
| torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden) | |
| ) | |
| self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) | |
| self.mean_conv_2.weight = torch.nn.Parameter( | |
| torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden) | |
| ) | |
| else: | |
| self.norm = module | |
| self.check_equal(original) | |
| def check_equal(self, module): | |
| random_data = torch.randn(1, self.chunk_size, self.hidden) | |
| orig_out = module(random_data) | |
| new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2)) | |
| np.testing.assert_allclose( | |
| to_numpy(orig_out), | |
| to_numpy(new_out.squeeze(2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.run_on_bpu: | |
| u = self.mean_conv_1(x) # (1, h, 1, c) | |
| numerator = x + u * self.negtive # (1, h, 1, c) | |
| s = torch.pow(numerator, 2) # (1, h, 1, c) | |
| s = self.mean_conv_2(s) # (1, h, 1, c) | |
| denominator = torch.sqrt(s + self.eps) # (1, h, 1, c) | |
| x = torch.div(numerator, denominator) # (1, h, 1, c) | |
| x = x * self.weight + self.bias | |
| else: | |
| x = x.squeeze(2).transpose(1, 2).contiguous() | |
| x = self.norm(x) | |
| x = x.transpose(1, 2).contiguous().unsqueeze(2) | |
| return x | |
| class BPUIdentity(torch.nn.Module): | |
| """Refactor torch.nn.Identity(). | |
| For inserting BPU node whose input == output. | |
| """ | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.channels = channels | |
| self.identity_conv = torch.nn.Conv2d( | |
| channels, channels, 1, groups=channels, bias=False | |
| ) | |
| torch.nn.init.dirac_(self.identity_conv.weight.data, groups=channels) | |
| self.check_equal() | |
| def check_equal(self): | |
| random_data = torch.randn(1, self.channels, 1, 10) | |
| result = self.forward(random_data) | |
| np.testing.assert_allclose( | |
| to_numpy(random_data), to_numpy(result), rtol=1e-02, atol=1e-03 | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Identity with 4-D dataflow, input == output. | |
| Args: | |
| x (torch.Tensor): (batch, in_channel, 1, time) | |
| Returns: | |
| (torch.Tensor): (batch, in_channel, 1, time). | |
| """ | |
| return self.identity_conv(x) | |
| class BPULinear(torch.nn.Module): | |
| """Refactor torch.nn.Linear or pointwise_conv""" | |
| def __init__(self, module, is_pointwise_conv=False): | |
| super().__init__() | |
| # Unchanged submodules and attributes | |
| original = copy.deepcopy(module) | |
| self.idim = module.weight.size(1) | |
| self.odim = module.weight.size(0) | |
| self.is_pointwise_conv = is_pointwise_conv | |
| # Modify weight & bias | |
| self.linear = torch.nn.Conv2d(self.idim, self.odim, 1, 1) | |
| if is_pointwise_conv: | |
| # (odim, idim, kernel=1) -> (odim, idim, 1, 1) | |
| self.linear.weight = torch.nn.Parameter(module.weight.unsqueeze(-1)) | |
| else: | |
| # (odim, idim) -> (odim, idim, 1, 1) | |
| self.linear.weight = torch.nn.Parameter( | |
| module.weight.unsqueeze(2).unsqueeze(3) | |
| ) | |
| self.linear.bias = module.bias | |
| self.check_equal(original) | |
| def check_equal(self, module): | |
| random_data = torch.randn(1, 8, self.idim) | |
| if self.is_pointwise_conv: | |
| random_data = random_data.transpose(1, 2) | |
| original_result = module(random_data) | |
| if self.is_pointwise_conv: | |
| random_data = random_data.transpose(1, 2) | |
| original_result = original_result.transpose(1, 2) | |
| random_data = random_data.transpose(1, 2).unsqueeze(2) | |
| new_result = self.forward(random_data) | |
| np.testing.assert_allclose( | |
| to_numpy(original_result), | |
| to_numpy(new_result.squeeze(2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Linear with 4-D dataflow. | |
| Args: | |
| x (torch.Tensor): (batch, in_channel, 1, time) | |
| Returns: | |
| (torch.Tensor): (batch, out_channel, 1, time). | |
| """ | |
| return self.linear(x) | |
| class BPUGlobalCMVN(torch.nn.Module): | |
| """Refactor wenet/transformer/cmvn.py::GlobalCMVN""" | |
| def __init__(self, module): | |
| super().__init__() | |
| # Unchanged submodules and attributes | |
| self.norm_var = module.norm_var | |
| # NOTE(xcsong): Expand to 4-D tensor, (mel_dim) -> (1, 1, mel_dim, 1) | |
| self.mean = module.mean.unsqueeze(-1).unsqueeze(0).unsqueeze(0) | |
| self.istd = module.istd.unsqueeze(-1).unsqueeze(0).unsqueeze(0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """CMVN with 4-D dataflow. | |
| Args: | |
| x (torch.Tensor): (batch, 1, mel_dim, time) | |
| Returns: | |
| (torch.Tensor): normalized feature with same shape. | |
| """ | |
| x = x - self.mean | |
| if self.norm_var: | |
| x = x * self.istd | |
| return x | |
| class BPUConv2dSubsampling8(torch.nn.Module): | |
| """Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8 | |
| NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding | |
| """ | |
| def __init__(self, module): | |
| super().__init__() | |
| # Unchanged submodules and attributes | |
| original = copy.deepcopy(module) | |
| self.right_context = module.right_context | |
| self.subsampling_rate = module.subsampling_rate | |
| assert isinstance(module.pos_enc, NoPositionalEncoding) | |
| # 1. Modify self.conv | |
| # NOTE(xcsong): We change input shape from (1, 1, frames, mel_dim) | |
| # to (1, 1, mel_dim, frames) for more efficient computation. | |
| self.conv = module.conv | |
| for idx in [0, 2, 4]: | |
| self.conv[idx].weight = torch.nn.Parameter( | |
| module.conv[idx].weight.transpose(2, 3) | |
| ) | |
| # 2. Modify self.linear | |
| # NOTE(xcsong): Split final projection to meet the requirment of | |
| # maximum kernel_size (7 for XJ3) | |
| self.linear = torch.nn.ModuleList() | |
| odim = module.linear.weight.size(0) # 512, in this case | |
| freq = module.linear.weight.size(1) // odim # 4608 // 512 == 9 | |
| self.odim, self.freq = odim, freq | |
| weight = module.linear.weight.reshape( | |
| odim, odim, freq, 1 | |
| ) # (odim, odim * freq) -> (odim, odim, freq, 1) | |
| self.split_size = [] | |
| num_split = (freq - 1) // 7 + 1 # XJ3 requires kernel_size <= 7 | |
| slice_begin = 0 | |
| for idx in range(num_split): | |
| kernel_size = min(freq, (idx + 1) * 7) - idx * 7 | |
| conv_ele = torch.nn.Conv2d(odim, odim, (kernel_size, 1), (kernel_size, 1)) | |
| conv_ele.weight = torch.nn.Parameter( | |
| weight[:, :, slice_begin : slice_begin + kernel_size, :] | |
| ) | |
| conv_ele.bias = torch.nn.Parameter(torch.zeros_like(conv_ele.bias)) | |
| self.linear.append(conv_ele) | |
| self.split_size.append(kernel_size) | |
| slice_begin += kernel_size | |
| self.linear[0].bias = torch.nn.Parameter(module.linear.bias) | |
| self.check_equal(original) | |
| def check_equal(self, module): | |
| random_data = torch.randn(1, 67, 80) | |
| mask = torch.zeros(1, 1, 67) | |
| original_result, _, _ = module(random_data, mask) # (1, 8, 512) | |
| random_data = random_data.transpose(1, 2).unsqueeze(0) # (1, 1, 80, 67) | |
| new_result = self.forward(random_data) # (1, 512, 1, 8) | |
| np.testing.assert_allclose( | |
| to_numpy(original_result), | |
| to_numpy(new_result.squeeze(2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Subsample x with 4-D dataflow. | |
| Args: | |
| x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time). | |
| Returns: | |
| torch.Tensor: Subsampled tensor (#batch, odim, 1, time'), | |
| where time' = time // 8. | |
| """ | |
| x = self.conv(x) # (1, odim, freq, time') | |
| x_out = torch.zeros(x.size(0), self.odim, 1, x.size(3)) | |
| x = torch.split(x, self.split_size, dim=2) | |
| for idx, (x_part, layer) in enumerate(zip(x, self.linear)): | |
| x_out += layer(x_part) | |
| return x_out | |
| class BPUMultiHeadedAttention(torch.nn.Module): | |
| """Refactor wenet/transformer/attention.py::MultiHeadedAttention | |
| NOTE(xcsong): Only support attention_class == MultiHeadedAttention, | |
| we do not consider RelPositionMultiHeadedAttention currently. | |
| """ | |
| def __init__(self, module, chunk_size, left_chunks): | |
| super().__init__() | |
| # Unchanged submodules and attributes | |
| original = copy.deepcopy(module) | |
| self.d_k = module.d_k | |
| self.h = module.h | |
| n_feat = self.d_k * self.h | |
| self.chunk_size = chunk_size | |
| self.left_chunks = left_chunks | |
| self.time = chunk_size * (left_chunks + 1) | |
| self.activation = torch.nn.Softmax(dim=-1) | |
| # 1. Modify self.linear_x | |
| self.linear_q = BPULinear(module.linear_q) | |
| self.linear_k = BPULinear(module.linear_k) | |
| self.linear_v = BPULinear(module.linear_v) | |
| self.linear_out = BPULinear(module.linear_out) | |
| # 2. denom | |
| self.register_buffer( | |
| "denom", torch.full((1, self.h, 1, 1), 1.0 / math.sqrt(self.d_k)) | |
| ) | |
| self.check_equal(original) | |
| def check_equal(self, module): | |
| random_data = torch.randn(1, self.chunk_size, self.d_k * self.h) | |
| mask = torch.ones((1, self.h, self.chunk_size, self.time), dtype=torch.bool) | |
| cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks, self.d_k * 2) | |
| original_out, original_cache = module( | |
| random_data, | |
| random_data, | |
| random_data, | |
| mask[:, 0, :, :], | |
| torch.empty(0), | |
| cache, | |
| ) | |
| random_data = random_data.transpose(1, 2).unsqueeze(2) | |
| cache = cache.reshape( | |
| 1, self.h, self.d_k * 2, self.chunk_size * self.left_chunks | |
| ) | |
| new_out, new_cache = self.forward( | |
| random_data, random_data, random_data, mask, cache | |
| ) | |
| np.testing.assert_allclose( | |
| to_numpy(original_out), | |
| to_numpy(new_out.squeeze(2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| np.testing.assert_allclose( | |
| to_numpy(original_cache), | |
| to_numpy(new_cache.transpose(2, 3)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| mask: torch.Tensor, | |
| cache: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Compute scaled dot product attention. | |
| Args: | |
| q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size). | |
| k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size). | |
| v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size). | |
| mask (torch.Tensor): Mask tensor, | |
| (#batch, head, chunk_size, cache_t + chunk_size). | |
| cache (torch.Tensor): Cache tensor | |
| (1, head, d_k * 2, cache_t), | |
| where `cache_t == chunk_size * left_chunks`. | |
| Returns: | |
| torch.Tensor: Output tensor (#batch, size, 1, chunk_size). | |
| torch.Tensor: Cache tensor | |
| (1, head, d_k * 2, cache_t + chunk_size) | |
| where `cache_t == chunk_size * left_chunks` | |
| """ | |
| # 1. Forward QKV | |
| q = self.linear_q(q) # (1, d, 1, c) d == size, c == chunk_size | |
| k = self.linear_k(k) # (1, d, 1, c) | |
| v = self.linear_v(v) # (1, d, 1, c) | |
| q = q.view(1, self.h, self.d_k, self.chunk_size) | |
| k = k.view(1, self.h, self.d_k, self.chunk_size) | |
| v = v.view(1, self.h, self.d_k, self.chunk_size) | |
| q = q.transpose(2, 3) # (batch, head, time1, d_k) | |
| k_cache, v_cache = torch.split(cache, cache.size(2) // 2, dim=2) | |
| k = torch.cat((k_cache, k), dim=3) | |
| v = torch.cat((v_cache, v), dim=3) | |
| new_cache = torch.cat((k, v), dim=2) | |
| # 2. (Q^T)K | |
| scores = torch.matmul(q, k) * self.denom # (#b, n_head, time1, time2) | |
| # 3. Forward attention | |
| mask = mask.eq(0) | |
| scores = scores.masked_fill(mask, -float("inf")) | |
| attn = self.activation(scores).masked_fill(mask, 0.0) | |
| attn = attn.transpose(2, 3) | |
| x = torch.matmul(v, attn) | |
| x = x.view(1, self.d_k * self.h, 1, self.chunk_size) | |
| x_out = self.linear_out(x) | |
| return x_out, new_cache | |
| class BPUConvolution(torch.nn.Module): | |
| """Refactor wenet/transformer/convolution.py::ConvolutionModule | |
| NOTE(xcsong): Only suport use_layer_norm == False | |
| """ | |
| def __init__(self, module): | |
| super().__init__() | |
| # Unchanged submodules and attributes | |
| original = copy.deepcopy(module) | |
| self.lorder = module.lorder | |
| self.use_layer_norm = False | |
| self.activation = module.activation | |
| channels = module.pointwise_conv1.weight.size(1) | |
| self.channels = channels | |
| kernel_size = module.depthwise_conv.weight.size(2) | |
| assert module.use_layer_norm is False | |
| # 1. Modify self.pointwise_conv1 | |
| self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True) | |
| # 2. Modify self.depthwise_conv | |
| self.depthwise_conv = torch.nn.Conv2d( | |
| channels, channels, (1, kernel_size), stride=1, groups=channels | |
| ) | |
| self.depthwise_conv.weight = torch.nn.Parameter( | |
| module.depthwise_conv.weight.unsqueeze(-2) | |
| ) | |
| self.depthwise_conv.bias = torch.nn.Parameter(module.depthwise_conv.bias) | |
| # 3. Modify self.norm, Only support batchnorm2d | |
| self.norm = torch.nn.BatchNorm2d(channels) | |
| self.norm.training = False | |
| self.norm.num_features = module.norm.num_features | |
| self.norm.eps = module.norm.eps | |
| self.norm.momentum = module.norm.momentum | |
| self.norm.weight = torch.nn.Parameter(module.norm.weight) | |
| self.norm.bias = torch.nn.Parameter(module.norm.bias) | |
| self.norm.running_mean = module.norm.running_mean | |
| self.norm.running_var = module.norm.running_var | |
| # 4. Modify self.pointwise_conv2 | |
| self.pointwise_conv2 = BPULinear(module.pointwise_conv2, True) | |
| # 5. Identity conv, for running `concat` on BPU | |
| self.identity = BPUIdentity(channels) | |
| self.check_equal(original) | |
| def check_equal(self, module): | |
| random_data = torch.randn(1, 8, self.channels) | |
| cache = torch.zeros((1, self.channels, self.lorder)) | |
| original_out, original_cache = module(random_data, cache=cache) | |
| random_data = random_data.transpose(1, 2).unsqueeze(2) | |
| cache = cache.unsqueeze(2) | |
| new_out, new_cache = self.forward(random_data, cache) | |
| np.testing.assert_allclose( | |
| to_numpy(original_out), | |
| to_numpy(new_out.squeeze(2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| np.testing.assert_allclose( | |
| to_numpy(original_cache), | |
| to_numpy(new_cache.squeeze(2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| def forward( | |
| self, x: torch.Tensor, cache: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Compute convolution module. | |
| Args: | |
| x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size). | |
| cache (torch.Tensor): left context cache, it is only | |
| used in causal convolution (#batch, channels, 1, cache_t). | |
| Returns: | |
| torch.Tensor: Output tensor (#batch, channels, 1, chunk_size). | |
| torch.Tensor: Cache tensor (#batch, channels, 1, cache_t). | |
| """ | |
| # Concat cache | |
| x = torch.cat((self.identity(cache), self.identity(x)), dim=3) | |
| new_cache = x[:, :, :, -self.lorder :] | |
| # GLU mechanism | |
| x = self.pointwise_conv1(x) # (batch, 2*channel, 1, dim) | |
| x = torch.nn.functional.glu(x, dim=1) # (b, channel, 1, dim) | |
| # Depthwise Conv | |
| x = self.depthwise_conv(x) | |
| x = self.activation(self.norm(x)) | |
| x = self.pointwise_conv2(x) | |
| return x, new_cache | |
| class BPUFFN(torch.nn.Module): | |
| """Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward""" | |
| def __init__(self, module): | |
| super().__init__() | |
| # Unchanged submodules and attributes | |
| original = copy.deepcopy(module) | |
| self.activation = module.activation | |
| # 1. Modify self.w_x | |
| self.w_1 = BPULinear(module.w_1) | |
| self.w_2 = BPULinear(module.w_2) | |
| self.check_equal(original) | |
| def check_equal(self, module): | |
| random_data = torch.randn(1, 8, self.w_1.idim) | |
| original_out = module(random_data) | |
| random_data = random_data.transpose(1, 2).unsqueeze(2) | |
| new_out = self.forward(random_data) | |
| np.testing.assert_allclose( | |
| to_numpy(original_out), | |
| to_numpy(new_out.squeeze(2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward function. | |
| Args: | |
| xs: input tensor (B, D, 1, L) | |
| Returns: | |
| output tensor, (B, D, 1, L) | |
| """ | |
| return self.w_2(self.activation(self.w_1(x))) | |
| class BPUConformerEncoderLayer(torch.nn.Module): | |
| """Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer""" | |
| def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): | |
| super().__init__() | |
| # Unchanged submodules and attributes | |
| original = copy.deepcopy(module) | |
| self.size = module.size | |
| assert module.normalize_before is True | |
| assert module.concat_after is False | |
| # 1. Modify submodules | |
| self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron) | |
| self.self_attn = BPUMultiHeadedAttention( | |
| module.self_attn, chunk_size, left_chunks | |
| ) | |
| self.conv_module = BPUConvolution(module.conv_module) | |
| self.feed_forward = BPUFFN(module.feed_forward) | |
| # 2. Modify norms | |
| self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu) | |
| self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size, ln_run_on_bpu) | |
| self.norm_ff_macron = BPULayerNorm( | |
| module.norm_ff_macaron, chunk_size, ln_run_on_bpu | |
| ) | |
| self.norm_conv = BPULayerNorm(module.norm_conv, chunk_size, ln_run_on_bpu) | |
| self.norm_final = BPULayerNorm(module.norm_final, chunk_size, ln_run_on_bpu) | |
| # 3. 4-D ff_scale | |
| self.register_buffer( | |
| "ff_scale", torch.full((1, self.size, 1, 1), module.ff_scale) | |
| ) | |
| self.check_equal(original) | |
| def check_equal(self, module): | |
| time1 = self.self_attn.chunk_size | |
| time2 = self.self_attn.time | |
| h, d_k = self.self_attn.h, self.self_attn.d_k | |
| random_x = torch.randn(1, time1, self.size) | |
| att_mask = torch.ones(1, h, time1, time2) | |
| att_cache = torch.zeros(1, h, time2 - time1, d_k * 2) | |
| cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder) | |
| original_x, _, original_att_cache, original_cnn_cache = module( | |
| random_x, | |
| att_mask[:, 0, :, :], | |
| torch.empty(0), | |
| att_cache=att_cache, | |
| cnn_cache=cnn_cache, | |
| ) | |
| random_x = random_x.transpose(1, 2).unsqueeze(2) | |
| att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1) | |
| cnn_cache = cnn_cache.unsqueeze(2) | |
| new_x, new_att_cache, new_cnn_cache = self.forward( | |
| random_x, att_mask, att_cache, cnn_cache | |
| ) | |
| np.testing.assert_allclose( | |
| to_numpy(original_att_cache), | |
| to_numpy(new_att_cache.transpose(2, 3)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| np.testing.assert_allclose( | |
| to_numpy(original_x), | |
| to_numpy(new_x.squeeze(2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| np.testing.assert_allclose( | |
| to_numpy(original_cnn_cache), | |
| to_numpy(new_cnn_cache.squeeze(2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| att_mask: torch.Tensor, | |
| att_cache: torch.Tensor, | |
| cnn_cache: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Compute encoded features. | |
| Args: | |
| x (torch.Tensor): (#batch, size, 1, chunk_size) | |
| att_mask (torch.Tensor): Mask tensor for the input | |
| (#batch, head, chunk_size, cache_t1 + chunk_size), | |
| att_cache (torch.Tensor): Cache tensor of the KEY & VALUE | |
| (#batch=1, head, d_k * 2, cache_t1), head * d_k == size. | |
| cnn_cache (torch.Tensor): Convolution cache in conformer layer | |
| (#batch=1, size, 1, cache_t2) | |
| Returns: | |
| torch.Tensor: Output tensor (#batch, size, 1, chunk_size). | |
| torch.Tensor: att_cache tensor, | |
| (1, head, d_k * 2, cache_t1 + chunk_size). | |
| torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2). | |
| """ | |
| # 1. ffn_macaron | |
| residual = x | |
| x = self.norm_ff_macron(x) | |
| x = residual + self.ff_scale * self.feed_forward_macaron(x) | |
| # 2. attention | |
| residual = x | |
| x = self.norm_mha(x) | |
| x_att, new_att_cache = self.self_attn(x, x, x, att_mask, att_cache) | |
| x = residual + x_att | |
| # 3. convolution | |
| residual = x | |
| x = self.norm_conv(x) | |
| x, new_cnn_cache = self.conv_module(x, cnn_cache) | |
| x = residual + x | |
| # 4. ffn | |
| residual = x | |
| x = self.norm_ff(x) | |
| x = residual + self.ff_scale * self.feed_forward(x) | |
| # 5. final post-norm | |
| x = self.norm_final(x) | |
| return x, new_att_cache, new_cnn_cache | |
| class BPUConformerEncoder(torch.nn.Module): | |
| """Refactor wenet/transformer/encoder.py::ConformerEncoder""" | |
| def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): | |
| super().__init__() | |
| # Unchanged submodules and attributes | |
| original = copy.deepcopy(module) | |
| output_size = module.output_size() | |
| self._output_size = module.output_size() | |
| self.after_norm = module.after_norm | |
| self.chunk_size = chunk_size | |
| self.left_chunks = left_chunks | |
| self.head = module.encoders[0].self_attn.h | |
| self.layers = len(module.encoders) | |
| # 1. Modify submodules | |
| self.global_cmvn = BPUGlobalCMVN(module.global_cmvn) | |
| self.embed = BPUConv2dSubsampling8(module.embed) | |
| self.encoders = torch.nn.ModuleList() | |
| for layer in module.encoders: | |
| self.encoders.append( | |
| BPUConformerEncoderLayer(layer, chunk_size, left_chunks, ln_run_on_bpu) | |
| ) | |
| # 2. Auxiliary conv | |
| self.identity_cnncache = BPUIdentity(output_size) | |
| self.check_equal(original) | |
| def check_equal(self, module): | |
| time1 = self.encoders[0].self_attn.chunk_size | |
| time2 = self.encoders[0].self_attn.time | |
| layers = self.layers | |
| h, d_k = self.head, self.encoders[0].self_attn.d_k | |
| decoding_window = ( | |
| (self.chunk_size - 1) * module.embed.subsampling_rate | |
| + module.embed.right_context | |
| + 1 | |
| ) | |
| lorder = self.encoders[0].conv_module.lorder | |
| random_x = torch.randn(1, decoding_window, 80) | |
| att_mask = torch.ones(1, h, time1, time2) | |
| att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2) | |
| cnn_cache = torch.zeros(layers, 1, self._output_size, lorder) | |
| orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk( | |
| random_x, | |
| 0, | |
| time2 - time1, | |
| att_mask=att_mask[:, 0, :, :], | |
| att_cache=att_cache, | |
| cnn_cache=cnn_cache, | |
| ) | |
| random_x = random_x.unsqueeze(0) | |
| att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1) | |
| cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder) | |
| new_x, new_att_cache, new_cnn_cache = self.forward( | |
| random_x, att_cache, cnn_cache, att_mask | |
| ) | |
| caches = torch.split(new_att_cache, h, dim=1) | |
| caches = [c.transpose(2, 3) for c in caches] | |
| np.testing.assert_allclose( | |
| to_numpy(orig_att_cache), | |
| to_numpy(torch.cat(caches, dim=0)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| np.testing.assert_allclose( | |
| to_numpy(orig_x), | |
| to_numpy(new_x.squeeze(2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| np.testing.assert_allclose( | |
| to_numpy(orig_cnn_cache), | |
| to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| def forward( | |
| self, | |
| xs: torch.Tensor, | |
| att_cache: torch.Tensor, | |
| cnn_cache: torch.Tensor, | |
| att_mask: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ Forward just one chunk | |
| Args: | |
| xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim), | |
| where `time == (chunk_size - 1) * subsample_rate + \ | |
| subsample.right_context + 1` | |
| att_cache (torch.Tensor): cache tensor for KEY & VALUE in | |
| transformer/conformer attention, with shape | |
| (1, head * elayers, d_k * 2, cache_t1), where | |
| `head * d_k == hidden-dim` and | |
| `cache_t1 == chunk_size * left_chunks`. | |
| cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, | |
| (1, hidden-dim, elayers, cache_t2), where | |
| `cache_t2 == cnn.lorder - 1` | |
| att_mask (torch.Tensor): Mask tensor for the input | |
| (#batch, head, chunk_size, cache_t1 + chunk_size), | |
| Returns: | |
| torch.Tensor: output of current input xs, | |
| with shape (b=1, hidden-dim, 1, chunk_size). | |
| torch.Tensor: new attention cache required for next chunk, with | |
| same shape as the original att_cache. | |
| torch.Tensor: new conformer cnn cache required for next chunk, with | |
| same shape as the original cnn_cache. | |
| """ | |
| # xs: (B, 1, time, mel_dim) -> (B, 1, mel_dim, time) | |
| xs = xs.transpose(2, 3) | |
| xs = self.global_cmvn(xs) | |
| # xs: (B, 1, mel_dim, time) -> (B, hidden_dim, 1, chunk_size) | |
| xs = self.embed(xs) | |
| att_cache = torch.split(att_cache, self.head, dim=1) | |
| cnn_cache = self.identity_cnncache(cnn_cache) | |
| cnn_cache = torch.split(cnn_cache, 1, dim=2) | |
| r_att_cache = [] | |
| r_cnn_cache = [] | |
| for i, layer in enumerate(self.encoders): | |
| xs, new_att_cache, new_cnn_cache = layer( | |
| xs, att_mask, att_cache=att_cache[i], cnn_cache=cnn_cache[i] | |
| ) | |
| r_att_cache.append(new_att_cache[:, :, :, self.chunk_size :]) | |
| r_cnn_cache.append(new_cnn_cache) | |
| r_att_cache = torch.cat(r_att_cache, dim=1) | |
| r_cnn_cache = self.identity_cnncache(torch.cat(r_cnn_cache, dim=2)) | |
| xs = xs.squeeze(2).transpose(1, 2).contiguous() | |
| xs = self.after_norm(xs) | |
| # NOTE(xcsong): 4D in, 4D out to meet the requirment of CTC input. | |
| xs = xs.transpose(1, 2).contiguous().unsqueeze(2) # (B, C, 1, T) | |
| return (xs, r_att_cache, r_cnn_cache) | |
| class BPUCTC(torch.nn.Module): | |
| """Refactor wenet/transformer/ctc.py::CTC""" | |
| def __init__(self, module): | |
| super().__init__() | |
| # Unchanged submodules and attributes | |
| original = copy.deepcopy(module) | |
| self.idim = module.ctc_lo.weight.size(1) | |
| num_class = module.ctc_lo.weight.size(0) | |
| # 1. Modify self.ctc_lo, Split final projection to meet the | |
| # requirment of maximum in/out channels (2048 for XJ3) | |
| self.ctc_lo = torch.nn.ModuleList() | |
| self.split_size = [] | |
| num_split = (num_class - 1) // 2048 + 1 | |
| for idx in range(num_split): | |
| out_channel = min(num_class, (idx + 1) * 2048) - idx * 2048 | |
| conv_ele = torch.nn.Conv2d(self.idim, out_channel, 1, 1) | |
| self.ctc_lo.append(conv_ele) | |
| self.split_size.append(out_channel) | |
| orig_weight = torch.split(module.ctc_lo.weight, self.split_size, dim=0) | |
| orig_bias = torch.split(module.ctc_lo.bias, self.split_size, dim=0) | |
| for i, (w, b) in enumerate(zip(orig_weight, orig_bias)): | |
| w = w.unsqueeze(2).unsqueeze(3) | |
| self.ctc_lo[i].weight = torch.nn.Parameter(w) | |
| self.ctc_lo[i].bias = torch.nn.Parameter(b) | |
| self.check_equal(original) | |
| def check_equal(self, module): | |
| random_data = torch.randn(1, 100, self.idim) | |
| original_result = module.ctc_lo(random_data) | |
| random_data = random_data.transpose(1, 2).unsqueeze(2) | |
| new_result = self.forward(random_data) | |
| np.testing.assert_allclose( | |
| to_numpy(original_result), | |
| to_numpy(new_result.squeeze(2).transpose(1, 2)), | |
| rtol=1e-02, | |
| atol=1e-03, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """frame activations, without softmax. | |
| Args: | |
| Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size) | |
| Returns: | |
| torch.Tensor: (B, num_class, 1, chunk_size) | |
| """ | |
| out = [] | |
| for i, layer in enumerate(self.ctc_lo): | |
| out.append(layer(x)) | |
| out = torch.cat(out, dim=1) | |
| return out | |
| def export_encoder(asr_model, args): | |
| logger.info("Stage-1: export encoder") | |
| decode_window, mel_dim = args.decoding_window, args.feature_size | |
| encoder = BPUConformerEncoder( | |
| asr_model.encoder, | |
| args.chunk_size, | |
| args.num_decoding_left_chunks, | |
| args.ln_run_on_bpu, | |
| ) | |
| encoder.eval() | |
| encoder_outpath = os.path.join(args.output_dir, "encoder.onnx") | |
| logger.info("Stage-1.1: prepare inputs for encoder") | |
| chunk = torch.randn((1, 1, decode_window, mel_dim)) | |
| required_cache_size = encoder.chunk_size * encoder.left_chunks | |
| kv_time = required_cache_size + encoder.chunk_size | |
| hidden, layers = encoder._output_size, len(encoder.encoders) | |
| head = encoder.encoders[0].self_attn.h | |
| d_k = hidden // head | |
| lorder = encoder.encoders[0].conv_module.lorder | |
| att_cache = torch.zeros(1, layers * head, d_k * 2, required_cache_size) | |
| att_mask = torch.ones((1, head, encoder.chunk_size, kv_time)) | |
| att_mask[:, :, :, :required_cache_size] = 0 | |
| cnn_cache = torch.zeros((1, hidden, layers, lorder)) | |
| inputs = (chunk, att_cache, cnn_cache, att_mask) | |
| logger.info( | |
| "chunk.size(): {} att_cache.size(): {} " | |
| "cnn_cache.size(): {} att_mask.size(): {}".format( | |
| list(chunk.size()), | |
| list(att_cache.size()), | |
| list(cnn_cache.size()), | |
| list(att_mask.size()), | |
| ) | |
| ) | |
| logger.info("Stage-1.2: torch.onnx.export") | |
| # NOTE(xcsong): Below attributes will be used in | |
| # onnx2horizonbin.py::generate_config() | |
| attributes = {} | |
| attributes["input_name"] = "chunk;att_cache;cnn_cache;att_mask" | |
| attributes["output_name"] = "output;r_att_cache;r_cnn_cache" | |
| attributes["input_type"] = "featuremap;featuremap;featuremap;featuremap" | |
| attributes["norm_type"] = "no_preprocess;no_preprocess;no_preprocess;no_preprocess" | |
| attributes["input_layout_train"] = "NCHW;NCHW;NCHW;NCHW" | |
| attributes["input_layout_rt"] = "NCHW;NCHW;NCHW;NCHW" | |
| attributes[ | |
| "input_shape" | |
| ] = "{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}".format( | |
| chunk.size(0), | |
| chunk.size(1), | |
| chunk.size(2), | |
| chunk.size(3), | |
| att_cache.size(0), | |
| att_cache.size(1), | |
| att_cache.size(2), | |
| att_cache.size(3), | |
| cnn_cache.size(0), | |
| cnn_cache.size(1), | |
| cnn_cache.size(2), | |
| cnn_cache.size(3), | |
| att_mask.size(0), | |
| att_mask.size(1), | |
| att_mask.size(2), | |
| att_mask.size(3), | |
| ) | |
| torch.onnx.export( # NOTE(xcsong): only support opset==11 | |
| encoder, | |
| inputs, | |
| encoder_outpath, | |
| opset_version=11, | |
| export_params=True, | |
| do_constant_folding=True, | |
| input_names=attributes["input_name"].split(";"), | |
| output_names=attributes["output_name"].split(";"), | |
| dynamic_axes=None, | |
| verbose=False, | |
| ) | |
| onnx_encoder = onnx.load(encoder_outpath) | |
| for k in vars(args): | |
| meta = onnx_encoder.metadata_props.add() | |
| meta.key, meta.value = str(k), str(getattr(args, k)) | |
| for k in attributes: | |
| meta = onnx_encoder.metadata_props.add() | |
| meta.key, meta.value = str(k), str(attributes[k]) | |
| onnx.checker.check_model(onnx_encoder) | |
| onnx.helper.printable_graph(onnx_encoder.graph) | |
| onnx.save(onnx_encoder, encoder_outpath) | |
| print_input_output_info(onnx_encoder, "onnx_encoder") | |
| logger.info("Export onnx_encoder, done! see {}".format(encoder_outpath)) | |
| logger.info("Stage-1.3: check onnx_encoder and torch_encoder") | |
| torch_output = [] | |
| torch_chunk, torch_att_mask = copy.deepcopy(chunk), copy.deepcopy(att_mask) | |
| torch_att_cache = copy.deepcopy(att_cache) | |
| torch_cnn_cache = copy.deepcopy(cnn_cache) | |
| for i in range(10): | |
| logger.info( | |
| "torch chunk-{}: {}, att_cache: {}, cnn_cache: {}" | |
| ", att_mask: {}".format( | |
| i, | |
| list(torch_chunk.size()), | |
| list(torch_att_cache.size()), | |
| list(torch_cnn_cache.size()), | |
| list(torch_att_mask.size()), | |
| ) | |
| ) | |
| torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)) :] = 1 | |
| out, torch_att_cache, torch_cnn_cache = encoder( | |
| torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask | |
| ) | |
| torch_output.append(out) | |
| torch_output = torch.cat(torch_output, dim=-1) | |
| onnx_output = [] | |
| onnx_chunk, onnx_att_mask = to_numpy(chunk), to_numpy(att_mask) | |
| onnx_att_cache = to_numpy(att_cache) | |
| onnx_cnn_cache = to_numpy(cnn_cache) | |
| ort_session = onnxruntime.InferenceSession(encoder_outpath) | |
| input_names = [node.name for node in onnx_encoder.graph.input] | |
| for i in range(10): | |
| logger.info( | |
| "onnx chunk-{}: {}, att_cache: {}, cnn_cache: {}," | |
| " att_mask: {}".format( | |
| i, | |
| onnx_chunk.shape, | |
| onnx_att_cache.shape, | |
| onnx_cnn_cache.shape, | |
| onnx_att_mask.shape, | |
| ) | |
| ) | |
| onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)) :] = 1 | |
| ort_inputs = { | |
| "chunk": onnx_chunk, | |
| "att_cache": onnx_att_cache, | |
| "cnn_cache": onnx_cnn_cache, | |
| "att_mask": onnx_att_mask, | |
| } | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] | |
| onnx_output.append(ort_outs[0]) | |
| onnx_output = np.concatenate(onnx_output, axis=-1) | |
| np.testing.assert_allclose( | |
| to_numpy(torch_output), onnx_output, rtol=1e-03, atol=1e-04 | |
| ) | |
| meta = ort_session.get_modelmeta() | |
| logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) | |
| logger.info("Check onnx_encoder, pass!") | |
| return encoder, ort_session | |
| def export_ctc(asr_model, args): | |
| logger.info("Stage-2: export ctc") | |
| ctc = BPUCTC(asr_model.ctc).eval() | |
| ctc_outpath = os.path.join(args.output_dir, "ctc.onnx") | |
| logger.info("Stage-2.1: prepare inputs for ctc") | |
| hidden = torch.randn((1, args.output_size, 1, args.chunk_size)) | |
| logger.info("Stage-2.2: torch.onnx.export") | |
| # NOTE(xcsong): Below attributes will be used in | |
| # onnx2horizonbin.py::generate_config() | |
| attributes = {} | |
| attributes["input_name"], attributes["input_type"] = "hidden", "featuremap" | |
| attributes["norm_type"] = "no_preprocess" | |
| attributes["input_layout_train"] = "NCHW" | |
| attributes["input_layout_rt"] = "NCHW" | |
| attributes["input_shape"] = "{}x{}x{}x{}".format( | |
| hidden.size(0), | |
| hidden.size(1), | |
| hidden.size(2), | |
| hidden.size(3), | |
| ) | |
| torch.onnx.export( | |
| ctc, | |
| hidden, | |
| ctc_outpath, | |
| opset_version=11, | |
| export_params=True, | |
| do_constant_folding=True, | |
| input_names=["hidden"], | |
| output_names=["probs"], | |
| dynamic_axes=None, | |
| verbose=False, | |
| ) | |
| onnx_ctc = onnx.load(ctc_outpath) | |
| for k in vars(args): | |
| meta = onnx_ctc.metadata_props.add() | |
| meta.key, meta.value = str(k), str(getattr(args, k)) | |
| for k in attributes: | |
| meta = onnx_ctc.metadata_props.add() | |
| meta.key, meta.value = str(k), str(attributes[k]) | |
| onnx.checker.check_model(onnx_ctc) | |
| onnx.helper.printable_graph(onnx_ctc.graph) | |
| onnx.save(onnx_ctc, ctc_outpath) | |
| print_input_output_info(onnx_ctc, "onnx_ctc") | |
| logger.info("Export onnx_ctc, done! see {}".format(ctc_outpath)) | |
| logger.info("Stage-2.3: check onnx_ctc and torch_ctc") | |
| torch_output = ctc(hidden) | |
| ort_session = onnxruntime.InferenceSession(ctc_outpath) | |
| onnx_output = ort_session.run(None, {"hidden": to_numpy(hidden)}) | |
| np.testing.assert_allclose( | |
| to_numpy(torch_output), onnx_output[0], rtol=1e-03, atol=1e-04 | |
| ) | |
| meta = ort_session.get_modelmeta() | |
| logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) | |
| logger.info("Check onnx_ctc, pass!") | |
| return ctc, ort_session | |
| def export_decoder(asr_model, args): | |
| logger.info("Currently, Decoder is not supported.") | |
| if __name__ == "__main__": | |
| torch.manual_seed(777) | |
| args = get_args() | |
| args.ln_run_on_bpu = False | |
| # NOTE(xcsong): XJ3 BPU only support static shapes | |
| assert args.chunk_size > 0 | |
| assert args.num_decoding_left_chunks > 0 | |
| os.system("mkdir -p " + args.output_dir) | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| with open(args.config, "r") as fin: | |
| configs = yaml.load(fin, Loader=yaml.FullLoader) | |
| model = init_model(configs) | |
| load_checkpoint(model, args.checkpoint) | |
| model.eval() | |
| print(model) | |
| args.feature_size = configs["input_dim"] | |
| args.output_size = model.encoder.output_size() | |
| args.decoding_window = ( | |
| (args.chunk_size - 1) * model.encoder.embed.subsampling_rate | |
| + model.encoder.embed.right_context | |
| + 1 | |
| ) | |
| export_encoder(model, args) | |
| export_ctc(model, args) | |
| export_decoder(model, args) | |