Spaces:
Runtime error
Runtime error
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torchaudio | |
| from torch import nn | |
| from torch.nn.utils.rnn import pad_sequence | |
| from wenet.transducer.predictor import PredictorBase | |
| from wenet.transducer.search.greedy_search import basic_greedy_search | |
| from wenet.transducer.search.prefix_beam_search import PrefixBeamSearch | |
| from wenet.transformer.asr_model import ASRModel | |
| from wenet.transformer.ctc import CTC | |
| from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder | |
| from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss | |
| from wenet.utils.common import (IGNORE_ID, add_blank, add_sos_eos, | |
| reverse_pad_list, TORCH_NPU_AVAILABLE) | |
| class Transducer(ASRModel): | |
| """Transducer-ctc-attention hybrid Encoder-Predictor-Decoder model""" | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| blank: int, | |
| encoder: nn.Module, | |
| predictor: PredictorBase, | |
| joint: nn.Module, | |
| attention_decoder: Optional[Union[TransformerDecoder, | |
| BiTransformerDecoder]] = None, | |
| ctc: Optional[CTC] = None, | |
| ctc_weight: float = 0, | |
| ignore_id: int = IGNORE_ID, | |
| reverse_weight: float = 0.0, | |
| lsm_weight: float = 0.0, | |
| length_normalized_loss: bool = False, | |
| transducer_weight: float = 1.0, | |
| attention_weight: float = 0.0, | |
| enable_k2: bool = False, | |
| delay_penalty: float = 0.0, | |
| warmup_steps: float = 25000, | |
| lm_only_scale: float = 0.25, | |
| am_only_scale: float = 0.0, | |
| special_tokens: dict = None, | |
| ) -> None: | |
| assert attention_weight + ctc_weight + transducer_weight == 1.0 | |
| super().__init__(vocab_size, | |
| encoder, | |
| attention_decoder, | |
| ctc, | |
| ctc_weight, | |
| ignore_id, | |
| reverse_weight, | |
| lsm_weight, | |
| length_normalized_loss, | |
| special_tokens=special_tokens) | |
| self.blank = blank | |
| self.transducer_weight = transducer_weight | |
| self.attention_decoder_weight = 1 - self.transducer_weight - self.ctc_weight | |
| self.predictor = predictor | |
| self.joint = joint | |
| self.bs = None | |
| # k2 rnnt loss | |
| self.enable_k2 = enable_k2 | |
| self.delay_penalty = delay_penalty | |
| if delay_penalty != 0.0: | |
| assert self.enable_k2 is True | |
| self.lm_only_scale = lm_only_scale | |
| self.am_only_scale = am_only_scale | |
| self.warmup_steps = warmup_steps | |
| self.simple_am_proj: Optional[nn.Linear] = None | |
| self.simple_lm_proj: Optional[nn.Linear] = None | |
| if self.enable_k2: | |
| self.simple_am_proj = torch.nn.Linear(self.encoder.output_size(), | |
| vocab_size) | |
| self.simple_lm_proj = torch.nn.Linear(self.predictor.output_size(), | |
| vocab_size) | |
| # Note(Mddct): decoder also means predictor in transducer, | |
| # but here decoder is attention decoder | |
| del self.criterion_att | |
| if attention_decoder is not None: | |
| self.criterion_att = LabelSmoothingLoss( | |
| size=vocab_size, | |
| padding_idx=ignore_id, | |
| smoothing=lsm_weight, | |
| normalize_length=length_normalized_loss, | |
| ) | |
| def forward( | |
| self, | |
| batch: dict, | |
| device: torch.device, | |
| ) -> Dict[str, Optional[torch.Tensor]]: | |
| """Frontend + Encoder + predictor + joint + loss | |
| """ | |
| self.device = device | |
| speech = batch['feats'].to(device) | |
| speech_lengths = batch['feats_lengths'].to(device) | |
| text = batch['target'].to(device) | |
| text_lengths = batch['target_lengths'].to(device) | |
| steps = batch.get('steps', 0) | |
| assert text_lengths.dim() == 1, text_lengths.shape | |
| # Check that batch_size is unified | |
| assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == | |
| text_lengths.shape[0]), (speech.shape, speech_lengths.shape, | |
| text.shape, text_lengths.shape) | |
| # Encoder | |
| encoder_out, encoder_mask = self.encoder(speech, speech_lengths) | |
| encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
| # compute_loss | |
| loss_rnnt = self._compute_loss(encoder_out, | |
| encoder_out_lens, | |
| encoder_mask, | |
| text, | |
| text_lengths, | |
| steps=steps) | |
| loss = self.transducer_weight * loss_rnnt | |
| # optional attention decoder | |
| loss_att: Optional[torch.Tensor] = None | |
| if self.attention_decoder_weight != 0.0 and self.decoder is not None: | |
| loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, | |
| text, text_lengths) | |
| else: | |
| acc_att = None | |
| # optional ctc | |
| loss_ctc: Optional[torch.Tensor] = None | |
| if self.ctc_weight != 0.0 and self.ctc is not None: | |
| loss_ctc, _ = self.ctc(encoder_out, encoder_out_lens, text, | |
| text_lengths) | |
| else: | |
| loss_ctc = None | |
| if loss_ctc is not None: | |
| loss = loss + self.ctc_weight * loss_ctc.sum() | |
| if loss_att is not None: | |
| loss = loss + self.attention_decoder_weight * loss_att.sum() | |
| # NOTE: 'loss' must be in dict | |
| return { | |
| 'loss': loss, | |
| 'loss_att': loss_att, | |
| 'loss_ctc': loss_ctc, | |
| 'loss_rnnt': loss_rnnt, | |
| 'th_accuracy': acc_att, | |
| } | |
| def init_bs(self): | |
| if self.bs is None: | |
| self.bs = PrefixBeamSearch(self.encoder, self.predictor, | |
| self.joint, self.ctc, self.blank) | |
| def _cal_transducer_score( | |
| self, | |
| encoder_out: torch.Tensor, | |
| encoder_mask: torch.Tensor, | |
| hyps_lens: torch.Tensor, | |
| hyps_pad: torch.Tensor, | |
| ): | |
| # ignore id -> blank, add blank at head | |
| hyps_pad_blank = add_blank(hyps_pad, self.blank, self.ignore_id) | |
| xs_in_lens = encoder_mask.squeeze(1).sum(1).int() | |
| # 1. Forward predictor | |
| predictor_out = self.predictor(hyps_pad_blank) | |
| # 2. Forward joint | |
| joint_out = self.joint(encoder_out, predictor_out) | |
| rnnt_text = hyps_pad.to(torch.int64) | |
| rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, | |
| rnnt_text).to(torch.int32) | |
| # 3. Compute transducer loss | |
| loss_td = torchaudio.functional.rnnt_loss(joint_out, | |
| rnnt_text, | |
| xs_in_lens, | |
| hyps_lens.int(), | |
| blank=self.blank, | |
| reduction='none') | |
| return loss_td * -1 | |
| def _cal_attn_score( | |
| self, | |
| encoder_out: torch.Tensor, | |
| encoder_mask: torch.Tensor, | |
| hyps_pad: torch.Tensor, | |
| hyps_lens: torch.Tensor, | |
| ): | |
| # (beam_size, max_hyps_len) | |
| ori_hyps_pad = hyps_pad | |
| # td_score = loss_td * -1 | |
| hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) | |
| hyps_lens = hyps_lens + 1 # Add <sos> at begining | |
| # used for right to left decoder | |
| r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) | |
| r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, | |
| self.ignore_id) | |
| decoder_out, r_decoder_out, _ = self.decoder( | |
| encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, | |
| self.reverse_weight) # (beam_size, max_hyps_len, vocab_size) | |
| decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) | |
| decoder_out = decoder_out.cpu().numpy() | |
| # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a | |
| # conventional transformer decoder. | |
| r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) | |
| r_decoder_out = r_decoder_out.cpu().numpy() | |
| return decoder_out, r_decoder_out | |
| def beam_search( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| decoding_chunk_size: int = -1, | |
| beam_size: int = 5, | |
| num_decoding_left_chunks: int = -1, | |
| simulate_streaming: bool = False, | |
| ctc_weight: float = 0.3, | |
| transducer_weight: float = 0.7, | |
| ): | |
| """beam search | |
| Args: | |
| speech (torch.Tensor): (batch=1, max_len, feat_dim) | |
| speech_length (torch.Tensor): (batch, ) | |
| beam_size (int): beam size for beam search | |
| decoding_chunk_size (int): decoding chunk for dynamic chunk | |
| trained model. | |
| <0: for decoding, use full chunk. | |
| >0: for decoding, use fixed chunk size as set. | |
| 0: used for training, it's prohibited here | |
| simulate_streaming (bool): whether do encoder forward in a | |
| streaming fashion | |
| ctc_weight (float): ctc probability weight in transducer | |
| prefix beam search. | |
| final_prob = ctc_weight * ctc_prob + transducer_weight * transducer_prob | |
| transducer_weight (float): transducer probability weight in | |
| prefix beam search | |
| Returns: | |
| List[List[int]]: best path result | |
| """ | |
| self.init_bs() | |
| beam, _ = self.bs.prefix_beam_search( | |
| speech, | |
| speech_lengths, | |
| decoding_chunk_size, | |
| beam_size, | |
| num_decoding_left_chunks, | |
| simulate_streaming, | |
| ctc_weight, | |
| transducer_weight, | |
| ) | |
| return beam[0].hyp[1:], beam[0].score | |
| def transducer_attention_rescoring( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| beam_size: int, | |
| decoding_chunk_size: int = -1, | |
| num_decoding_left_chunks: int = -1, | |
| simulate_streaming: bool = False, | |
| reverse_weight: float = 0.0, | |
| ctc_weight: float = 0.0, | |
| attn_weight: float = 0.0, | |
| transducer_weight: float = 0.0, | |
| search_ctc_weight: float = 1.0, | |
| search_transducer_weight: float = 0.0, | |
| beam_search_type: str = 'transducer') -> List[List[int]]: | |
| """beam search | |
| Args: | |
| speech (torch.Tensor): (batch=1, max_len, feat_dim) | |
| speech_length (torch.Tensor): (batch, ) | |
| beam_size (int): beam size for beam search | |
| decoding_chunk_size (int): decoding chunk for dynamic chunk | |
| trained model. | |
| <0: for decoding, use full chunk. | |
| >0: for decoding, use fixed chunk size as set. | |
| 0: used for training, it's prohibited here | |
| simulate_streaming (bool): whether do encoder forward in a | |
| streaming fashion | |
| ctc_weight (float): ctc probability weight using in rescoring. | |
| rescore_prob = ctc_weight * ctc_prob + | |
| transducer_weight * (transducer_loss * -1) + | |
| attn_weight * attn_prob | |
| attn_weight (float): attn probability weight using in rescoring. | |
| transducer_weight (float): transducer probability weight using in | |
| rescoring | |
| search_ctc_weight (float): ctc weight using | |
| in rnnt beam search (seeing in self.beam_search) | |
| search_transducer_weight (float): transducer weight using | |
| in rnnt beam search (seeing in self.beam_search) | |
| Returns: | |
| List[List[int]]: best path result | |
| """ | |
| assert speech.shape[0] == speech_lengths.shape[0] | |
| assert decoding_chunk_size != 0 | |
| if reverse_weight > 0.0: | |
| # decoder should be a bitransformer decoder if reverse_weight > 0.0 | |
| assert hasattr(self.decoder, 'right_decoder') | |
| device = speech.device | |
| batch_size = speech.shape[0] | |
| # For attention rescoring we only support batch_size=1 | |
| assert batch_size == 1 | |
| # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size | |
| self.init_bs() | |
| if beam_search_type == 'transducer': | |
| beam, encoder_out = self.bs.prefix_beam_search( | |
| speech, | |
| speech_lengths, | |
| decoding_chunk_size=decoding_chunk_size, | |
| beam_size=beam_size, | |
| num_decoding_left_chunks=num_decoding_left_chunks, | |
| ctc_weight=search_ctc_weight, | |
| transducer_weight=search_transducer_weight, | |
| ) | |
| beam_score = [s.score for s in beam] | |
| hyps = [s.hyp[1:] for s in beam] | |
| elif beam_search_type == 'ctc': | |
| hyps, encoder_out = self._ctc_prefix_beam_search( | |
| speech, | |
| speech_lengths, | |
| beam_size=beam_size, | |
| decoding_chunk_size=decoding_chunk_size, | |
| num_decoding_left_chunks=num_decoding_left_chunks, | |
| simulate_streaming=simulate_streaming) | |
| beam_score = [hyp[1] for hyp in hyps] | |
| hyps = [hyp[0] for hyp in hyps] | |
| assert len(hyps) == beam_size | |
| # build hyps and encoder output | |
| hyps_pad = pad_sequence([ | |
| torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps | |
| ], True, self.ignore_id) # (beam_size, max_hyps_len) | |
| hyps_lens = torch.tensor([len(hyp) for hyp in hyps], | |
| device=device, | |
| dtype=torch.long) # (beam_size,) | |
| encoder_out = encoder_out.repeat(beam_size, 1, 1) | |
| encoder_mask = torch.ones(beam_size, | |
| 1, | |
| encoder_out.size(1), | |
| dtype=torch.bool, | |
| device=device) | |
| # 2.1 calculate transducer score | |
| td_score = self._cal_transducer_score( | |
| encoder_out, | |
| encoder_mask, | |
| hyps_lens, | |
| hyps_pad, | |
| ) | |
| # 2.2 calculate attention score | |
| decoder_out, r_decoder_out = self._cal_attn_score( | |
| encoder_out, | |
| encoder_mask, | |
| hyps_pad, | |
| hyps_lens, | |
| ) | |
| # Only use decoder score for rescoring | |
| best_score = -float('inf') | |
| best_index = 0 | |
| for i, hyp in enumerate(hyps): | |
| score = 0.0 | |
| for j, w in enumerate(hyp): | |
| score += decoder_out[i][j][w] | |
| score += decoder_out[i][len(hyp)][self.eos] | |
| td_s = td_score[i] | |
| # add right to left decoder score | |
| if reverse_weight > 0: | |
| r_score = 0.0 | |
| for j, w in enumerate(hyp): | |
| r_score += r_decoder_out[i][len(hyp) - j - 1][w] | |
| r_score += r_decoder_out[i][len(hyp)][self.eos] | |
| score = score * (1 - reverse_weight) + r_score * reverse_weight | |
| # add ctc score | |
| score = score * attn_weight + \ | |
| beam_score[i] * ctc_weight + \ | |
| td_s * transducer_weight | |
| if score > best_score: | |
| best_score = score | |
| best_index = i | |
| return hyps[best_index], best_score | |
| def greedy_search( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| decoding_chunk_size: int = -1, | |
| num_decoding_left_chunks: int = -1, | |
| simulate_streaming: bool = False, | |
| n_steps: int = 64, | |
| ) -> List[List[int]]: | |
| """ greedy search | |
| Args: | |
| speech (torch.Tensor): (batch=1, max_len, feat_dim) | |
| speech_length (torch.Tensor): (batch, ) | |
| beam_size (int): beam size for beam search | |
| decoding_chunk_size (int): decoding chunk for dynamic chunk | |
| trained model. | |
| <0: for decoding, use full chunk. | |
| >0: for decoding, use fixed chunk size as set. | |
| 0: used for training, it's prohibited here | |
| simulate_streaming (bool): whether do encoder forward in a | |
| streaming fashion | |
| Returns: | |
| List[List[int]]: best path result | |
| """ | |
| # TODO(Mddct): batch decode | |
| assert speech.size(0) == 1 | |
| assert speech.shape[0] == speech_lengths.shape[0] | |
| assert decoding_chunk_size != 0 | |
| # TODO(Mddct): forward chunk by chunk | |
| _ = simulate_streaming | |
| # Let's assume B = batch_size | |
| encoder_out, encoder_mask = self.encoder( | |
| speech, | |
| speech_lengths, | |
| decoding_chunk_size, | |
| num_decoding_left_chunks, | |
| ) | |
| encoder_out_lens = encoder_mask.squeeze(1).sum() | |
| hyps = basic_greedy_search(self, | |
| encoder_out, | |
| encoder_out_lens, | |
| n_steps=n_steps) | |
| return hyps | |
| def forward_encoder_chunk( | |
| self, | |
| xs: torch.Tensor, | |
| offset: int, | |
| required_cache_size: int, | |
| att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
| cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| return self.encoder.forward_chunk(xs, offset, required_cache_size, | |
| att_cache, cnn_cache) | |
| def forward_predictor_step( | |
| self, xs: torch.Tensor, cache: List[torch.Tensor] | |
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| assert len(cache) == 2 | |
| # fake padding | |
| padding = torch.zeros(1, 1) | |
| return self.predictor.forward_step(xs, padding, cache) | |
| def forward_joint_step(self, enc_out: torch.Tensor, | |
| pred_out: torch.Tensor) -> torch.Tensor: | |
| return self.joint(enc_out, pred_out) | |
| def forward_predictor_init_state(self) -> List[torch.Tensor]: | |
| return self.predictor.init_state(1, device=torch.device("cpu")) | |
| def _compute_loss(self, | |
| encoder_out: torch.Tensor, | |
| encoder_out_lens: torch.Tensor, | |
| encoder_mask: torch.Tensor, | |
| text: torch.Tensor, | |
| text_lengths: torch.Tensor, | |
| steps: int = 0) -> torch.Tensor: | |
| ys_in_pad = add_blank(text, self.blank, self.ignore_id) | |
| # predictor | |
| predictor_out = self.predictor(ys_in_pad) | |
| if self.simple_lm_proj is None and self.simple_am_proj is None: | |
| # joint | |
| joint_out = self.joint(encoder_out, predictor_out) | |
| # NOTE(Mddct): some loss implementation require pad valid is zero | |
| # torch.int32 rnnt_loss required | |
| rnnt_text = text.to(torch.int64) | |
| rnnt_text = torch.where(rnnt_text == self.ignore_id, 0, | |
| rnnt_text).to(torch.int32) | |
| rnnt_text_lengths = text_lengths.to(torch.int32) | |
| encoder_out_lens = encoder_out_lens.to(torch.int32) | |
| loss = torchaudio.functional.rnnt_loss(joint_out, | |
| rnnt_text, | |
| encoder_out_lens, | |
| rnnt_text_lengths, | |
| blank=self.blank, | |
| reduction="mean") | |
| else: | |
| try: | |
| import k2 | |
| except ImportError: | |
| print('Error: k2 is not installed') | |
| delay_penalty = self.delay_penalty | |
| if steps < 2 * self.warmup_steps: | |
| delay_penalty = 0.00 | |
| ys_in_pad = ys_in_pad.type(torch.int64) | |
| boundary = torch.zeros((encoder_out.size(0), 4), | |
| dtype=torch.int64, | |
| device=encoder_out.device) | |
| boundary[:, 3] = encoder_mask.squeeze(1).sum(1) | |
| boundary[:, 2] = text_lengths | |
| rnnt_text = torch.where(text == self.ignore_id, 0, text) | |
| lm = self.simple_lm_proj(predictor_out) | |
| am = self.simple_am_proj(encoder_out) | |
| amp_autocast = torch.cuda.amp.autocast | |
| if "npu" in self.device.__str__() and TORCH_NPU_AVAILABLE: | |
| amp_autocast = torch.npu.amp.autocast | |
| with amp_autocast(enabled=False): | |
| simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( | |
| lm=lm.float(), | |
| am=am.float(), | |
| symbols=rnnt_text, | |
| termination_symbol=self.blank, | |
| lm_only_scale=self.lm_only_scale, | |
| am_only_scale=self.am_only_scale, | |
| boundary=boundary, | |
| reduction="sum", | |
| return_grad=True, | |
| delay_penalty=delay_penalty, | |
| ) | |
| # ranges : [B, T, prune_range] | |
| ranges = k2.get_rnnt_prune_ranges( | |
| px_grad=px_grad, | |
| py_grad=py_grad, | |
| boundary=boundary, | |
| s_range=5, | |
| ) | |
| am_pruned, lm_pruned = k2.do_rnnt_pruning( | |
| am=self.joint.enc_ffn(encoder_out), | |
| lm=self.joint.pred_ffn(predictor_out), | |
| ranges=ranges, | |
| ) | |
| logits = self.joint( | |
| am_pruned, | |
| lm_pruned, | |
| pre_project=False, | |
| ) | |
| with amp_autocast(enabled=False): | |
| pruned_loss = k2.rnnt_loss_pruned( | |
| logits=logits.float(), | |
| symbols=rnnt_text, | |
| ranges=ranges, | |
| termination_symbol=self.blank, | |
| boundary=boundary, | |
| reduction="sum", | |
| delay_penalty=delay_penalty, | |
| ) | |
| simple_loss_scale = 0.5 | |
| if steps < self.warmup_steps: | |
| simple_loss_scale = (1.0 - (steps / self.warmup_steps) * | |
| (1.0 - simple_loss_scale)) | |
| pruned_loss_scale = 1.0 | |
| if steps < self.warmup_steps: | |
| pruned_loss_scale = 0.1 + 0.9 * (steps / self.warmup_steps) | |
| loss = (simple_loss_scale * simple_loss + | |
| pruned_loss_scale * pruned_loss) | |
| loss = loss / encoder_out.size(0) | |
| return loss | |