Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import os | |
| from typing import Dict, List | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from data_loaders.data import Social | |
| from data_loaders.tensors import social_collate | |
| from torch.utils.data import DataLoader | |
| from utils.misc import prGreen | |
| def get_dataset_loader( | |
| args, | |
| data_dict: Dict[str, np.ndarray], | |
| split: str = "train", | |
| chunk: bool = False, | |
| add_padding: bool = True, | |
| ) -> DataLoader: | |
| dataset = Social( | |
| args=args, | |
| data_dict=data_dict, | |
| split=split, | |
| chunk=chunk, | |
| add_padding=add_padding, | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=not split == "test", | |
| num_workers=8, | |
| drop_last=True, | |
| collate_fn=social_collate, | |
| pin_memory=True, | |
| ) | |
| return loader | |
| def _load_pose_data( | |
| all_paths: List[str], audio_per_frame: int, flip_person: bool = False | |
| ) -> Dict[str, List]: | |
| data = [] | |
| face = [] | |
| audio = [] | |
| lengths = [] | |
| missing = [] | |
| for _, curr_path_name in enumerate(all_paths): | |
| if not curr_path_name.endswith("_body_pose.npy"): | |
| continue | |
| # load face information and deal with missing codes | |
| curr_code = np.load( | |
| curr_path_name.replace("_body_pose.npy", "_face_expression.npy") | |
| ).astype(float) | |
| # curr_code = np.array(curr_face["codes"], dtype=float) | |
| missing_list = np.load( | |
| curr_path_name.replace("_body_pose.npy", "_missing_face_frames.npy") | |
| ) | |
| if len(missing_list) == len(curr_code): | |
| print("skipping", curr_path_name, curr_code.shape) | |
| continue | |
| curr_missing = np.ones_like(curr_code) | |
| curr_missing[missing_list] = 0.0 | |
| # load pose information and deal with discontinuities | |
| curr_pose = np.load(curr_path_name) | |
| if "PXB184" in curr_path_name or "RLW104" in curr_path_name: # Capture 1 or 2 | |
| curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi) | |
| curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi) | |
| # load audio information | |
| curr_audio, _ = torchaudio.load( | |
| curr_path_name.replace("_body_pose.npy", "_audio.wav") | |
| ) | |
| curr_audio = curr_audio.T | |
| if flip_person: | |
| prGreen("[get_data.py] flipping the dataset of left right person") | |
| tmp = torch.zeros_like(curr_audio) | |
| tmp[:, 1] = curr_audio[:, 0] | |
| tmp[:, 0] = curr_audio[:, 1] | |
| curr_audio = tmp | |
| assert len(curr_pose) * audio_per_frame == len( | |
| curr_audio | |
| ), f"motion {curr_pose.shape} vs audio {curr_audio.shape}" | |
| data.append(curr_pose) | |
| face.append(curr_code) | |
| missing.append(curr_missing) | |
| audio.append(curr_audio) | |
| lengths.append(len(curr_pose)) | |
| data_dict = { | |
| "data": data, | |
| "face": face, | |
| "audio": audio, | |
| "lengths": lengths, | |
| "missing": missing, | |
| } | |
| return data_dict | |
| def load_local_data( | |
| data_root: str, audio_per_frame: int, flip_person: bool = False | |
| ) -> Dict[str, List]: | |
| if flip_person: | |
| if "PXB184" in data_root: | |
| data_root = data_root.replace("PXB184", "RLW104") | |
| elif "RLW104" in data_root: | |
| data_root = data_root.replace("RLW104", "PXB184") | |
| elif "TXB805" in data_root: | |
| data_root = data_root.replace("TXB805", "GQS883") | |
| elif "GQS883" in data_root: | |
| data_root = data_root.replace("GQS883", "TXB805") | |
| all_paths = [os.path.join(data_root, x) for x in os.listdir(data_root)] | |
| all_paths.sort() | |
| return _load_pose_data( | |
| all_paths, | |
| audio_per_frame, | |
| flip_person=flip_person, | |
| ) | |