Spaces:
Sleeping
Sleeping
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| from __future__ import print_function | |
| import argparse | |
| import os | |
| import sys | |
| import torch | |
| import yaml | |
| import logging | |
| import torch.nn.functional as F | |
| from wenet.utils.checkpoint import load_checkpoint | |
| from wenet.transformer.ctc import CTC | |
| from wenet.transformer.decoder import TransformerDecoder | |
| from wenet.transformer.encoder import BaseEncoder | |
| from wenet.utils.init_model import init_model | |
| from wenet.utils.mask import make_pad_mask | |
| try: | |
| import onnxruntime | |
| except ImportError: | |
| print("Please install onnxruntime-gpu!") | |
| sys.exit(1) | |
| logger = logging.getLogger(__file__) | |
| logger.setLevel(logging.INFO) | |
| class Encoder(torch.nn.Module): | |
| def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.ctc = ctc | |
| self.beam_size = beam_size | |
| def forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| ): | |
| """Encoder | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch, ) | |
| Returns: | |
| encoder_out: B x T x F | |
| encoder_out_lens: B | |
| ctc_log_probs: B x T x V | |
| beam_log_probs: B x T x beam_size | |
| beam_log_probs_idx: B x T x beam_size | |
| """ | |
| encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1, -1) | |
| encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
| ctc_log_probs = self.ctc.log_softmax(encoder_out) | |
| encoder_out_lens = encoder_out_lens.int() | |
| beam_log_probs, beam_log_probs_idx = torch.topk( | |
| ctc_log_probs, self.beam_size, dim=2 | |
| ) | |
| return ( | |
| encoder_out, | |
| encoder_out_lens, | |
| ctc_log_probs, | |
| beam_log_probs, | |
| beam_log_probs_idx, | |
| ) | |
| class StreamingEncoder(torch.nn.Module): | |
| def __init__(self, model, required_cache_size, beam_size, transformer=False): | |
| super().__init__() | |
| self.ctc = model.ctc | |
| self.subsampling_rate = model.encoder.embed.subsampling_rate | |
| self.embed = model.encoder.embed | |
| self.global_cmvn = model.encoder.global_cmvn | |
| self.required_cache_size = required_cache_size | |
| self.beam_size = beam_size | |
| self.encoder = model.encoder | |
| self.transformer = transformer | |
| def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask): | |
| """Streaming Encoder | |
| Args: | |
| xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), | |
| where `time == (chunk_size - 1) * subsample_rate + \ | |
| subsample.right_context + 1` | |
| offset (torch.Tensor): offset with shape (b, 1) | |
| 1 is retained for triton deployment | |
| required_cache_size (int): cache size required for next chunk | |
| compuation | |
| > 0: actual cache size | |
| <= 0: not allowed in streaming gpu encoder ` | |
| att_cache (torch.Tensor): cache tensor for KEY & VALUE in | |
| transformer/conformer attention, with shape | |
| (b, elayers, head, cache_t1, d_k * 2), where | |
| `head * d_k == hidden-dim` and | |
| `cache_t1 == chunk_size * num_decoding_left_chunks`. | |
| cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, | |
| (b, elayers, b, hidden-dim, cache_t2), where | |
| `cache_t2 == cnn.lorder - 1` | |
| cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) | |
| in a batch of request, each request may have different | |
| history cache. Cache mask is used to indidate the effective | |
| cache for each request | |
| Returns: | |
| torch.Tensor: log probabilities of ctc output and cutoff by beam size | |
| with shape (b, chunk_size, beam) | |
| torch.Tensor: index of top beam size probabilities for each timestep | |
| with shape (b, chunk_size, beam) | |
| torch.Tensor: output of current input xs, | |
| with shape (b, chunk_size, hidden-dim). | |
| torch.Tensor: new attention cache required for next chunk, with | |
| same shape (b, elayers, head, cache_t1, d_k * 2) | |
| as the original att_cache | |
| torch.Tensor: new conformer cnn cache required for next chunk, with | |
| same shape as the original cnn_cache. | |
| torch.Tensor: new cache mask, with same shape as the original | |
| cache mask | |
| """ | |
| offset = offset.squeeze(1) | |
| T = chunk_xs.size(1) | |
| chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) | |
| # B X 1 X T | |
| chunk_mask = chunk_mask.to(chunk_xs.dtype) | |
| # transpose batch & num_layers dim | |
| att_cache = torch.transpose(att_cache, 0, 1) | |
| cnn_cache = torch.transpose(cnn_cache, 0, 1) | |
| # rewrite encoder.forward_chunk | |
| # <---------forward_chunk START---------> | |
| xs = self.global_cmvn(chunk_xs) | |
| # chunk mask is important for batch inferencing since | |
| # different sequence in a batch has different length | |
| xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) | |
| cache_size = att_cache.size(3) # required cache size | |
| masks = torch.cat((cache_mask, chunk_mask), dim=2) | |
| index = offset - cache_size | |
| pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) | |
| pos_emb = pos_emb.to(dtype=xs.dtype) | |
| next_cache_start = -self.required_cache_size | |
| r_cache_mask = masks[:, :, next_cache_start:] | |
| r_att_cache = [] | |
| r_cnn_cache = [] | |
| for i, layer in enumerate(self.encoder.encoders): | |
| xs, _, new_att_cache, new_cnn_cache = layer( | |
| xs, masks, pos_emb, att_cache=att_cache[i], cnn_cache=cnn_cache[i] | |
| ) | |
| # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2), | |
| # shape(new_cnn_cache) is (B, hidden-dim, cache_t2) | |
| r_att_cache.append(new_att_cache[:, :, next_cache_start:, :].unsqueeze(1)) | |
| if not self.transformer: | |
| r_cnn_cache.append(new_cnn_cache.unsqueeze(1)) | |
| if self.encoder.normalize_before: | |
| chunk_out = self.encoder.after_norm(xs) | |
| else: | |
| chunk_out = xs | |
| r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx | |
| if not self.transformer: | |
| r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers | |
| # <---------forward_chunk END---------> | |
| log_ctc_probs = self.ctc.log_softmax(chunk_out) | |
| log_probs, log_probs_idx = torch.topk(log_ctc_probs, self.beam_size, dim=2) | |
| log_probs = log_probs.to(chunk_xs.dtype) | |
| r_offset = offset + chunk_out.shape[1] | |
| # the below ops not supported in Tensorrt | |
| # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, | |
| # rounding_mode='floor') | |
| chunk_out_lens = chunk_lens // self.subsampling_rate | |
| r_offset = r_offset.unsqueeze(1) | |
| return ( | |
| log_probs, | |
| log_probs_idx, | |
| chunk_out, | |
| chunk_out_lens, | |
| r_offset, | |
| r_att_cache, | |
| r_cnn_cache, | |
| r_cache_mask, | |
| ) | |
| class StreamingSqueezeformerEncoder(torch.nn.Module): | |
| def __init__(self, model, required_cache_size, beam_size): | |
| super().__init__() | |
| self.ctc = model.ctc | |
| self.subsampling_rate = model.encoder.embed.subsampling_rate | |
| self.embed = model.encoder.embed | |
| self.global_cmvn = model.encoder.global_cmvn | |
| self.required_cache_size = required_cache_size | |
| self.beam_size = beam_size | |
| self.encoder = model.encoder | |
| self.reduce_idx = model.encoder.reduce_idx | |
| self.recover_idx = model.encoder.recover_idx | |
| if self.reduce_idx is None: | |
| self.time_reduce = None | |
| else: | |
| if self.recover_idx is None: | |
| self.time_reduce = "normal" # no recovery at the end | |
| else: | |
| self.time_reduce = "recover" # recovery at the end | |
| assert len(self.reduce_idx) == len(self.recover_idx) | |
| def calculate_downsampling_factor(self, i: int) -> int: | |
| if self.reduce_idx is None: | |
| return 1 | |
| else: | |
| reduce_exp, recover_exp = 0, 0 | |
| for exp, rd_idx in enumerate(self.reduce_idx): | |
| if i >= rd_idx: | |
| reduce_exp = exp + 1 | |
| if self.recover_idx is not None: | |
| for exp, rc_idx in enumerate(self.recover_idx): | |
| if i >= rc_idx: | |
| recover_exp = exp + 1 | |
| return int(2 ** (reduce_exp - recover_exp)) | |
| def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask): | |
| """Streaming Encoder | |
| Args: | |
| xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), | |
| where `time == (chunk_size - 1) * subsample_rate + \ | |
| subsample.right_context + 1` | |
| offset (torch.Tensor): offset with shape (b, 1) | |
| 1 is retained for triton deployment | |
| required_cache_size (int): cache size required for next chunk | |
| compuation | |
| > 0: actual cache size | |
| <= 0: not allowed in streaming gpu encoder ` | |
| att_cache (torch.Tensor): cache tensor for KEY & VALUE in | |
| transformer/conformer attention, with shape | |
| (b, elayers, head, cache_t1, d_k * 2), where | |
| `head * d_k == hidden-dim` and | |
| `cache_t1 == chunk_size * num_decoding_left_chunks`. | |
| cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, | |
| (b, elayers, b, hidden-dim, cache_t2), where | |
| `cache_t2 == cnn.lorder - 1` | |
| cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) | |
| in a batch of request, each request may have different | |
| history cache. Cache mask is used to indidate the effective | |
| cache for each request | |
| Returns: | |
| torch.Tensor: log probabilities of ctc output and cutoff by beam size | |
| with shape (b, chunk_size, beam) | |
| torch.Tensor: index of top beam size probabilities for each timestep | |
| with shape (b, chunk_size, beam) | |
| torch.Tensor: output of current input xs, | |
| with shape (b, chunk_size, hidden-dim). | |
| torch.Tensor: new attention cache required for next chunk, with | |
| same shape (b, elayers, head, cache_t1, d_k * 2) | |
| as the original att_cache | |
| torch.Tensor: new conformer cnn cache required for next chunk, with | |
| same shape as the original cnn_cache. | |
| torch.Tensor: new cache mask, with same shape as the original | |
| cache mask | |
| """ | |
| offset = offset.squeeze(1) | |
| T = chunk_xs.size(1) | |
| chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) | |
| # B X 1 X T | |
| chunk_mask = chunk_mask.to(chunk_xs.dtype) | |
| # transpose batch & num_layers dim | |
| att_cache = torch.transpose(att_cache, 0, 1) | |
| cnn_cache = torch.transpose(cnn_cache, 0, 1) | |
| # rewrite encoder.forward_chunk | |
| # <---------forward_chunk START---------> | |
| xs = self.global_cmvn(chunk_xs) | |
| # chunk mask is important for batch inferencing since | |
| # different sequence in a batch has different length | |
| xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) | |
| elayers, cache_size = att_cache.size(0), att_cache.size(3) | |
| att_mask = torch.cat((cache_mask, chunk_mask), dim=2) | |
| index = offset - cache_size | |
| pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) | |
| pos_emb = pos_emb.to(dtype=xs.dtype) | |
| next_cache_start = -self.required_cache_size | |
| r_cache_mask = att_mask[:, :, next_cache_start:] | |
| r_att_cache = [] | |
| r_cnn_cache = [] | |
| mask_pad = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) | |
| mask_pad = mask_pad.unsqueeze(1) | |
| max_att_len: int = 0 | |
| recover_activations: List[ | |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | |
| ] = [] | |
| index = 0 | |
| xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int) | |
| xs = self.encoder.preln(xs) | |
| for i, layer in enumerate(self.encoder.encoders): | |
| if self.reduce_idx is not None: | |
| if self.time_reduce is not None and i in self.reduce_idx: | |
| recover_activations.append((xs, att_mask, pos_emb, mask_pad)) | |
| xs, xs_lens, att_mask, mask_pad = self.encoder.time_reduction_layer( | |
| xs, xs_lens, att_mask, mask_pad | |
| ) | |
| pos_emb = pos_emb[:, ::2, :] | |
| if self.encoder.pos_enc_layer_type == "rel_pos_repaired": | |
| pos_emb = pos_emb[:, : xs.size(1) * 2 - 1, :] | |
| index += 1 | |
| if self.recover_idx is not None: | |
| if self.time_reduce == "recover" and i in self.recover_idx: | |
| index -= 1 | |
| ( | |
| recover_tensor, | |
| recover_att_mask, | |
| recover_pos_emb, | |
| recover_mask_pad, | |
| ) = recover_activations[index] | |
| # recover output length for ctc decode | |
| xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2) | |
| xs = self.encoder.time_recover_layer(xs) | |
| recoverd_t = recover_tensor.size(1) | |
| xs = recover_tensor + xs[:, :recoverd_t, :].contiguous() | |
| att_mask = recover_att_mask | |
| pos_emb = recover_pos_emb | |
| mask_pad = recover_mask_pad | |
| factor = self.calculate_downsampling_factor(i) | |
| xs, _, new_att_cache, new_cnn_cache = layer( | |
| xs, | |
| att_mask, | |
| pos_emb, | |
| att_cache=att_cache[i][:, :, ::factor, :][ | |
| :, :, : pos_emb.size(1) - xs.size(1), : | |
| ] | |
| if elayers > 0 | |
| else att_cache[:, :, ::factor, :], | |
| cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache, | |
| ) | |
| cached_att = new_att_cache[:, :, next_cache_start // factor :, :] | |
| cached_cnn = new_cnn_cache.unsqueeze(1) | |
| cached_att = ( | |
| cached_att.unsqueeze(3).repeat(1, 1, 1, factor, 1).flatten(2, 3) | |
| ) | |
| if i == 0: | |
| # record length for the first block as max length | |
| max_att_len = cached_att.size(2) | |
| r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1)) | |
| r_cnn_cache.append(cached_cnn) | |
| chunk_out = xs | |
| r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx | |
| r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers | |
| # <---------forward_chunk END---------> | |
| log_ctc_probs = self.ctc.log_softmax(chunk_out) | |
| log_probs, log_probs_idx = torch.topk(log_ctc_probs, self.beam_size, dim=2) | |
| log_probs = log_probs.to(chunk_xs.dtype) | |
| r_offset = offset + chunk_out.shape[1] | |
| # the below ops not supported in Tensorrt | |
| # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, | |
| # rounding_mode='floor') | |
| chunk_out_lens = chunk_lens // self.subsampling_rate | |
| r_offset = r_offset.unsqueeze(1) | |
| return ( | |
| log_probs, | |
| log_probs_idx, | |
| chunk_out, | |
| chunk_out_lens, | |
| r_offset, | |
| r_att_cache, | |
| r_cnn_cache, | |
| r_cache_mask, | |
| ) | |
| class StreamingEfficientConformerEncoder(torch.nn.Module): | |
| def __init__(self, model, required_cache_size, beam_size): | |
| super().__init__() | |
| self.ctc = model.ctc | |
| self.subsampling_rate = model.encoder.embed.subsampling_rate | |
| self.embed = model.encoder.embed | |
| self.global_cmvn = model.encoder.global_cmvn | |
| self.required_cache_size = required_cache_size | |
| self.beam_size = beam_size | |
| self.encoder = model.encoder | |
| # Efficient Conformer | |
| self.stride_layer_idx = model.encoder.stride_layer_idx | |
| self.stride = model.encoder.stride | |
| self.num_blocks = model.encoder.num_blocks | |
| self.cnn_module_kernel = model.encoder.cnn_module_kernel | |
| def calculate_downsampling_factor(self, i: int) -> int: | |
| factor = 1 | |
| for idx, stride_idx in enumerate(self.stride_layer_idx): | |
| if i > stride_idx: | |
| factor *= self.stride[idx] | |
| return factor | |
| def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask): | |
| """Streaming Encoder | |
| Args: | |
| chunk_xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), | |
| where `time == (chunk_size - 1) * subsample_rate + \ | |
| subsample.right_context + 1` | |
| chunk_lens (torch.Tensor): | |
| offset (torch.Tensor): offset with shape (b, 1) | |
| 1 is retained for triton deployment | |
| att_cache (torch.Tensor): cache tensor for KEY & VALUE in | |
| transformer/conformer attention, with shape | |
| (b, elayers, head, cache_t1, d_k * 2), where | |
| `head * d_k == hidden-dim` and | |
| `cache_t1 == chunk_size * num_decoding_left_chunks`. | |
| cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, | |
| (b, elayers, hidden-dim, cache_t2), where | |
| `cache_t2 == cnn.lorder - 1` | |
| cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) | |
| in a batch of request, each request may have different | |
| history cache. Cache mask is used to indidate the effective | |
| cache for each request | |
| Returns: | |
| torch.Tensor: log probabilities of ctc output and cutoff by beam size | |
| with shape (b, chunk_size, beam) | |
| torch.Tensor: index of top beam size probabilities for each timestep | |
| with shape (b, chunk_size, beam) | |
| torch.Tensor: output of current input xs, | |
| with shape (b, chunk_size, hidden-dim). | |
| torch.Tensor: new attention cache required for next chunk, with | |
| same shape (b, elayers, head, cache_t1, d_k * 2) | |
| as the original att_cache | |
| torch.Tensor: new conformer cnn cache required for next chunk, with | |
| same shape as the original cnn_cache. | |
| torch.Tensor: new cache mask, with same shape as the original | |
| cache mask | |
| """ | |
| offset = offset.squeeze(1) # (b, ) | |
| offset *= self.calculate_downsampling_factor(self.num_blocks + 1) | |
| T = chunk_xs.size(1) | |
| chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) # (b, 1, T) | |
| # B X 1 X T | |
| chunk_mask = chunk_mask.to(chunk_xs.dtype) | |
| # transpose batch & num_layers dim | |
| # Shape(att_cache): (elayers, b, head, cache_t1, d_k * 2) | |
| # Shape(cnn_cache): (elayers, b, outsize, cnn_kernel) | |
| att_cache = torch.transpose(att_cache, 0, 1) | |
| cnn_cache = torch.transpose(cnn_cache, 0, 1) | |
| # rewrite encoder.forward_chunk | |
| # <---------forward_chunk START---------> | |
| xs = self.global_cmvn(chunk_xs) | |
| # chunk mask is important for batch inferencing since | |
| # different sequence in a batch has different length | |
| xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) | |
| cache_size = att_cache.size(3) # required cache size | |
| masks = torch.cat((cache_mask, chunk_mask), dim=2) | |
| att_mask = torch.cat((cache_mask, chunk_mask), dim=2) | |
| index = offset - cache_size | |
| pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) | |
| pos_emb = pos_emb.to(dtype=xs.dtype) | |
| next_cache_start = -self.required_cache_size | |
| r_cache_mask = masks[:, :, next_cache_start:] | |
| r_att_cache = [] | |
| r_cnn_cache = [] | |
| mask_pad = chunk_mask.to(torch.bool) | |
| max_att_len, max_cnn_len = 0, 0 # for repeat_interleave of new_att_cache | |
| for i, layer in enumerate(self.encoder.encoders): | |
| factor = self.calculate_downsampling_factor(i) | |
| # NOTE(xcsong): Before layer.forward | |
| # shape(att_cache[i:i + 1]) is (b, head, cache_t1, d_k * 2), | |
| # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) | |
| # shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ] | |
| att_cache_trunc = 0 | |
| if xs.size(1) + att_cache.size(3) / factor > pos_emb.size(1): | |
| # The time step is not divisible by the downsampling multiple | |
| # We propose to double the chunk_size. | |
| att_cache_trunc = ( | |
| xs.size(1) + att_cache.size(3) // factor - pos_emb.size(1) + 1 | |
| ) | |
| xs, _, new_att_cache, new_cnn_cache = layer( | |
| xs, | |
| att_mask, | |
| pos_emb, | |
| mask_pad=mask_pad, | |
| att_cache=att_cache[i][:, :, ::factor, :][:, :, att_cache_trunc:, :], | |
| cnn_cache=cnn_cache[i, :, :, :] if cnn_cache.size(0) > 0 else cnn_cache, | |
| ) | |
| if i in self.stride_layer_idx: | |
| # compute time dimension for next block | |
| efficient_index = self.stride_layer_idx.index(i) | |
| att_mask = att_mask[ | |
| :, :: self.stride[efficient_index], :: self.stride[efficient_index] | |
| ] | |
| mask_pad = mask_pad[ | |
| :, :: self.stride[efficient_index], :: self.stride[efficient_index] | |
| ] | |
| pos_emb = pos_emb[:, :: self.stride[efficient_index], :] | |
| # shape(new_att_cache) = [batch, head, time2, outdim] | |
| new_att_cache = new_att_cache[:, :, next_cache_start // factor :, :] | |
| # shape(new_cnn_cache) = [batch, 1, outdim, cache_t2] | |
| new_cnn_cache = new_cnn_cache.unsqueeze(1) # shape(1):layerID | |
| # use repeat_interleave to new_att_cache | |
| # new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2) | |
| new_att_cache = ( | |
| new_att_cache.unsqueeze(3).repeat(1, 1, 1, factor, 1).flatten(2, 3) | |
| ) | |
| # padding new_cnn_cache to cnn.lorder for casual convolution | |
| new_cnn_cache = F.pad( | |
| new_cnn_cache, (self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0) | |
| ) | |
| if i == 0: | |
| # record length for the first block as max length | |
| max_att_len = new_att_cache.size(2) | |
| max_cnn_len = new_cnn_cache.size(3) | |
| # update real shape of att_cache and cnn_cache | |
| r_att_cache.append(new_att_cache[:, :, -max_att_len:, :].unsqueeze(1)) | |
| r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:]) | |
| if self.encoder.normalize_before: | |
| chunk_out = self.encoder.after_norm(xs) | |
| else: | |
| chunk_out = xs | |
| # shape of r_att_cache: (b, elayers, head, time2, outdim) | |
| r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx | |
| # shape of r_cnn_cache: (b, elayers, outdim, cache_t2) | |
| r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers | |
| # <---------forward_chunk END---------> | |
| log_ctc_probs = self.ctc.log_softmax(chunk_out) | |
| log_probs, log_probs_idx = torch.topk(log_ctc_probs, self.beam_size, dim=2) | |
| log_probs = log_probs.to(chunk_xs.dtype) | |
| r_offset = offset + chunk_out.shape[1] | |
| # the below ops not supported in Tensorrt | |
| # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, | |
| # rounding_mode='floor') | |
| chunk_out_lens = ( | |
| chunk_lens | |
| // self.subsampling_rate | |
| // self.calculate_downsampling_factor(self.num_blocks + 1) | |
| ) | |
| chunk_out_lens += 1 | |
| r_offset = r_offset.unsqueeze(1) | |
| return ( | |
| log_probs, | |
| log_probs_idx, | |
| chunk_out, | |
| chunk_out_lens, | |
| r_offset, | |
| r_att_cache, | |
| r_cnn_cache, | |
| r_cache_mask, | |
| ) | |
| class Decoder(torch.nn.Module): | |
| def __init__( | |
| self, | |
| decoder: TransformerDecoder, | |
| ctc_weight: float = 0.5, | |
| reverse_weight: float = 0.0, | |
| beam_size: int = 10, | |
| decoder_fastertransformer: bool = False, | |
| ): | |
| super().__init__() | |
| self.decoder = decoder | |
| self.ctc_weight = ctc_weight | |
| self.reverse_weight = reverse_weight | |
| self.beam_size = beam_size | |
| self.decoder_fastertransformer = decoder_fastertransformer | |
| def forward( | |
| self, | |
| encoder_out: torch.Tensor, | |
| encoder_lens: torch.Tensor, | |
| hyps_pad_sos_eos: torch.Tensor, | |
| hyps_lens_sos: torch.Tensor, | |
| r_hyps_pad_sos_eos: torch.Tensor, | |
| ctc_score: torch.Tensor, | |
| ): | |
| """Encoder | |
| Args: | |
| encoder_out: B x T x F | |
| encoder_lens: B | |
| hyps_pad_sos_eos: B x beam x (T2+1), | |
| hyps with sos & eos and padded by ignore id | |
| hyps_lens_sos: B x beam, length for each hyp with sos | |
| r_hyps_pad_sos_eos: B x beam x (T2+1), | |
| reversed hyps with sos & eos and padded by ignore id | |
| ctc_score: B x beam, ctc score for each hyp | |
| Returns: | |
| decoder_out: B x beam x T2 x V | |
| r_decoder_out: B x beam x T2 x V | |
| best_index: B | |
| """ | |
| B, T, F = encoder_out.shape | |
| bz = self.beam_size | |
| B2 = B * bz | |
| encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) | |
| encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1) | |
| encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T) | |
| T2 = hyps_pad_sos_eos.shape[2] - 1 | |
| hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1) | |
| hyps_lens = hyps_lens_sos.view( | |
| B2, | |
| ) | |
| hyps_pad_sos = hyps_pad[:, :-1].contiguous() | |
| hyps_pad_eos = hyps_pad[:, 1:].contiguous() | |
| r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1) | |
| r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous() | |
| r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous() | |
| decoder_out, r_decoder_out, _ = self.decoder( | |
| encoder_out, | |
| encoder_mask, | |
| hyps_pad_sos, | |
| hyps_lens, | |
| r_hyps_pad_sos, | |
| self.reverse_weight, | |
| ) | |
| decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) | |
| V = decoder_out.shape[-1] | |
| decoder_out = decoder_out.view(B2, T2, V) | |
| mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2 | |
| # mask index, remove ignore id | |
| index = torch.unsqueeze(hyps_pad_eos * mask, 2) | |
| score = decoder_out.gather(2, index).squeeze(2) # B2 X T2 | |
| # mask padded part | |
| score = score * mask | |
| decoder_out = decoder_out.view(B, bz, T2, V) | |
| if self.reverse_weight > 0: | |
| r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) | |
| r_decoder_out = r_decoder_out.view(B2, T2, V) | |
| index = torch.unsqueeze(r_hyps_pad_eos * mask, 2) | |
| r_score = r_decoder_out.gather(2, index).squeeze(2) | |
| r_score = r_score * mask | |
| score = score * (1 - self.reverse_weight) + self.reverse_weight * r_score | |
| r_decoder_out = r_decoder_out.view(B, bz, T2, V) | |
| score = torch.sum(score, axis=1) # B2 | |
| score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score | |
| best_index = torch.argmax(score, dim=1) | |
| if self.decoder_fastertransformer: | |
| return decoder_out, best_index | |
| else: | |
| return best_index | |
| def to_numpy(tensors): | |
| out = [] | |
| if type(tensors) == torch.tensor: | |
| tensors = [tensors] | |
| for tensor in tensors: | |
| if tensor.requires_grad: | |
| tensor = tensor.detach().cpu().numpy() | |
| else: | |
| tensor = tensor.cpu().numpy() | |
| out.append(tensor) | |
| return out | |
| def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): | |
| for a, b in zip(xlist, blist): | |
| try: | |
| torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) | |
| except AssertionError as error: | |
| if tolerate_small_mismatch: | |
| print(error) | |
| else: | |
| raise | |
| def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): | |
| bz = 32 | |
| seq_len = 100 | |
| beam_size = args.beam_size | |
| feature_size = configs["input_dim"] | |
| speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32) | |
| speech_lens = torch.randint(low=10, high=seq_len, size=(bz,), dtype=torch.int32) | |
| encoder = Encoder(model.encoder, model.ctc, beam_size) | |
| encoder.eval() | |
| torch.onnx.export( | |
| encoder, | |
| (speech, speech_lens), | |
| encoder_onnx_path, | |
| export_params=True, | |
| opset_version=13, | |
| do_constant_folding=True, | |
| input_names=["speech", "speech_lengths"], | |
| output_names=[ | |
| "encoder_out", | |
| "encoder_out_lens", | |
| "ctc_log_probs", | |
| "beam_log_probs", | |
| "beam_log_probs_idx", | |
| ], | |
| dynamic_axes={ | |
| "speech": {0: "B", 1: "T"}, | |
| "speech_lengths": {0: "B"}, | |
| "encoder_out": {0: "B", 1: "T_OUT"}, | |
| "encoder_out_lens": {0: "B"}, | |
| "ctc_log_probs": {0: "B", 1: "T_OUT"}, | |
| "beam_log_probs": {0: "B", 1: "T_OUT"}, | |
| "beam_log_probs_idx": {0: "B", 1: "T_OUT"}, | |
| }, | |
| verbose=False, | |
| ) | |
| with torch.no_grad(): | |
| o0, o1, o2, o3, o4 = encoder(speech, speech_lens) | |
| providers = ["CUDAExecutionProvider"] | |
| ort_session = onnxruntime.InferenceSession(encoder_onnx_path, providers=providers) | |
| ort_inputs = {"speech": to_numpy(speech), "speech_lengths": to_numpy(speech_lens)} | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| # check encoder output | |
| test(to_numpy([o0, o1, o2, o3, o4]), ort_outs) | |
| logger.info("export offline onnx encoder succeed!") | |
| onnx_config = { | |
| "beam_size": args.beam_size, | |
| "reverse_weight": args.reverse_weight, | |
| "ctc_weight": args.ctc_weight, | |
| "fp16": args.fp16, | |
| } | |
| return onnx_config | |
| def export_online_encoder(model, configs, args, logger, encoder_onnx_path): | |
| decoding_chunk_size = args.decoding_chunk_size | |
| subsampling = model.encoder.embed.subsampling_rate | |
| context = model.encoder.embed.right_context + 1 | |
| decoding_window = (decoding_chunk_size - 1) * subsampling + context | |
| batch_size = 32 | |
| audio_len = decoding_window | |
| feature_size = configs["input_dim"] | |
| output_size = configs["encoder_conf"]["output_size"] | |
| num_layers = configs["encoder_conf"]["num_blocks"] | |
| # in transformer the cnn module will not be available | |
| transformer = False | |
| cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1 | |
| if not cnn_module_kernel: | |
| transformer = True | |
| num_decoding_left_chunks = args.num_decoding_left_chunks | |
| required_cache_size = decoding_chunk_size * num_decoding_left_chunks | |
| if configs["encoder"] == "squeezeformer": | |
| encoder = StreamingSqueezeformerEncoder( | |
| model, required_cache_size, args.beam_size | |
| ) | |
| elif configs["encoder"] == "efficientConformer": | |
| encoder = StreamingEfficientConformerEncoder( | |
| model, required_cache_size, args.beam_size | |
| ) | |
| else: | |
| encoder = StreamingEncoder( | |
| model, required_cache_size, args.beam_size, transformer | |
| ) | |
| encoder.eval() | |
| # begin to export encoder | |
| chunk_xs = torch.randn(batch_size, audio_len, feature_size, dtype=torch.float32) | |
| chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len | |
| offset = torch.arange(0, batch_size).unsqueeze(1) | |
| # (elayers, b, head, cache_t1, d_k * 2) | |
| head = configs["encoder_conf"]["attention_heads"] | |
| d_k = configs["encoder_conf"]["output_size"] // head | |
| att_cache = torch.randn( | |
| batch_size, num_layers, head, required_cache_size, d_k * 2, dtype=torch.float32 | |
| ) | |
| cnn_cache = torch.randn( | |
| batch_size, num_layers, output_size, cnn_module_kernel, dtype=torch.float32 | |
| ) | |
| cache_mask = torch.ones(batch_size, 1, required_cache_size, dtype=torch.float32) | |
| input_names = [ | |
| "chunk_xs", | |
| "chunk_lens", | |
| "offset", | |
| "att_cache", | |
| "cnn_cache", | |
| "cache_mask", | |
| ] | |
| output_names = [ | |
| "log_probs", | |
| "log_probs_idx", | |
| "chunk_out", | |
| "chunk_out_lens", | |
| "r_offset", | |
| "r_att_cache", | |
| "r_cnn_cache", | |
| "r_cache_mask", | |
| ] | |
| input_tensors = (chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask) | |
| if transformer: | |
| output_names.pop(6) | |
| all_names = input_names + output_names | |
| dynamic_axes = {} | |
| for name in all_names: | |
| # only the first dimension is dynamic | |
| # all other dimension is fixed | |
| dynamic_axes[name] = {0: "B"} | |
| torch.onnx.export( | |
| encoder, | |
| input_tensors, | |
| encoder_onnx_path, | |
| export_params=True, | |
| opset_version=14, | |
| do_constant_folding=True, | |
| input_names=input_names, | |
| output_names=output_names, | |
| dynamic_axes=dynamic_axes, | |
| verbose=False, | |
| ) | |
| with torch.no_grad(): | |
| torch_outs = encoder( | |
| chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask | |
| ) | |
| if transformer: | |
| torch_outs = list(torch_outs).pop(6) | |
| ort_session = onnxruntime.InferenceSession( | |
| encoder_onnx_path, providers=["CUDAExecutionProvider"] | |
| ) | |
| ort_inputs = {} | |
| input_tensors = to_numpy(input_tensors) | |
| for idx, name in enumerate(input_names): | |
| ort_inputs[name] = input_tensors[idx] | |
| if transformer: | |
| del ort_inputs["cnn_cache"] | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| test(to_numpy(torch_outs), ort_outs, rtol=1e-03, atol=1e-05) | |
| logger.info("export to onnx streaming encoder succeed!") | |
| onnx_config = { | |
| "subsampling_rate": subsampling, | |
| "context": context, | |
| "decoding_chunk_size": decoding_chunk_size, | |
| "num_decoding_left_chunks": num_decoding_left_chunks, | |
| "beam_size": args.beam_size, | |
| "fp16": args.fp16, | |
| "feat_size": feature_size, | |
| "decoding_window": decoding_window, | |
| "cnn_module_kernel_cache": cnn_module_kernel, | |
| } | |
| return onnx_config | |
| def export_rescoring_decoder( | |
| model, configs, args, logger, decoder_onnx_path, decoder_fastertransformer | |
| ): | |
| bz, seq_len = 32, 100 | |
| beam_size = args.beam_size | |
| decoder = Decoder( | |
| model.decoder, | |
| model.ctc_weight, | |
| model.reverse_weight, | |
| beam_size, | |
| decoder_fastertransformer, | |
| ) | |
| decoder.eval() | |
| hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len)) | |
| hyps_lens_sos = torch.randint( | |
| low=3, high=seq_len, size=(bz, beam_size), dtype=torch.int32 | |
| ) | |
| r_hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len)) | |
| output_size = configs["encoder_conf"]["output_size"] | |
| encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32) | |
| encoder_out_lens = torch.randint(low=3, high=seq_len, size=(bz,), dtype=torch.int32) | |
| ctc_score = torch.randn(bz, beam_size, dtype=torch.float32) | |
| input_names = [ | |
| "encoder_out", | |
| "encoder_out_lens", | |
| "hyps_pad_sos_eos", | |
| "hyps_lens_sos", | |
| "r_hyps_pad_sos_eos", | |
| "ctc_score", | |
| ] | |
| output_names = ["best_index"] | |
| if decoder_fastertransformer: | |
| output_names.insert(0, "decoder_out") | |
| torch.onnx.export( | |
| decoder, | |
| ( | |
| encoder_out, | |
| encoder_out_lens, | |
| hyps_pad_sos_eos, | |
| hyps_lens_sos, | |
| r_hyps_pad_sos_eos, | |
| ctc_score, | |
| ), | |
| decoder_onnx_path, | |
| export_params=True, | |
| opset_version=13, | |
| do_constant_folding=True, | |
| input_names=input_names, | |
| output_names=output_names, | |
| dynamic_axes={ | |
| "encoder_out": {0: "B", 1: "T"}, | |
| "encoder_out_lens": {0: "B"}, | |
| "hyps_pad_sos_eos": {0: "B", 2: "T2"}, | |
| "hyps_lens_sos": {0: "B"}, | |
| "r_hyps_pad_sos_eos": {0: "B", 2: "T2"}, | |
| "ctc_score": {0: "B"}, | |
| "best_index": {0: "B"}, | |
| }, | |
| verbose=False, | |
| ) | |
| with torch.no_grad(): | |
| o0 = decoder( | |
| encoder_out, | |
| encoder_out_lens, | |
| hyps_pad_sos_eos, | |
| hyps_lens_sos, | |
| r_hyps_pad_sos_eos, | |
| ctc_score, | |
| ) | |
| providers = ["CUDAExecutionProvider"] | |
| ort_session = onnxruntime.InferenceSession(decoder_onnx_path, providers=providers) | |
| input_tensors = [ | |
| encoder_out, | |
| encoder_out_lens, | |
| hyps_pad_sos_eos, | |
| hyps_lens_sos, | |
| r_hyps_pad_sos_eos, | |
| ctc_score, | |
| ] | |
| ort_inputs = {} | |
| input_tensors = to_numpy(input_tensors) | |
| for idx, name in enumerate(input_names): | |
| ort_inputs[name] = input_tensors[idx] | |
| # if model.reverse weight == 0, | |
| # the r_hyps_pad will be removed | |
| # from the onnx decoder since it doen't play any role | |
| if model.reverse_weight == 0: | |
| del ort_inputs["r_hyps_pad_sos_eos"] | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| # check decoder output | |
| if decoder_fastertransformer: | |
| test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05) | |
| else: | |
| test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05) | |
| logger.info("export to onnx decoder succeed!") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="export x86_gpu model") | |
| parser.add_argument("--config", required=True, help="config file") | |
| parser.add_argument("--checkpoint", required=True, help="checkpoint model") | |
| parser.add_argument( | |
| "--cmvn_file", | |
| required=False, | |
| default="", | |
| type=str, | |
| help="global_cmvn file, default path is in config file", | |
| ) | |
| parser.add_argument( | |
| "--reverse_weight", | |
| default=-1.0, | |
| type=float, | |
| required=False, | |
| help="reverse weight for bitransformer," + "default value is in config file", | |
| ) | |
| parser.add_argument( | |
| "--ctc_weight", | |
| default=-1.0, | |
| type=float, | |
| required=False, | |
| help="ctc weight, default value is in config file", | |
| ) | |
| parser.add_argument( | |
| "--beam_size", | |
| default=10, | |
| type=int, | |
| required=False, | |
| help="beam size would be ctc output size", | |
| ) | |
| parser.add_argument( | |
| "--output_onnx_dir", | |
| default="onnx_model", | |
| help="output onnx encoder and decoder directory", | |
| ) | |
| parser.add_argument( | |
| "--fp16", | |
| action="store_true", | |
| help="whether to export fp16 model, default false", | |
| ) | |
| # arguments for streaming encoder | |
| parser.add_argument( | |
| "--streaming", | |
| action="store_true", | |
| help="whether to export streaming encoder, default false", | |
| ) | |
| parser.add_argument( | |
| "--decoding_chunk_size", | |
| default=16, | |
| type=int, | |
| required=False, | |
| help="the decoding chunk size, <=0 is not supported", | |
| ) | |
| parser.add_argument( | |
| "--num_decoding_left_chunks", | |
| default=5, | |
| type=int, | |
| required=False, | |
| help="number of left chunks, <= 0 is not supported", | |
| ) | |
| parser.add_argument( | |
| "--decoder_fastertransformer", | |
| action="store_true", | |
| help="return decoder_out and best_index for ft", | |
| ) | |
| args = parser.parse_args() | |
| torch.manual_seed(0) | |
| torch.set_printoptions(precision=10) | |
| with open(args.config, "r") as fin: | |
| configs = yaml.load(fin, Loader=yaml.FullLoader) | |
| if args.cmvn_file and os.path.exists(args.cmvn_file): | |
| configs["cmvn_file"] = args.cmvn_file | |
| if args.reverse_weight != -1.0 and "reverse_weight" in configs["model_conf"]: | |
| configs["model_conf"]["reverse_weight"] = args.reverse_weight | |
| print("Update reverse weight to", args.reverse_weight) | |
| if args.ctc_weight != -1: | |
| print("Update ctc weight to ", args.ctc_weight) | |
| configs["model_conf"]["ctc_weight"] = args.ctc_weight | |
| configs["encoder_conf"]["use_dynamic_chunk"] = False | |
| model = init_model(configs) | |
| load_checkpoint(model, args.checkpoint) | |
| model.eval() | |
| if not os.path.exists(args.output_onnx_dir): | |
| os.mkdir(args.output_onnx_dir) | |
| encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder.onnx") | |
| export_enc_func = None | |
| if args.streaming: | |
| assert args.decoding_chunk_size > 0 | |
| assert args.num_decoding_left_chunks > 0 | |
| export_enc_func = export_online_encoder | |
| else: | |
| export_enc_func = export_offline_encoder | |
| onnx_config = export_enc_func(model, configs, args, logger, encoder_onnx_path) | |
| decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder.onnx") | |
| export_rescoring_decoder( | |
| model, configs, args, logger, decoder_onnx_path, args.decoder_fastertransformer | |
| ) | |
| if args.fp16: | |
| try: | |
| import onnxmltools | |
| from onnxmltools.utils.float16_converter import convert_float_to_float16 | |
| except ImportError: | |
| print("Please install onnxmltools!") | |
| sys.exit(1) | |
| encoder_onnx_model = onnxmltools.utils.load_model(encoder_onnx_path) | |
| encoder_onnx_model = convert_float_to_float16(encoder_onnx_model) | |
| encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder_fp16.onnx") | |
| onnxmltools.utils.save_model(encoder_onnx_model, encoder_onnx_path) | |
| decoder_onnx_model = onnxmltools.utils.load_model(decoder_onnx_path) | |
| decoder_onnx_model = convert_float_to_float16(decoder_onnx_model) | |
| decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder_fp16.onnx") | |
| onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path) | |
| # dump configurations | |
| config_dir = os.path.join(args.output_onnx_dir, "config.yaml") | |
| with open(config_dir, "w") as out: | |
| yaml.dump(onnx_config, out) | |