Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import json | |
| from typing import Callable, Optional | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from einops.layers.torch import Rearrange | |
| from model.guide import GuideTransformer | |
| from model.modules.audio_encoder import Wav2VecEncoder | |
| from model.modules.rotary_embedding_torch import RotaryEmbedding | |
| from model.modules.transformer_modules import ( | |
| DecoderLayerStack, | |
| FiLMTransformerDecoderLayer, | |
| RegressionTransformer, | |
| TransformerEncoderLayerRotary, | |
| ) | |
| from model.utils import ( | |
| init_weight, | |
| PositionalEncoding, | |
| prob_mask_like, | |
| setup_lip_regressor, | |
| SinusoidalPosEmb, | |
| ) | |
| from model.vqvae import setup_tokenizer | |
| from torch.nn import functional as F | |
| from utils.misc import prGreen, prRed | |
| class Audio2LipRegressionTransformer(torch.nn.Module): | |
| def __init__( | |
| self, | |
| n_vertices: int = 338, | |
| causal: bool = False, | |
| train_wav2vec: bool = False, | |
| transformer_encoder_layers: int = 2, | |
| transformer_decoder_layers: int = 4, | |
| ): | |
| super().__init__() | |
| self.n_vertices = n_vertices | |
| self.audio_encoder = Wav2VecEncoder() | |
| if not train_wav2vec: | |
| self.audio_encoder.eval() | |
| for param in self.audio_encoder.parameters(): | |
| param.requires_grad = False | |
| self.regression_model = RegressionTransformer( | |
| transformer_encoder_layers=transformer_encoder_layers, | |
| transformer_decoder_layers=transformer_decoder_layers, | |
| d_model=512, | |
| d_cond=512, | |
| num_heads=4, | |
| causal=causal, | |
| ) | |
| self.project_output = torch.nn.Linear(512, self.n_vertices * 3) | |
| def forward(self, audio): | |
| """ | |
| :param audio: tensor of shape B x T x 1600 | |
| :return: tensor of shape B x T x n_vertices x 3 containing reconstructed lip geometry | |
| """ | |
| B, T = audio.shape[0], audio.shape[1] | |
| cond = self.audio_encoder(audio) | |
| x = torch.zeros(B, T, 512, device=audio.device) | |
| x = self.regression_model(x, cond) | |
| x = self.project_output(x) | |
| verts = x.view(B, T, self.n_vertices, 3) | |
| return verts | |
| class FiLMTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| args, | |
| nfeats: int, | |
| latent_dim: int = 512, | |
| ff_size: int = 1024, | |
| num_layers: int = 4, | |
| num_heads: int = 4, | |
| dropout: float = 0.1, | |
| cond_feature_dim: int = 4800, | |
| activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu, | |
| use_rotary: bool = True, | |
| cond_mode: str = "audio", | |
| split_type: str = "train", | |
| device: str = "cuda", | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| self.nfeats = nfeats | |
| self.cond_mode = cond_mode | |
| self.cond_feature_dim = cond_feature_dim | |
| self.add_frame_cond = args.add_frame_cond | |
| self.data_format = args.data_format | |
| self.split_type = split_type | |
| self.device = device | |
| # positional embeddings | |
| self.rotary = None | |
| self.abs_pos_encoding = nn.Identity() | |
| # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity) | |
| if use_rotary: | |
| self.rotary = RotaryEmbedding(dim=latent_dim) | |
| else: | |
| self.abs_pos_encoding = PositionalEncoding( | |
| latent_dim, dropout, batch_first=True | |
| ) | |
| # time embedding processing | |
| self.time_mlp = nn.Sequential( | |
| SinusoidalPosEmb(latent_dim), | |
| nn.Linear(latent_dim, latent_dim * 4), | |
| nn.Mish(), | |
| ) | |
| self.to_time_cond = nn.Sequential( | |
| nn.Linear(latent_dim * 4, latent_dim), | |
| ) | |
| self.to_time_tokens = nn.Sequential( | |
| nn.Linear(latent_dim * 4, latent_dim * 2), | |
| Rearrange("b (r d) -> b r d", r=2), | |
| ) | |
| # null embeddings for guidance dropout | |
| self.seq_len = args.max_seq_length | |
| emb_len = 1998 # hardcoded for now | |
| self.null_cond_embed = nn.Parameter(torch.randn(1, emb_len, latent_dim)) | |
| self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim)) | |
| self.norm_cond = nn.LayerNorm(latent_dim) | |
| self.setup_audio_models() | |
| # set up pose/face specific parts of the model | |
| self.input_projection = nn.Linear(self.nfeats, latent_dim) | |
| if self.data_format == "pose": | |
| cond_feature_dim = 1024 | |
| key_feature_dim = 104 | |
| self.step = 30 | |
| self.use_cm = True | |
| self.setup_guide_models(args, latent_dim, key_feature_dim) | |
| self.post_pose_layers = self._build_single_pose_conv(self.nfeats) | |
| self.post_pose_layers.apply(init_weight) | |
| self.final_conv = torch.nn.Conv1d(self.nfeats, self.nfeats, kernel_size=1) | |
| self.receptive_field = 25 | |
| elif self.data_format == "face": | |
| self.use_cm = False | |
| cond_feature_dim = 1024 + 1014 | |
| self.setup_lip_models() | |
| self.cond_encoder = nn.Sequential() | |
| for _ in range(2): | |
| self.cond_encoder.append( | |
| TransformerEncoderLayerRotary( | |
| d_model=latent_dim, | |
| nhead=num_heads, | |
| dim_feedforward=ff_size, | |
| dropout=dropout, | |
| activation=activation, | |
| batch_first=True, | |
| rotary=self.rotary, | |
| ) | |
| ) | |
| self.cond_encoder.apply(init_weight) | |
| self.cond_projection = nn.Linear(cond_feature_dim, latent_dim) | |
| self.non_attn_cond_projection = nn.Sequential( | |
| nn.LayerNorm(latent_dim), | |
| nn.Linear(latent_dim, latent_dim), | |
| nn.SiLU(), | |
| nn.Linear(latent_dim, latent_dim), | |
| ) | |
| # decoder | |
| decoderstack = nn.ModuleList([]) | |
| for _ in range(num_layers): | |
| decoderstack.append( | |
| FiLMTransformerDecoderLayer( | |
| latent_dim, | |
| num_heads, | |
| dim_feedforward=ff_size, | |
| dropout=dropout, | |
| activation=activation, | |
| batch_first=True, | |
| rotary=self.rotary, | |
| use_cm=self.use_cm, | |
| ) | |
| ) | |
| self.seqTransDecoder = DecoderLayerStack(decoderstack) | |
| self.seqTransDecoder.apply(init_weight) | |
| self.final_layer = nn.Linear(latent_dim, self.nfeats) | |
| self.final_layer.apply(init_weight) | |
| def _build_single_pose_conv(self, nfeats: int) -> nn.ModuleList: | |
| post_pose_layers = torch.nn.ModuleList( | |
| [ | |
| torch.nn.Conv1d(nfeats, max(256, nfeats), kernel_size=3, dilation=1), | |
| torch.nn.Conv1d(max(256, nfeats), nfeats, kernel_size=3, dilation=2), | |
| torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3), | |
| torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=1), | |
| torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=2), | |
| torch.nn.Conv1d(nfeats, nfeats, kernel_size=3, dilation=3), | |
| ] | |
| ) | |
| return post_pose_layers | |
| def _run_single_pose_conv(self, output: torch.Tensor) -> torch.Tensor: | |
| output = torch.nn.functional.pad(output, pad=[self.receptive_field - 1, 0]) | |
| for _, layer in enumerate(self.post_pose_layers): | |
| y = torch.nn.functional.leaky_relu(layer(output), negative_slope=0.2) | |
| if self.split_type == "train": | |
| y = torch.nn.functional.dropout(y, 0.2) | |
| if output.shape[1] == y.shape[1]: | |
| output = (output[:, :, -y.shape[-1] :] + y) / 2.0 # skip connection | |
| else: | |
| output = y | |
| return output | |
| def setup_guide_models(self, args, latent_dim: int, key_feature_dim: int) -> None: | |
| # set up conditioning info | |
| max_keyframe_len = len(list(range(self.seq_len))[:: self.step]) | |
| self.null_pose_embed = nn.Parameter( | |
| torch.randn(1, max_keyframe_len, latent_dim) | |
| ) | |
| prGreen(f"using keyframes: {self.null_pose_embed.shape}") | |
| self.frame_cond_projection = nn.Linear(key_feature_dim, latent_dim) | |
| self.frame_norm_cond = nn.LayerNorm(latent_dim) | |
| # for test time set up keyframe transformer | |
| self.resume_trans = None | |
| if self.split_type == "test": | |
| if hasattr(args, "resume_trans") and args.resume_trans is not None: | |
| self.resume_trans = args.resume_trans | |
| self.setup_guide_predictor(args.resume_trans) | |
| else: | |
| prRed("not using transformer, just using ground truth") | |
| def setup_guide_predictor(self, cp_path: str) -> None: | |
| cp_dir = cp_path.split("checkpoints/iter-")[0] | |
| with open(f"{cp_dir}/args.json") as f: | |
| trans_args = json.load(f) | |
| # set up tokenizer based on trans_arg load point | |
| self.tokenizer = setup_tokenizer(trans_args["resume_pth"]) | |
| # set up transformer | |
| self.transformer = GuideTransformer( | |
| tokens=self.tokenizer.n_clusters, | |
| num_layers=trans_args["layers"], | |
| dim=trans_args["dim"], | |
| emb_len=1998, | |
| num_audio_layers=trans_args["num_audio_layers"], | |
| ) | |
| for param in self.transformer.parameters(): | |
| param.requires_grad = False | |
| prGreen("loading TRANSFORMER checkpoint from {}".format(cp_path)) | |
| cp = torch.load(cp_path) | |
| missing_keys, unexpected_keys = self.transformer.load_state_dict( | |
| cp["model_state_dict"], strict=False | |
| ) | |
| assert len(missing_keys) == 0, missing_keys | |
| assert len(unexpected_keys) == 0, unexpected_keys | |
| def setup_audio_models(self) -> None: | |
| self.audio_model, self.audio_resampler = setup_lip_regressor() | |
| def setup_lip_models(self) -> None: | |
| self.lip_model = Audio2LipRegressionTransformer() | |
| cp_path = "./assets/iter-0200000.pt" | |
| cp = torch.load(cp_path, map_location=torch.device(self.device)) | |
| self.lip_model.load_state_dict(cp["model_state_dict"]) | |
| for param in self.lip_model.parameters(): | |
| param.requires_grad = False | |
| prGreen(f"adding lip conditioning {cp_path}") | |
| def parameters_w_grad(self): | |
| return [p for p in self.parameters() if p.requires_grad] | |
| def encode_audio(self, raw_audio: torch.Tensor) -> torch.Tensor: | |
| device = next(self.parameters()).device | |
| a0 = self.audio_resampler(raw_audio[:, :, 0].to(device)) | |
| a1 = self.audio_resampler(raw_audio[:, :, 1].to(device)) | |
| with torch.no_grad(): | |
| z0 = self.audio_model.feature_extractor(a0) | |
| z1 = self.audio_model.feature_extractor(a1) | |
| emb = torch.cat((z0, z1), axis=1).permute(0, 2, 1) | |
| return emb | |
| def encode_lip(self, audio: torch.Tensor, cond_embed: torch.Tensor) -> torch.Tensor: | |
| reshaped_audio = audio.reshape((audio.shape[0], -1, 1600, 2))[..., 0] | |
| # processes 4 seconds at a time | |
| B, T, _ = reshaped_audio.shape | |
| lip_cond = torch.zeros( | |
| (audio.shape[0], T, 338, 3), | |
| device=audio.device, | |
| dtype=audio.dtype, | |
| ) | |
| for i in range(0, T, 120): | |
| lip_cond[:, i : i + 120, ...] = self.lip_model( | |
| reshaped_audio[:, i : i + 120, ...] | |
| ) | |
| lip_cond = lip_cond.permute(0, 2, 3, 1).reshape((B, 338 * 3, -1)) | |
| lip_cond = torch.nn.functional.interpolate( | |
| lip_cond, size=cond_embed.shape[1], mode="nearest-exact" | |
| ).permute(0, 2, 1) | |
| cond_embed = torch.cat((cond_embed, lip_cond), dim=-1) | |
| return cond_embed | |
| def encode_keyframes( | |
| self, y: torch.Tensor, cond_drop_prob: float, batch_size: int | |
| ) -> torch.Tensor: | |
| pred = y["keyframes"] | |
| new_mask = y["mask"][..., :: self.step].squeeze((1, 2)) | |
| pred[~new_mask] = 0.0 # pad the unknown | |
| pose_hidden = self.frame_cond_projection(pred.detach().clone().cuda()) | |
| pose_embed = self.abs_pos_encoding(pose_hidden) | |
| pose_tokens = self.frame_norm_cond(pose_embed) | |
| # do conditional dropout for guide poses | |
| key_cond_drop_prob = cond_drop_prob | |
| keep_mask_pose = prob_mask_like( | |
| (batch_size,), 1 - key_cond_drop_prob, device=pose_tokens.device | |
| ) | |
| keep_mask_pose_embed = rearrange(keep_mask_pose, "b -> b 1 1") | |
| null_pose_embed = self.null_pose_embed.to(pose_tokens.dtype) | |
| pose_tokens = torch.where( | |
| keep_mask_pose_embed, | |
| pose_tokens, | |
| null_pose_embed[:, : pose_tokens.shape[1], :], | |
| ) | |
| return pose_tokens | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| times: torch.Tensor, | |
| y: Optional[torch.Tensor] = None, | |
| cond_drop_prob: float = 0.0, | |
| ) -> torch.Tensor: | |
| if x.dim() == 4: | |
| x = x.permute(0, 3, 1, 2).squeeze(-1) | |
| batch_size, device = x.shape[0], x.device | |
| if self.cond_mode == "uncond": | |
| cond_embed = torch.zeros( | |
| (x.shape[0], x.shape[1], self.cond_feature_dim), | |
| dtype=x.dtype, | |
| device=x.device, | |
| ) | |
| else: | |
| cond_embed = y["audio"] | |
| cond_embed = self.encode_audio(cond_embed) | |
| if self.data_format == "face": | |
| cond_embed = self.encode_lip(y["audio"], cond_embed) | |
| pose_tokens = None | |
| if self.data_format == "pose": | |
| pose_tokens = self.encode_keyframes(y, cond_drop_prob, batch_size) | |
| assert cond_embed is not None, "cond emb should not be none" | |
| # process conditioning information | |
| x = self.input_projection(x) | |
| x = self.abs_pos_encoding(x) | |
| audio_cond_drop_prob = cond_drop_prob | |
| keep_mask = prob_mask_like( | |
| (batch_size,), 1 - audio_cond_drop_prob, device=device | |
| ) | |
| keep_mask_embed = rearrange(keep_mask, "b -> b 1 1") | |
| keep_mask_hidden = rearrange(keep_mask, "b -> b 1") | |
| cond_tokens = self.cond_projection(cond_embed) | |
| cond_tokens = self.abs_pos_encoding(cond_tokens) | |
| if self.data_format == "face": | |
| cond_tokens = self.cond_encoder(cond_tokens) | |
| null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype) | |
| cond_tokens = torch.where( | |
| keep_mask_embed, cond_tokens, null_cond_embed[:, : cond_tokens.shape[1], :] | |
| ) | |
| mean_pooled_cond_tokens = cond_tokens.mean(dim=-2) | |
| cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens) | |
| # create t conditioning | |
| t_hidden = self.time_mlp(times) | |
| t = self.to_time_cond(t_hidden) | |
| t_tokens = self.to_time_tokens(t_hidden) | |
| null_cond_hidden = self.null_cond_hidden.to(t.dtype) | |
| cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden) | |
| t += cond_hidden | |
| # cross-attention conditioning | |
| c = torch.cat((cond_tokens, t_tokens), dim=-2) | |
| cond_tokens = self.norm_cond(c) | |
| # Pass through the transformer decoder | |
| output = self.seqTransDecoder(x, cond_tokens, t, memory2=pose_tokens) | |
| output = self.final_layer(output) | |
| if self.data_format == "pose": | |
| output = output.permute(0, 2, 1) | |
| output = self._run_single_pose_conv(output) | |
| output = self.final_conv(output) | |
| output = output.permute(0, 2, 1) | |
| return output | |