Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| from imaginaire.datasets.paired_videos import Dataset as VideoDataset | |
| class Dataset(VideoDataset): | |
| r"""Paired image dataset for use in pix2pixHD, SPADE. | |
| Args: | |
| cfg (Config): Loaded config object. | |
| is_inference (bool): In train or inference mode? | |
| """ | |
| def __init__(self, cfg, is_inference=False, is_test=False): | |
| self.paired = True | |
| super(Dataset, self).__init__(cfg, is_inference, | |
| sequence_length=1, | |
| is_test=is_test) | |
| self.is_video_dataset = False | |
| def _create_mapping(self): | |
| r"""Creates mapping from idx to key in LMDB. | |
| Returns: | |
| (tuple): | |
| - self.mapping (list): List mapping idx to key. | |
| - self.epoch_length (int): Number of samples in an epoch. | |
| """ | |
| idx_to_key = [] | |
| for lmdb_idx, sequence_list in enumerate(self.sequence_lists): | |
| for sequence_name, filenames in sequence_list.items(): | |
| for filename in filenames: | |
| idx_to_key.append({ | |
| 'lmdb_root': self.lmdb_roots[lmdb_idx], | |
| 'lmdb_idx': lmdb_idx, | |
| 'sequence_name': sequence_name, | |
| 'filenames': [filename], | |
| }) | |
| self.mapping = idx_to_key | |
| self.epoch_length = len(self.mapping) | |
| return self.mapping, self.epoch_length | |
| def _sample_keys(self, index): | |
| r"""Gets files to load for this sample. | |
| Args: | |
| index (int): Index in [0, len(dataset)]. | |
| Returns: | |
| key (dict): | |
| - lmdb_idx (int): Chosen LMDB dataset root. | |
| - sequence_name (str): Chosen sequence in chosen dataset. | |
| - filenames (list of str): Chosen filenames in chosen sequence. | |
| """ | |
| assert self.sequence_length == 1, \ | |
| 'Image dataset can only have sequence length = 1, not %d' % ( | |
| self.sequence_length) | |
| return self.mapping[index] | |
| def set_sequence_length(self, sequence_length): | |
| r"""Set the length of sequence you want as output from dataloader. | |
| Ignore this as this is an image loader. | |
| Args: | |
| sequence_length (int): Length of output sequences. | |
| """ | |
| pass | |
| def set_inference_sequence_idx(self, index): | |
| r"""Get frames from this sequence during inference. | |
| Overriden from super as this is not applicable for images. | |
| Args: | |
| index (int): Index of inference sequence. | |
| """ | |
| raise RuntimeError('Image dataset does not have sequences.') | |
| def num_inference_sequences(self): | |
| r"""Number of sequences available for inference. | |
| Overriden from super as this is not applicable for images. | |
| Returns: | |
| (int) | |
| """ | |
| raise RuntimeError('Image dataset does not have sequences.') | |