Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import copy | |
| import random | |
| import torch | |
| from imaginaire.datasets.paired_videos import Dataset as VideoDataset | |
| from imaginaire.model_utils.fs_vid2vid import select_object | |
| from imaginaire.utils.distributed import master_only_print as print | |
| class Dataset(VideoDataset): | |
| r"""Paired video dataset for use in few-shot vid2vid. | |
| Args: | |
| cfg (Config): Loaded config object. | |
| is_inference (bool): In train or inference mode? | |
| sequence_length (int): What sequence of images to provide? | |
| few_shot_K (int): How many images to provide for few-shot? | |
| """ | |
| def __init__(self, cfg, is_inference=False, sequence_length=None, | |
| few_shot_K=None, is_test=False): | |
| self.paired = True | |
| # Get initial few shot K. | |
| if few_shot_K is None: | |
| self.few_shot_K = cfg.data.initial_few_shot_K | |
| else: | |
| self.few_shot_K = few_shot_K | |
| # Initialize. | |
| super(Dataset, self).__init__( | |
| cfg, is_inference, sequence_length=sequence_length, is_test=is_test) | |
| def set_inference_sequence_idx(self, index, k_shot_index, | |
| k_shot_frame_index): | |
| r"""Get frames from this sequence during inference. | |
| Args: | |
| index (int): Index of inference sequence. | |
| k_shot_index (int): Index of sequence from which k_shot is sampled. | |
| k_shot_frame_index (int): Index of frame to sample. | |
| """ | |
| assert self.is_inference | |
| assert index < len(self.mapping) | |
| assert k_shot_index < len(self.mapping) | |
| assert k_shot_frame_index < len(self.mapping[k_shot_index]) | |
| self.inference_sequence_idx = index | |
| self.inference_k_shot_sequence_index = k_shot_index | |
| self.inference_k_shot_frame_index = k_shot_frame_index | |
| self.epoch_length = len( | |
| self.mapping[self.inference_sequence_idx]['filenames']) | |
| def set_sequence_length(self, sequence_length, few_shot_K=None): | |
| r"""Set the length of sequence you want as output from dataloader. | |
| Args: | |
| sequence_length (int): Length of output sequences. | |
| few_shot_K (int): Number of few-shot frames. | |
| """ | |
| if few_shot_K is None: | |
| few_shot_K = self.few_shot_K | |
| assert isinstance(sequence_length, int) | |
| assert isinstance(few_shot_K, int) | |
| if (sequence_length + few_shot_K) > self.sequence_length_max: | |
| error_message = \ | |
| 'Requested sequence length (%d) ' % (sequence_length) + \ | |
| '+ few shot K (%d) > ' % (few_shot_K) + \ | |
| 'max sequence length (%d). ' % (self.sequence_length_max) | |
| print(error_message) | |
| sequence_length = self.sequence_length_max - few_shot_K | |
| print('Reduced sequence length to %s' % (sequence_length)) | |
| self.sequence_length = sequence_length | |
| self.few_shot_K = few_shot_K | |
| # Recalculate mapping as some sequences might no longer be useful. | |
| self.mapping, self.epoch_length = self._create_mapping() | |
| print('Epoch length:', self.epoch_length) | |
| def _create_mapping(self): | |
| r"""Creates mapping from idx to key in LMDB. | |
| Returns: | |
| (tuple): | |
| - self.mapping (dict): Dict of seq_len to list of sequences. | |
| - self.epoch_length (int): Number of samples in an epoch. | |
| """ | |
| # Create dict mapping length to sequence. | |
| length_to_key, num_selected_seq = {}, 0 | |
| has_additional_lists = len(self.additional_lists) > 0 | |
| for lmdb_idx, sequence_list in enumerate(self.sequence_lists): | |
| for sequence_name, filenames in sequence_list.items(): | |
| if len(filenames) >= (self.sequence_length + self.few_shot_K): | |
| if len(filenames) not in length_to_key: | |
| length_to_key[len(filenames)] = [] | |
| if has_additional_lists: | |
| obj_indices = self.additional_lists[lmdb_idx][ | |
| sequence_name] | |
| else: | |
| obj_indices = [0 for _ in range(len(filenames))] | |
| length_to_key[len(filenames)].append({ | |
| 'lmdb_root': self.lmdb_roots[lmdb_idx], | |
| 'lmdb_idx': lmdb_idx, | |
| 'sequence_name': sequence_name, | |
| 'filenames': filenames, | |
| 'obj_indices': obj_indices, | |
| }) | |
| num_selected_seq += 1 | |
| self.mapping = length_to_key | |
| self.epoch_length = num_selected_seq | |
| # At inference time, we want to use all sequences, | |
| # irrespective of length. | |
| if self.is_inference: | |
| sequence_list = [] | |
| for key, sequences in self.mapping.items(): | |
| sequence_list.extend(sequences) | |
| self.mapping = sequence_list | |
| return self.mapping, self.epoch_length | |
| def _sample_keys(self, index): | |
| r"""Gets files to load for this sample. | |
| Args: | |
| index (int): Index in [0, len(dataset)]. | |
| Returns: | |
| key (dict): | |
| - lmdb_idx (int): Chosen LMDB dataset root. | |
| - sequence_name (str): Chosen sequence in chosen dataset. | |
| - filenames (list of str): Chosen filenames in chosen sequence. | |
| """ | |
| if self.is_inference: | |
| assert index < self.epoch_length | |
| chosen_sequence = self.mapping[self.inference_sequence_idx] | |
| chosen_filenames = [chosen_sequence['filenames'][index]] | |
| chosen_obj_indices = [chosen_sequence['obj_indices'][index]] | |
| k_shot_chosen_sequence = self.mapping[ | |
| self.inference_k_shot_sequence_index] | |
| k_shot_chosen_filenames = [k_shot_chosen_sequence['filenames'][ | |
| self.inference_k_shot_frame_index]] | |
| k_shot_chosen_obj_indices = [k_shot_chosen_sequence['obj_indices'][ | |
| self.inference_k_shot_frame_index]] | |
| # Prepare few shot key. | |
| few_shot_key = copy.deepcopy(k_shot_chosen_sequence) | |
| few_shot_key['filenames'] = k_shot_chosen_filenames | |
| few_shot_key['obj_indices'] = k_shot_chosen_obj_indices | |
| else: | |
| # Pick a time step for temporal augmentation. | |
| time_step = random.randint(1, self.augmentor.max_time_step) | |
| required_sequence_length = 1 + \ | |
| (self.sequence_length - 1) * time_step | |
| # If step is too large, default to step size of 1. | |
| if required_sequence_length + self.few_shot_K > \ | |
| self.sequence_length_max: | |
| required_sequence_length = self.sequence_length | |
| time_step = 1 | |
| # Find valid sequences. | |
| valid_sequences = [] | |
| for sequence_length, sequences in self.mapping.items(): | |
| if sequence_length >= required_sequence_length + \ | |
| self.few_shot_K: | |
| valid_sequences.extend(sequences) | |
| # Pick a sequence. | |
| chosen_sequence = random.choice(valid_sequences) | |
| # Choose filenames. | |
| max_start_idx = len(chosen_sequence['filenames']) - \ | |
| required_sequence_length | |
| start_idx = random.randint(0, max_start_idx) | |
| end_idx = start_idx + required_sequence_length | |
| chosen_filenames = chosen_sequence['filenames'][ | |
| start_idx:end_idx:time_step] | |
| chosen_obj_indices = chosen_sequence['obj_indices'][ | |
| start_idx:end_idx:time_step] | |
| # Find the K few shot filenames. | |
| valid_range = list(range(start_idx)) + \ | |
| list(range(end_idx, len(chosen_sequence['filenames']))) | |
| k_shot_chosen = sorted(random.sample(valid_range, self.few_shot_K)) | |
| k_shot_chosen_filenames = [chosen_sequence['filenames'][idx] | |
| for idx in k_shot_chosen] | |
| k_shot_chosen_obj_indices = [chosen_sequence['obj_indices'][idx] | |
| for idx in k_shot_chosen] | |
| assert not (set(chosen_filenames) & set(k_shot_chosen_filenames)) | |
| assert len(chosen_filenames) == self.sequence_length | |
| assert len(k_shot_chosen_filenames) == self.few_shot_K | |
| # Prepare few shot key. | |
| few_shot_key = copy.deepcopy(chosen_sequence) | |
| few_shot_key['filenames'] = k_shot_chosen_filenames | |
| few_shot_key['obj_indices'] = k_shot_chosen_obj_indices | |
| # Prepre output key. | |
| key = copy.deepcopy(chosen_sequence) | |
| key['filenames'] = chosen_filenames | |
| key['obj_indices'] = chosen_obj_indices | |
| return key, few_shot_key | |
| def _prepare_data(self, keys): | |
| r"""Load data and perform augmentation. | |
| Args: | |
| keys (dict): Key into LMDB/folder dataset for this item. | |
| Returns: | |
| data (dict): Dict with all chosen data_types. | |
| """ | |
| # Unpack keys. | |
| lmdb_idx = keys['lmdb_idx'] | |
| sequence_name = keys['sequence_name'] | |
| filenames = keys['filenames'] | |
| obj_indices = keys['obj_indices'] | |
| # Get key and lmdbs. | |
| keys, lmdbs = {}, {} | |
| for data_type in self.dataset_data_types: | |
| keys[data_type] = self._create_sequence_keys( | |
| sequence_name, filenames) | |
| lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx] | |
| # Load all data for this index. | |
| data = self.load_from_dataset(keys, lmdbs) | |
| # Apply ops pre augmentation. | |
| data = self.apply_ops(data, self.pre_aug_ops) | |
| # Select the object in data using the object indices. | |
| data = select_object(data, obj_indices) | |
| # Do augmentations for images. | |
| data, is_flipped = self.perform_augmentation(data, paired=True, augment_ops=self.augmentor.augment_ops) | |
| # Create copy of keypoint data types before post aug. | |
| # kp_data = {} | |
| # for data_type in self.keypoint_data_types: | |
| # new_key = data_type + '_xy' | |
| # kp_data[new_key] = copy.deepcopy(data[data_type]) | |
| # Create copy of keypoint data types before post aug. | |
| kp_data = {} | |
| for data_type in self.keypoint_data_types: | |
| new_key = data_type + '_xy' | |
| kp_data[new_key] = copy.deepcopy(data[data_type]) | |
| # Apply ops post augmentation. | |
| data = self.apply_ops(data, self.post_aug_ops) | |
| data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True) | |
| # Convert images to tensor. | |
| data = self.to_tensor(data) | |
| # Pack the sequence of images. | |
| for data_type in self.image_data_types: | |
| for idx in range(len(data[data_type])): | |
| data[data_type][idx] = data[data_type][idx].unsqueeze(0) | |
| data[data_type] = torch.cat(data[data_type], dim=0) | |
| # Add keypoint xy to data. | |
| data.update(kp_data) | |
| data['is_flipped'] = is_flipped | |
| data['key'] = keys | |
| return data | |
| def _getitem(self, index): | |
| r"""Gets selected files. | |
| Args: | |
| index (int): Index into dataset. | |
| Returns: | |
| data (dict): Dict with all chosen data_types. | |
| """ | |
| # Select a sample from the available data. | |
| keys, few_shot_keys = self._sample_keys(index) | |
| data = self._prepare_data(keys) | |
| few_shot_data = self._prepare_data(few_shot_keys) | |
| # Add few shot data into data. | |
| for key, value in few_shot_data.items(): | |
| data['few_shot_' + key] = few_shot_data[key] | |
| # Apply full data ops. | |
| if self.is_inference: | |
| if index == 0: | |
| pass | |
| elif index < self.cfg.data.num_workers: | |
| data_0 = self._getitem(0) | |
| if 'common_attr' in data_0: | |
| self.common_attr = data['common_attr'] = \ | |
| data_0['common_attr'] | |
| else: | |
| if hasattr(self, 'common_attr'): | |
| data['common_attr'] = self.common_attr | |
| data = self.apply_ops(data, self.full_data_ops, full_data=True) | |
| if self.is_inference and index == 0 and 'common_attr' in data: | |
| self.common_attr = data['common_attr'] | |
| return data | |