Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import math | |
| import fairseq | |
| import numpy as np | |
| import torch | |
| import torchaudio.transforms as T | |
| from torch import nn | |
| def setup_lip_regressor() -> ("Audio2LipRegressionTransformer", T.Resample): | |
| cp_path = "./assets/vq-wav2vec.pt" | |
| audio_model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) | |
| audio_model = audio_model[0] | |
| for param in audio_model.parameters(): | |
| param.requires_grad = False | |
| audio_model.eval() | |
| audio_resampler = T.Resample(48000, 16000) | |
| return audio_model, audio_resampler | |
| def init_weight(m): | |
| if ( | |
| isinstance(m, nn.Conv1d) | |
| or isinstance(m, nn.Linear) | |
| or isinstance(m, nn.ConvTranspose1d) | |
| ): | |
| nn.init.xavier_normal_(m.weight) | |
| # m.bias.data.fill_(0.01) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| # absolute positional embedding used for vanilla transformer sequential data | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, dropout=0.1, max_len=800, batch_first=False): | |
| super().__init__() | |
| self.batch_first = batch_first | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x): | |
| if self.batch_first: | |
| x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] | |
| else: | |
| x = x + self.pe[: x.shape[0], :] | |
| return self.dropout(x) | |
| # very similar positional embedding used for diffusion timesteps | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = x[:, None] * emb[None, :] | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| # dropout mask | |
| def prob_mask_like(shape, prob, device): | |
| if prob == 1: | |
| return torch.ones(shape, device=device, dtype=torch.bool) | |
| elif prob == 0: | |
| return torch.zeros(shape, device=device, dtype=torch.bool) | |
| else: | |
| return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob | |
| def extract(a, t, x_shape): | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| def make_beta_schedule( | |
| schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 | |
| ): | |
| if schedule == "linear": | |
| betas = ( | |
| torch.linspace( | |
| linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 | |
| ) | |
| ** 2 | |
| ) | |
| elif schedule == "cosine": | |
| timesteps = ( | |
| torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s | |
| ) | |
| alphas = timesteps / (1 + cosine_s) * np.pi / 2 | |
| alphas = torch.cos(alphas).pow(2) | |
| alphas = alphas / alphas[0] | |
| betas = 1 - alphas[1:] / alphas[:-1] | |
| betas = np.clip(betas, a_min=0, a_max=0.999) | |
| elif schedule == "sqrt_linear": | |
| betas = torch.linspace( | |
| linear_start, linear_end, n_timestep, dtype=torch.float64 | |
| ) | |
| elif schedule == "sqrt": | |
| betas = ( | |
| torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) | |
| ** 0.5 | |
| ) | |
| else: | |
| raise ValueError(f"schedule '{schedule}' unknown.") | |
| return betas.numpy() | |