Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import random | |
| import tempfile | |
| from collections import OrderedDict | |
| import warnings | |
| import numpy as np | |
| import torch | |
| # import torchvision.io as io | |
| import cv2 | |
| from PIL import Image | |
| from imaginaire.datasets.base import BaseDataset | |
| class Dataset(BaseDataset): | |
| r"""Dataset for paired few shot videos. | |
| Args: | |
| cfg (Config): Loaded config object. | |
| is_inference (bool): In train or inference mode? | |
| """ | |
| def __init__(self, cfg, is_inference=False, is_test=False): | |
| self.paired = True | |
| super(Dataset, self).__init__(cfg, is_inference, is_test) | |
| self.is_video_dataset = True | |
| self.few_shot_K = 1 | |
| self.first_last_only = getattr(cfg.data, 'first_last_only', False) | |
| self.sample_far_frames_more = getattr(cfg.data, 'sample_far_frames_more', False) | |
| def get_label_lengths(self): | |
| r"""Get num channels of all labels to be concated. | |
| Returns: | |
| label_lengths (OrderedDict): Dict mapping image data_type to num | |
| channels. | |
| """ | |
| label_lengths = OrderedDict() | |
| for data_type in self.input_labels: | |
| data_cfg = self.cfgdata | |
| if hasattr(data_cfg, 'one_hot_num_classes') and \ | |
| data_type in data_cfg.one_hot_num_classes: | |
| label_lengths[data_type] = data_cfg.one_hot_num_classes[data_type] | |
| if getattr(data_cfg, 'use_dont_care', False): | |
| label_lengths[data_type] += 1 | |
| else: | |
| label_lengths[data_type] = self.num_channels[data_type] | |
| return label_lengths | |
| def num_inference_sequences(self): | |
| r"""Number of sequences available for inference. | |
| Returns: | |
| (int) | |
| """ | |
| assert self.is_inference | |
| return len(self.mapping) | |
| def _create_mapping(self): | |
| r"""Creates mapping from idx to key in LMDB. | |
| Returns: | |
| (tuple): | |
| - self.mapping (dict): Dict of seq_len to list of sequences. | |
| - self.epoch_length (int): Number of samples in an epoch. | |
| """ | |
| # Create dict mapping length to sequence. | |
| mapping = [] | |
| for lmdb_idx, sequence_list in enumerate(self.sequence_lists): | |
| for sequence_name, filenames in sequence_list.items(): | |
| for filename in filenames: | |
| # This file is corrupt. | |
| if filename == 'z-KziTO_5so_0019_start0_end85_h596_w596': | |
| continue | |
| mapping.append({ | |
| 'lmdb_root': self.lmdb_roots[lmdb_idx], | |
| 'lmdb_idx': lmdb_idx, | |
| 'sequence_name': sequence_name, | |
| 'filenames': [filename], | |
| }) | |
| self.mapping = mapping | |
| self.epoch_length = len(mapping) | |
| return self.mapping, self.epoch_length | |
| def _sample_keys(self, index): | |
| r"""Gets files to load for this sample. | |
| Args: | |
| index (int): Index in [0, len(dataset)]. | |
| Returns: | |
| (tuple): | |
| - key (dict): | |
| - lmdb_idx (int): Chosen LMDB dataset root. | |
| - sequence_name (str): Chosen sequence in chosen dataset. | |
| - filenames (list of str): Chosen filenames in chosen sequence. | |
| """ | |
| if self.is_inference: | |
| assert index < self.epoch_length | |
| raise NotImplementedError | |
| else: | |
| # Select a video at random. | |
| key = random.choice(self.mapping) | |
| return key | |
| def _create_sequence_keys(self, sequence_name, filenames): | |
| r"""Create the LMDB key for this piece of information. | |
| Args: | |
| sequence_name (str): Which sequence from the chosen dataset. | |
| filenames (list of str): List of filenames in this sequence. | |
| Returns: | |
| keys (list): List of full keys. | |
| """ | |
| assert isinstance(filenames, list), 'Filenames should be a list.' | |
| keys = [] | |
| for filename in filenames: | |
| keys.append('%s/%s' % (sequence_name, filename)) | |
| return keys | |
| def _getitem(self, index): | |
| r"""Gets selected files. | |
| Args: | |
| index (int): Index into dataset. | |
| concat (bool): Concatenate all items in labels? | |
| Returns: | |
| data (dict): Dict with all chosen data_types. | |
| """ | |
| # Select a sample from the available data. | |
| keys = self._sample_keys(index) | |
| # Unpack keys. | |
| lmdb_idx = keys['lmdb_idx'] | |
| sequence_name = keys['sequence_name'] | |
| filenames = keys['filenames'] | |
| # Get key and lmdbs. | |
| keys, lmdbs = {}, {} | |
| for data_type in self.dataset_data_types: | |
| keys[data_type] = self._create_sequence_keys( | |
| sequence_name, filenames) | |
| lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx] | |
| # Load all data for this index. | |
| data = self.load_from_dataset(keys, lmdbs) | |
| # Get frames from video. | |
| try: | |
| temp = tempfile.NamedTemporaryFile() | |
| temp.write(data['videos'][0]) | |
| temp.seek(0) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| # frames, _, info = io.read_video(temp) | |
| # num_frames = frames.size(0) | |
| cap = cv2.VideoCapture(temp.name) | |
| num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if self.first_last_only: | |
| chosen_idxs = [0, num_frames - 1] | |
| else: | |
| # chosen_idxs = random.sample(range(frames.size(0)), 2) | |
| chosen_idx = random.sample(range(num_frames), 1)[0] | |
| few_shot_choose_range = list(range(chosen_idx)) + list(range(chosen_idx + 1, num_frames)) | |
| if self.sample_far_frames_more: | |
| choose_weight = list(reversed(range(chosen_idx))) + list(range(num_frames - chosen_idx - 1)) | |
| few_shot_idx = random.choices(few_shot_choose_range, choose_weight, k=self.few_shot_K) | |
| else: | |
| few_shot_idx = random.sample(few_shot_choose_range, k=self.few_shot_K) | |
| chosen_idxs = few_shot_idx + [chosen_idx] | |
| chosen_images = [] | |
| for idx in chosen_idxs: | |
| # chosen_images.append(Image.fromarray(frames[idx].numpy())) | |
| cap.set(1, idx) | |
| _, frame = cap.read() | |
| chosen_images.append(Image.fromarray(frame[:, :, ::-1])) | |
| except Exception: | |
| print('Issue with file:', sequence_name, filenames) | |
| blank = np.zeros((512, 512, 3), dtype=np.uint8) | |
| chosen_images = [Image.fromarray(blank), Image.fromarray(blank)] | |
| data['videos'] = chosen_images | |
| # Apply ops pre augmentation. | |
| data = self.apply_ops(data, self.pre_aug_ops) | |
| # Do augmentations for images. | |
| data, is_flipped = self.perform_augmentation( | |
| data, paired=True, augment_ops=self.augmentor.augment_ops) | |
| # Individual video frame augmentation is used in face-vid2vid. | |
| data = self.perform_individual_video_frame( | |
| data, self.augmentor.individual_video_frame_augmentation_ops) | |
| # Apply ops post augmentation. | |
| data = self.apply_ops(data, self.post_aug_ops) | |
| # Convert images to tensor. | |
| data = self.to_tensor(data) | |
| # Pack the sequence of images. | |
| for data_type in self.image_data_types: | |
| for idx in range(len(data[data_type])): | |
| data[data_type][idx] = data[data_type][idx].unsqueeze(0) | |
| data[data_type] = torch.cat(data[data_type], dim=0) | |
| if not self.is_video_dataset: | |
| # Remove any extra dimensions. | |
| for data_type in self.image_data_types: | |
| if data_type in data: | |
| data[data_type] = data[data_type].squeeze(0) | |
| # Prepare output. | |
| data['driving_images'] = data['videos'][self.few_shot_K:] | |
| data['source_images'] = data['videos'][:self.few_shot_K] | |
| data.pop('videos') | |
| data['is_flipped'] = is_flipped | |
| data['key'] = keys | |
| data['original_h_w'] = torch.IntTensor([ | |
| self.augmentor.original_h, self.augmentor.original_w]) | |
| # Apply full data ops. | |
| data = self.apply_ops(data, self.full_data_ops, full_data=True) | |
| return data | |
| def __getitem__(self, index): | |
| return self._getitem(index) | |