Spaces:
Sleeping
Sleeping
| ### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py | |
| import os | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torch import Tensor, nn | |
| from typing import Dict, Iterable, Optional | |
| # hard-coded audio hyperparameters | |
| SAMPLE_RATE = 16000 | |
| N_FFT = 1024 | |
| N_MELS = 128 | |
| HOP_LENGTH = int(0.01 * SAMPLE_RATE) | |
| DURATION = 10 | |
| N_SAMPLES = int(DURATION * SAMPLE_RATE) | |
| N_FRAMES = N_SAMPLES // HOP_LENGTH + 1 | |
| def sinusoids(length, channels, max_timescale=10000): | |
| """Returns sinusoids for positional embedding""" | |
| log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) | |
| inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) | |
| scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] | |
| return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) | |
| class MelEncoder(nn.Module): | |
| """ | |
| time-frequency represntation | |
| """ | |
| def __init__(self, | |
| sample_rate= 16000, | |
| f_min=0, | |
| f_max=8000, | |
| n_fft=1024, | |
| win_length=1024, | |
| hop_length = int(0.01 * 16000), | |
| n_mels = 128, | |
| power = None, | |
| pad= 0, | |
| normalized= False, | |
| center= True, | |
| pad_mode= "reflect" | |
| ): | |
| super(MelEncoder, self).__init__() | |
| self.window = torch.hann_window(win_length) | |
| self.spec_fn = torchaudio.transforms.Spectrogram( | |
| n_fft = n_fft, | |
| win_length = win_length, | |
| hop_length = hop_length, | |
| power = power | |
| ) | |
| self.mel_scale = torchaudio.transforms.MelScale( | |
| n_mels, | |
| sample_rate, | |
| f_min, | |
| f_max, | |
| n_fft // 2 + 1) | |
| self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() | |
| def forward(self, wav): | |
| spec = self.spec_fn(wav) | |
| power_spec = spec.real.abs().pow(2) | |
| mel_spec = self.mel_scale(power_spec) | |
| mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin)) | |
| return mel_spec | |
| class AudioEncoder(nn.Module): | |
| def __init__( | |
| self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int, | |
| ): | |
| super().__init__() | |
| self.mel_encoder = MelEncoder(n_mels=n_mels) | |
| self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1) | |
| self.conv_stack = nn.ModuleList([]) | |
| for _ in range(num_of_stride_conv): | |
| self.conv_stack.append( | |
| nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1) | |
| ) | |
| # self.proj = nn.Linear(audio_dim, text_dim, bias=False) | |
| self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim)) | |
| def forward(self, x: Tensor): | |
| """ | |
| x : torch.Tensor, shape = (batch_size, waveform) | |
| single channel wavform | |
| """ | |
| x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx) | |
| x = F.gelu(self.conv1(x)) | |
| for conv in self.conv_stack: | |
| x = F.gelu(conv(x)) | |
| x = x.permute(0, 2, 1) | |
| x = (x + self.positional_embedding).to(x.dtype) | |
| return x |