# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) # 2024 Alibaba Inc (Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Positonal Encoding Module.""" import math from typing import Tuple, Union import torch import torch.nn.functional as F import numpy as np class EspnetRelPositionalEncoding(torch.nn.Module): """Relative positional encoding module (new implementation). Details can be found in https://github.com/espnet/espnet/pull/2816. See : Appendix B in https://arxiv.org/abs/1901.02860 Args: d_model (int): Embedding dimension. dropout_rate (float): Dropout rate. max_len (int): Maximum input length. """ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): """Construct an PositionalEncoding object.""" super(EspnetRelPositionalEncoding, self).__init__() self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) def extend_pe(self, x: torch.Tensor): """Reset the positional encodings.""" if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). Returns: torch.Tensor: Encoded tensor (batch, time, `*`). """ self.extend_pe(x) x = x * self.xscale pos_emb = self.position_encoding(size=x.size(1), offset=offset) return self.dropout(x), self.dropout(pos_emb) def position_encoding(self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor: """ For getting encoding in a streaming fashion Attention!!!!! we apply dropout only once at the whole utterance level in a none streaming way, but will call this function several times with increasing input size in a streaming scenario, so the dropout will be applied several times. Args: offset (int or torch.tensor): start offset size (int): required size of position encoding Returns: torch.Tensor: Corresponding encoding """ pos_emb = self.pe[ :, self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, ] return pos_emb