Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import os | |
| from typing import Dict, Iterable, List, Union | |
| import numpy as np | |
| import torch | |
| from torch.utils import data | |
| from utils.misc import prGreen | |
| class Social(data.Dataset): | |
| def __init__( | |
| self, | |
| args, | |
| data_dict: Dict[str, Iterable], | |
| split: str = "train", | |
| chunk: bool = False, | |
| add_padding: bool = True, | |
| ) -> None: | |
| if args.data_format == "face": | |
| prGreen("[dataset.py] training face only model") | |
| data_dict["data"] = data_dict["face"] | |
| elif args.data_format == "pose": | |
| prGreen("[dataset.py] training pose only model") | |
| missing = [] | |
| for d in data_dict["data"]: | |
| missing.append(np.ones_like(d)) | |
| data_dict["missing"] = missing | |
| # set up variables for dataloader | |
| self.data_format = args.data_format | |
| self.add_frame_cond = args.add_frame_cond | |
| self._register_keyframe_step() | |
| self.data_root = args.data_root | |
| self.max_seq_length = args.max_seq_length | |
| if hasattr(args, "curr_seq_length") and args.curr_seq_length is not None: | |
| self.max_seq_length = args.curr_seq_length | |
| prGreen([f"[dataset.py] sequences of {self.max_seq_length}"]) | |
| self.add_padding = add_padding | |
| self.audio_per_frame = 1600 | |
| self.max_audio_length = self.max_seq_length * self.audio_per_frame | |
| self.min_seq_length = 400 | |
| # set up training/validation splits | |
| train_idx = list(range(0, len(data_dict["data"]) - 6)) | |
| val_idx = list(range(len(data_dict["data"]) - 6, len(data_dict["data"]) - 4)) | |
| test_idx = list(range(len(data_dict["data"]) - 4, len(data_dict["data"]))) | |
| self.split = split | |
| if split == "train": | |
| self._pick_sequences(data_dict, train_idx) | |
| elif split == "val": | |
| self._pick_sequences(data_dict, val_idx) | |
| else: | |
| self._pick_sequences(data_dict, test_idx) | |
| self.chunk = chunk | |
| if split == "test": | |
| print("[dataset.py] chunking data...") | |
| self._chunk_data() | |
| self._load_std() | |
| prGreen( | |
| f"[dataset.py] {split} | {len(self.data)} sequences ({self.data[0].shape}) | total len {self.total_len}" | |
| ) | |
| def inv_transform( | |
| self, data: Union[np.ndarray, torch.Tensor], data_type: str | |
| ) -> Union[np.ndarray, torch.Tensor]: | |
| if data_type == "pose": | |
| std = self.std | |
| mean = self.mean | |
| elif data_type == "face": | |
| std = self.face_std | |
| mean = self.face_mean | |
| elif data_type == "audio": | |
| std = self.audio_std | |
| mean = self.audio_mean | |
| else: | |
| assert False, f"datatype not defined: {data_type}" | |
| if torch.is_tensor(data): | |
| return data * torch.tensor( | |
| std, device=data.device, requires_grad=False | |
| ) + torch.tensor(mean, device=data.device, requires_grad=False) | |
| else: | |
| return data * std + mean | |
| def _pick_sequences(self, data_dict: Dict[str, Iterable], idx: List[int]) -> None: | |
| self.data = np.take(data_dict["data"], idx, axis=0) | |
| self.missing = np.take(data_dict["missing"], idx, axis=0) | |
| self.audio = np.take(data_dict["audio"], idx, axis=0) | |
| self.lengths = np.take(data_dict["lengths"], idx, axis=0) | |
| self.total_len = sum([len(d) for d in self.data]) | |
| def _load_std(self) -> None: | |
| stats = torch.load(os.path.join(self.data_root, "data_stats.pth")) | |
| print( | |
| f'[dataset.py] loading from... {os.path.join(self.data_root, "data_stats.pth")}' | |
| ) | |
| self.mean = stats["pose_mean"].reshape(-1) | |
| self.std = stats["pose_std"].reshape(-1) | |
| self.face_mean = stats["code_mean"] | |
| self.face_std = stats["code_std"] | |
| self.audio_mean = stats["audio_mean"] | |
| self.audio_std = stats["audio_std_flat"] | |
| def _chunk_data(self) -> None: | |
| chunk_data = [] | |
| chunk_missing = [] | |
| chunk_lengths = [] | |
| chunk_audio = [] | |
| # create sequences of set lengths | |
| for d_idx in range(len(self.data)): | |
| curr_data = self.data[d_idx] | |
| curr_missing = self.missing[d_idx] | |
| curr_audio = self.audio[d_idx] | |
| end_range = len(self.data[d_idx]) - self.max_seq_length | |
| for chunk_idx in range(0, end_range, self.max_seq_length): | |
| chunk_end = chunk_idx + self.max_seq_length | |
| curr_data_chunk = curr_data[chunk_idx:chunk_end, :] | |
| curr_missing_chunk = curr_missing[chunk_idx:chunk_end, :] | |
| curr_audio_chunk = curr_audio[ | |
| chunk_idx * self.audio_per_frame : chunk_end * self.audio_per_frame, | |
| :, | |
| ] | |
| if curr_data_chunk.shape[0] < self.max_seq_length: | |
| # do not add a short chunk to the list | |
| continue | |
| chunk_lengths.append(curr_data_chunk.shape[0]) | |
| chunk_data.append(curr_data_chunk) | |
| chunk_missing.append(curr_missing_chunk) | |
| chunk_audio.append(curr_audio_chunk) | |
| idx = np.random.permutation(len(chunk_data)) | |
| print("==> shuffle", idx) | |
| self.data = np.take(chunk_data, idx, axis=0) | |
| self.missing = np.take(chunk_missing, idx, axis=0) | |
| self.lengths = np.take(chunk_lengths, idx, axis=0) | |
| self.audio = np.take(chunk_audio, idx, axis=0) | |
| self.total_len = len(self.data) | |
| def _register_keyframe_step(self) -> None: | |
| if self.add_frame_cond == 1: | |
| self.step = 30 | |
| if self.add_frame_cond is None: | |
| self.step = 1 | |
| def _pad_sequence( | |
| self, sequence: np.ndarray, actual_length: int, max_length: int | |
| ) -> np.ndarray: | |
| sequence = np.concatenate( | |
| ( | |
| sequence, | |
| np.zeros((max_length - actual_length, sequence.shape[-1])), | |
| ), | |
| axis=0, | |
| ) | |
| return sequence | |
| def _get_idx(self, item: int) -> int: | |
| cumulative_len = 0 | |
| seq_idx = 0 | |
| while item > cumulative_len: | |
| cumulative_len += len(self.data[seq_idx]) | |
| seq_idx += 1 | |
| item = seq_idx - 1 | |
| return item | |
| def _get_random_subsection( | |
| self, data_dict: Dict[str, Iterable] | |
| ) -> Dict[str, np.ndarray]: | |
| isnonzero = False | |
| while not isnonzero: | |
| start = np.random.randint(0, data_dict["m_length"] - self.max_seq_length) | |
| if self.add_padding: | |
| length = ( | |
| np.random.randint(self.min_seq_length, self.max_seq_length) | |
| if not self.split == "test" | |
| else self.max_seq_length | |
| ) | |
| else: | |
| length = self.max_seq_length | |
| curr_missing = data_dict["missing"][start : start + length] | |
| isnonzero = np.any(curr_missing) | |
| missing = curr_missing | |
| motion = data_dict["motion"][start : start + length, :] | |
| keyframes = motion[:: self.step] | |
| audio = data_dict["audio"][ | |
| start * self.audio_per_frame : (start + length) * self.audio_per_frame, | |
| :, | |
| ] | |
| data_dict["m_length"] = len(motion) | |
| data_dict["k_length"] = len(keyframes) | |
| data_dict["a_length"] = len(audio) | |
| if data_dict["m_length"] < self.max_seq_length: | |
| motion = self._pad_sequence( | |
| motion, data_dict["m_length"], self.max_seq_length | |
| ) | |
| missing = self._pad_sequence( | |
| missing, data_dict["m_length"], self.max_seq_length | |
| ) | |
| audio = self._pad_sequence( | |
| audio, data_dict["a_length"], self.max_audio_length | |
| ) | |
| max_step_length = len(np.zeros(self.max_seq_length)[:: self.step]) | |
| keyframes = self._pad_sequence( | |
| keyframes, data_dict["k_length"], max_step_length | |
| ) | |
| data_dict["motion"] = motion | |
| data_dict["keyframes"] = keyframes | |
| data_dict["audio"] = audio | |
| data_dict["missing"] = missing | |
| return data_dict | |
| def __len__(self) -> int: | |
| return self.total_len | |
| def __getitem__(self, item: int) -> Dict[str, np.ndarray]: | |
| # figure out which sequence to randomly sample from | |
| if not self.split == "test": | |
| item = self._get_idx(item) | |
| motion = self.data[item] | |
| audio = self.audio[item] | |
| m_length = self.lengths[item] | |
| missing = self.missing[item] | |
| a_length = len(audio) | |
| # Z Normalization | |
| if self.data_format == "pose": | |
| motion = (motion - self.mean) / self.std | |
| elif self.data_format == "face": | |
| motion = (motion - self.face_mean) / self.face_std | |
| audio = (audio - self.audio_mean) / self.audio_std | |
| keyframes = motion[:: self.step] | |
| k_length = len(keyframes) | |
| data_dict = { | |
| "motion": motion, | |
| "m_length": m_length, | |
| "audio": audio, | |
| "a_length": a_length, | |
| "keyframes": keyframes, | |
| "k_length": k_length, | |
| "missing": missing, | |
| } | |
| if not self.split == "test" and not self.chunk: | |
| data_dict = self._get_random_subsection(data_dict) | |
| if self.data_format == "face": | |
| data_dict["motion"] *= data_dict["missing"] | |
| return data_dict | |