Spaces:
Running
Running
| import torchvision | |
| from einops import rearrange | |
| import numpy as np | |
| import math | |
| import torchaudio | |
| import torch | |
| import importlib | |
| from data_utils import create_masks_from_landmarks_box | |
| import torch.nn.functional as F | |
| def save_audio_video( | |
| video, | |
| audio=None, | |
| frame_rate=25, | |
| sample_rate=16000, | |
| save_path="temp.mp4", | |
| ): | |
| """Save audio and video to a single file. | |
| video: (t, c, h, w) | |
| audio: (channels t) | |
| """ | |
| save_path = str(save_path) | |
| if isinstance(video, torch.Tensor): | |
| video = video.cpu().numpy() | |
| video_tensor = rearrange(video, "t c h w -> t h w c").astype(np.uint8) | |
| print("video_tensor shape", video_tensor.shape) | |
| print("audio shape", audio.shape) | |
| if audio is not None: | |
| # Assuming audio is a tensor of shape (channels, samples) | |
| audio_tensor = audio | |
| torchvision.io.write_video( | |
| save_path, | |
| video_tensor, | |
| fps=frame_rate, | |
| audio_array=audio_tensor, | |
| audio_fps=sample_rate, | |
| video_codec="h264", # Specify a codec to address the error | |
| audio_codec="aac", | |
| ) | |
| else: | |
| torchvision.io.write_video( | |
| save_path, | |
| video_tensor, | |
| fps=frame_rate, | |
| video_codec="h264", # Specify a codec to address the error | |
| audio_codec="aac", | |
| ) | |
| return save_path | |
| def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None): | |
| len_file = audio.shape[-1] | |
| if max_len_sec or max_len_raw: | |
| max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr) | |
| if len_file < int(max_len): | |
| # dummy = np.zeros((1, int(max_len_sec * sr) - len_file)) | |
| # extened_wav = np.concatenate((audio_data, dummy[0])) | |
| extened_wav = torch.nn.functional.pad( | |
| audio, (0, int(max_len) - len_file), "constant" | |
| ) | |
| else: | |
| extened_wav = audio[:, : int(max_len)] | |
| else: | |
| extened_wav = audio | |
| return extened_wav | |
| def get_raw_audio(audio_path, audio_rate, fps=25): | |
| audio, sr = torchaudio.load(audio_path, channels_first=True) | |
| if audio.shape[0] > 1: | |
| audio = audio.mean(0, keepdim=True) | |
| audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=audio_rate)[0] | |
| samples_per_frame = math.ceil(audio_rate / fps) | |
| n_frames = audio.shape[-1] / samples_per_frame | |
| if not n_frames.is_integer(): | |
| audio = trim_pad_audio( | |
| audio, audio_rate, max_len_raw=math.ceil(n_frames) * samples_per_frame | |
| ) | |
| audio = rearrange(audio, "(f s) -> f s", s=samples_per_frame) | |
| return audio | |
| def calculate_splits(tensor, min_last_size): | |
| # Check the total number of elements in the tensor | |
| total_size = tensor.size(1) # size along the second dimension | |
| # If total size is less than the minimum size for the last split, return the tensor as a single split | |
| if total_size <= min_last_size: | |
| return [tensor] | |
| # Calculate number of splits and size of each split | |
| num_splits = (total_size - min_last_size) // min_last_size + 1 | |
| base_size = (total_size - min_last_size) // num_splits | |
| # Create split sizes list | |
| split_sizes = [base_size] * (num_splits - 1) | |
| split_sizes.append( | |
| total_size - sum(split_sizes) | |
| ) # Ensure the last split has at least min_last_size | |
| # Adjust sizes to ensure they sum exactly to total_size | |
| sum_sizes = sum(split_sizes) | |
| while sum_sizes != total_size: | |
| for i in range(num_splits): | |
| if sum_sizes < total_size: | |
| split_sizes[i] += 1 | |
| sum_sizes += 1 | |
| if sum_sizes >= total_size: | |
| break | |
| # Split the tensor | |
| splits = torch.split(tensor, split_sizes, dim=1) | |
| return splits | |
| def make_into_multiple_of(x, multiple, dim=0): | |
| """Make the torch tensor into a multiple of the given number.""" | |
| if x.shape[dim] % multiple != 0: | |
| x = torch.cat( | |
| [ | |
| x, | |
| torch.zeros( | |
| *x.shape[:dim], | |
| multiple - (x.shape[dim] % multiple), | |
| *x.shape[dim + 1 :], | |
| ).to(x.device), | |
| ], | |
| dim=dim, | |
| ) | |
| return x | |
| def default(value, default_value): | |
| return default_value if value is None else value | |
| def instantiate_from_config(config): | |
| if not "target" in config: | |
| if config == "__is_first_stage__": | |
| return None | |
| elif config == "__is_unconditional__": | |
| return None | |
| raise KeyError("Expected key `target` to instantiate.") | |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
| def get_obj_from_str(string, reload=False, invalidate_cache=True): | |
| module, cls = string.rsplit(".", 1) | |
| if invalidate_cache: | |
| importlib.invalidate_caches() | |
| if reload: | |
| module_imp = importlib.import_module(module) | |
| importlib.reload(module_imp) | |
| return getattr(importlib.import_module(module, package=None), cls) | |
| def load_landmarks( | |
| landmarks: np.ndarray, | |
| original_size, | |
| target_size=(64, 64), | |
| nose_index=28, | |
| ): | |
| """ | |
| Load and process facial landmarks to create masks. | |
| Args: | |
| landmarks: Facial landmarks array | |
| original_size: Original size of the video frames | |
| index: Index for non-dub mode | |
| target_size: Target size for the output mask | |
| is_dub: Whether this is for dubbing mode | |
| what_mask: Type of mask to create ("full", "box", "heart", "mouth") | |
| nose_index: Index of the nose landmark | |
| Returns: | |
| Processed landmarks mask | |
| """ | |
| expand_box = 0.0 | |
| if len(landmarks.shape) == 2: | |
| landmarks = landmarks[None, ...] | |
| mask = create_masks_from_landmarks_box( | |
| landmarks, | |
| (original_size[0], original_size[1]), | |
| box_expand=expand_box, | |
| nose_index=nose_index, | |
| ) | |
| mask = F.interpolate(mask.unsqueeze(1).float(), size=target_size, mode="nearest") | |
| return mask | |
| def create_pipeline_inputs( | |
| audio: torch.Tensor, | |
| audio_interpolation: torch.Tensor, | |
| num_frames: int, | |
| video_emb: torch.Tensor, | |
| landmarks: np.ndarray, | |
| overlap: int = 1, | |
| add_zero_flag: bool = False, | |
| mask_arms: bool = None, | |
| nose_index: int = 28, | |
| ): | |
| """ | |
| Create inputs for the keyframe generation and interpolation pipeline. | |
| Args: | |
| video: Input video tensor | |
| audio: Audio embeddings for keyframe generation | |
| audio_interpolation: Audio embeddings for interpolation | |
| num_frames: Number of frames per segment | |
| video_emb: Optional video embeddings | |
| landmarks: Facial landmarks for mask generation | |
| overlap: Number of frames to overlap between segments | |
| add_zero_flag: Whether to add zero flag every num_frames | |
| what_mask: Type of mask to generate ("box" or other options) | |
| mask_arms: Optional mask for arms region | |
| nose_index: Index of the nose landmark point | |
| Returns: | |
| Tuple containing all necessary inputs for the pipeline | |
| """ | |
| audio_interpolation_chunks = [] | |
| audio_image_preds = [] | |
| gt_chunks = [] | |
| gt_keyframes_chunks = [] | |
| # Adjustment for overlap to ensure segments are created properly | |
| step = num_frames - overlap | |
| # Ensure there's at least one step forward on each iteration | |
| if step < 1: | |
| step = 1 | |
| audio_image_preds_idx = [] | |
| audio_interp_preds_idx = [] | |
| masks_chunks = [] | |
| masks_interpolation_chunks = [] | |
| for i in range(0, audio.shape[0] - num_frames + 1, step): | |
| try: | |
| audio[i + num_frames - 1] | |
| except IndexError: | |
| break # Last chunk is smaller than num_frames | |
| segment_end = i + num_frames | |
| gt_chunks.append(video_emb[i:segment_end]) | |
| masks = load_landmarks( | |
| landmarks[i:segment_end], | |
| (512, 512), | |
| target_size=(64, 64), | |
| nose_index=nose_index, | |
| ) | |
| if mask_arms is not None: | |
| masks = np.logical_and( | |
| masks, np.logical_not(mask_arms[i:segment_end, None, ...]) | |
| ) | |
| masks_interpolation_chunks.append(masks) | |
| if i not in audio_image_preds_idx: | |
| audio_image_preds.append(audio[i]) | |
| masks_chunks.append(masks[0]) | |
| gt_keyframes_chunks.append(video_emb[i]) | |
| audio_image_preds_idx.append(i) | |
| if segment_end - 1 not in audio_image_preds_idx: | |
| audio_image_preds_idx.append(segment_end - 1) | |
| audio_image_preds.append(audio[segment_end - 1]) | |
| masks_chunks.append(masks[-1]) | |
| gt_keyframes_chunks.append(video_emb[segment_end - 1]) | |
| audio_interpolation_chunks.append(audio_interpolation[i:segment_end]) | |
| audio_interp_preds_idx.append([i, segment_end - 1]) | |
| # If the flag is on, add element 0 every 14 audio elements | |
| if add_zero_flag: | |
| first_element = audio_image_preds[0] | |
| len_audio_image_preds = ( | |
| len(audio_image_preds) + (len(audio_image_preds) + 1) % num_frames | |
| ) | |
| for i in range(0, len_audio_image_preds, num_frames): | |
| audio_image_preds.insert(i, first_element) | |
| audio_image_preds_idx.insert(i, None) | |
| masks_chunks.insert(i, masks_chunks[0]) | |
| gt_keyframes_chunks.insert(i, gt_keyframes_chunks[0]) | |
| to_remove = [idx is None for idx in audio_image_preds_idx] | |
| audio_image_preds_idx_clone = [idx for idx in audio_image_preds_idx] | |
| if add_zero_flag: | |
| # Remove the added elements from the list | |
| audio_image_preds_idx = [ | |
| sample for i, sample in zip(to_remove, audio_image_preds_idx) if not i | |
| ] | |
| interpolation_cond_list = [] | |
| for i in range(0, len(audio_image_preds_idx) - 1, overlap if overlap > 0 else 2): | |
| interpolation_cond_list.append( | |
| [audio_image_preds_idx[i], audio_image_preds_idx[i + 1]] | |
| ) | |
| # Since we generate num_frames at a time, we need to ensure that the last chunk is of size num_frames | |
| # Calculate the number of frames needed to make audio_image_preds a multiple of num_frames | |
| frames_needed = (num_frames - (len(audio_image_preds) % num_frames)) % num_frames | |
| # Extend from the start of audio_image_preds | |
| audio_image_preds = audio_image_preds + [audio_image_preds[-1]] * frames_needed | |
| masks_chunks = masks_chunks + [masks_chunks[-1]] * frames_needed | |
| gt_keyframes_chunks = ( | |
| gt_keyframes_chunks + [gt_keyframes_chunks[-1]] * frames_needed | |
| ) | |
| to_remove = to_remove + [True] * frames_needed | |
| audio_image_preds_idx_clone = ( | |
| audio_image_preds_idx_clone + [audio_image_preds_idx_clone[-1]] * frames_needed | |
| ) | |
| print( | |
| f"Added {frames_needed} frames from the start to make audio_image_preds a multiple of {num_frames}" | |
| ) | |
| # random_cond_idx = np.random.randint(0, len(video_emb)) | |
| random_cond_idx = 0 | |
| assert len(to_remove) == len(audio_image_preds), ( | |
| "to_remove and audio_image_preds must have the same length" | |
| ) | |
| return ( | |
| gt_chunks, | |
| gt_keyframes_chunks, | |
| audio_interpolation_chunks, | |
| audio_image_preds, | |
| video_emb[random_cond_idx], | |
| masks_chunks, | |
| masks_interpolation_chunks, | |
| to_remove, | |
| audio_interp_preds_idx, | |
| audio_image_preds_idx_clone, | |
| ) | |