Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import random | |
| from imaginaire.datasets.base import BaseDataset | |
| class Dataset(BaseDataset): | |
| r"""Image dataset for use in FUNIT. | |
| Args: | |
| cfg (Config): Loaded config object. | |
| is_inference (bool): In train or inference mode? | |
| """ | |
| def __init__(self, cfg, is_inference=False, is_test=False): | |
| self.paired = False | |
| super(Dataset, self).__init__(cfg, is_inference, is_test) | |
| self.num_content_classes = len(self.class_name_to_idx['images_content']) | |
| self.num_style_classes = len(self.class_name_to_idx['images_style']) | |
| self.sample_class_idx = None | |
| self.content_offset = 8888 | |
| self.content_interval = 100 | |
| def set_sample_class_idx(self, class_idx=None): | |
| r"""Set sample class idx. | |
| Args: | |
| class_idx (int): Which class idx to sample from. | |
| """ | |
| self.sample_class_idx = class_idx | |
| if class_idx is None: | |
| self.epoch_length = \ | |
| max([len(lmdb_keys) for _, lmdb_keys in self.mapping.items()]) | |
| else: | |
| self.epoch_length = \ | |
| len(self.mapping_class['images_style'][class_idx]) | |
| def _create_mapping(self): | |
| r"""Creates mapping from idx to key in LMDB. | |
| Returns: | |
| (tuple): | |
| - self.mapping (dict): Dict with data type as key mapping idx to | |
| LMDB key. | |
| - self.epoch_length (int): Number of samples in an epoch. | |
| """ | |
| idx_to_key, class_names = {}, {} | |
| for lmdb_idx, sequence_list in enumerate(self.sequence_lists): | |
| for data_type, data_type_sequence_list in sequence_list.items(): | |
| class_names[data_type] = [] | |
| if data_type not in idx_to_key: | |
| idx_to_key[data_type] = [] | |
| for sequence_name, filenames in data_type_sequence_list.items(): | |
| class_name = sequence_name.split('/')[0] | |
| for filename in filenames: | |
| idx_to_key[data_type].append({ | |
| 'lmdb_root': self.lmdb_roots[lmdb_idx], | |
| 'lmdb_idx': lmdb_idx, | |
| 'sequence_name': sequence_name, | |
| 'filename': filename, | |
| 'class_name': class_name | |
| }) | |
| class_names[data_type].append(class_name) | |
| self.mapping = idx_to_key | |
| self.epoch_length = max([len(lmdb_keys) | |
| for _, lmdb_keys in self.mapping.items()]) | |
| # Create mapping from class name to class idx. | |
| self.class_name_to_idx = {} | |
| for data_type, class_names_data_type in class_names.items(): | |
| self.class_name_to_idx[data_type] = {} | |
| class_names_data_type = sorted(list(set(class_names_data_type))) | |
| for class_idx, class_name in enumerate(class_names_data_type): | |
| self.class_name_to_idx[data_type][class_name] = class_idx | |
| # Add class idx to mapping. | |
| for data_type in self.mapping: | |
| for key in self.mapping[data_type]: | |
| key['class_idx'] = \ | |
| self.class_name_to_idx[data_type][key['class_name']] | |
| # Create a mapping from index to lmdb key for each class. | |
| idx_to_key_class = {} | |
| for data_type in self.mapping: | |
| idx_to_key_class[data_type] = {} | |
| for class_idx, class_name in enumerate(class_names[data_type]): | |
| idx_to_key_class[data_type][class_idx] = [] | |
| for key in self.mapping[data_type]: | |
| idx_to_key_class[data_type][key['class_idx']].append(key) | |
| self.mapping_class = idx_to_key_class | |
| return self.mapping, self.epoch_length | |
| def _sample_keys(self, index): | |
| r"""Gets files to load for this sample. | |
| Args: | |
| index (int): Index in [0, len(dataset)]. | |
| Returns: | |
| (tuple): | |
| - keys (dict): Each key of this dict is a data type. | |
| - lmdb_key (dict): | |
| - lmdb_idx (int): Chosen LMDB dataset root. | |
| - sequence_name (str): Chosen sequence in chosen dataset. | |
| - filename (str): Chosen filename in chosen sequence. | |
| """ | |
| keys = {} | |
| if self.is_inference: # evaluation mode | |
| lmdb_keys_content = self.mapping['images_content'] | |
| keys['images_content'] = \ | |
| lmdb_keys_content[ | |
| ((index + self.content_offset * self.sample_class_idx) * | |
| self.content_interval) % len(lmdb_keys_content)] | |
| lmdb_keys_style = \ | |
| self.mapping_class['images_style'][self.sample_class_idx] | |
| keys['images_style'] = lmdb_keys_style[index] | |
| else: | |
| lmdb_keys_content = self.mapping['images_content'] | |
| lmdb_keys_style = self.mapping['images_style'] | |
| keys['images_content'] = random.choice(lmdb_keys_content) | |
| keys['images_style'] = random.choice(lmdb_keys_style) | |
| return keys | |
| def __getitem__(self, index): | |
| r"""Gets selected files. | |
| Args: | |
| index (int): Index into dataset. | |
| concat (bool): Concatenate all items in labels? | |
| Returns: | |
| data (dict): Dict with all chosen data_types. | |
| """ | |
| # Select a sample from the available data. | |
| keys_per_data_type = self._sample_keys(index) | |
| # Get class idx into a list. | |
| class_idxs = [] | |
| for data_type in keys_per_data_type: | |
| class_idxs.append(keys_per_data_type[data_type]['class_idx']) | |
| # Get keys and lmdbs. | |
| keys, lmdbs = {}, {} | |
| for data_type in self.dataset_data_types: | |
| # Unpack keys. | |
| lmdb_idx = keys_per_data_type[data_type]['lmdb_idx'] | |
| sequence_name = keys_per_data_type[data_type]['sequence_name'] | |
| filename = keys_per_data_type[data_type]['filename'] | |
| keys[data_type] = '%s/%s' % (sequence_name, filename) | |
| lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx] | |
| # Load all data for this index. | |
| data = self.load_from_dataset(keys, lmdbs) | |
| # Apply ops pre augmentation. | |
| data = self.apply_ops(data, self.pre_aug_ops) | |
| # Do augmentations for images. | |
| data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops) | |
| # Apply ops post augmentation. | |
| data = self.apply_ops(data, self.post_aug_ops) | |
| data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True) | |
| # Convert images to tensor. | |
| data = self.to_tensor(data) | |
| # Remove any extra dimensions. | |
| for data_type in self.image_data_types: | |
| data[data_type] = data[data_type][0] | |
| # Package output. | |
| data['is_flipped'] = is_flipped | |
| data['key'] = keys_per_data_type | |
| data['labels_content'] = class_idxs[0] | |
| data['labels_style'] = class_idxs[1] | |
| return data | |