Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import json | |
| import os | |
| import cv2 | |
| import lmdb | |
| import numpy as np | |
| import torch.utils.data as data | |
| from PIL import Image | |
| from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS | |
| from imaginaire.utils.distributed import master_only_print as print | |
| import imageio | |
| class LMDBDataset(data.Dataset): | |
| r"""This deals with opening, and reading from an LMDB dataset. | |
| Args: | |
| root (str): Path to the LMDB file. | |
| """ | |
| def __init__(self, root): | |
| self.root = os.path.expanduser(root) | |
| self.env = lmdb.open(root, max_readers=126, readonly=True, lock=False, | |
| readahead=False, meminit=False) | |
| with self.env.begin(write=False) as txn: | |
| self.length = txn.stat()['entries'] | |
| # Read metadata. | |
| with open(os.path.join(self.root, '..', 'metadata.json')) as fin: | |
| self.extensions = json.load(fin) | |
| print('LMDB file at %s opened.' % (root)) | |
| def getitem_by_path(self, path, data_type): | |
| r"""Load data item stored for key = path. | |
| Args: | |
| path (str): Key into LMDB dataset. | |
| data_type (str): Key into self.extensions e.g. data/data_segmaps/... | |
| Returns: | |
| img (PIL.Image) or buf (str): Contents of LMDB value for this key. | |
| """ | |
| # Figure out decoding params. | |
| ext = self.extensions[data_type] | |
| is_image = False | |
| is_hdr = False | |
| if ext in IMG_EXTENSIONS: | |
| is_image = True | |
| if 'tif' in ext: | |
| dtype, mode = np.uint16, -1 | |
| elif 'JPEG' in ext or 'JPG' in ext \ | |
| or 'jpeg' in ext or 'jpg' in ext: | |
| dtype, mode = np.uint8, 3 | |
| else: | |
| dtype, mode = np.uint8, -1 | |
| elif ext in HDR_IMG_EXTENSIONS: | |
| is_hdr = True | |
| else: | |
| is_image = False | |
| # Get value from key. | |
| with self.env.begin(write=False) as txn: | |
| buf = txn.get(path) | |
| # Decode and return. | |
| if is_image: | |
| try: | |
| img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode) | |
| except Exception: | |
| print(path) | |
| # BGR to RGB if 3 channels. | |
| if img.ndim == 3 and img.shape[-1] == 3: | |
| img = img[:, :, ::-1] | |
| img = Image.fromarray(img) | |
| return img | |
| elif is_hdr: | |
| try: | |
| imageio.plugins.freeimage.download() | |
| img = imageio.imread(buf) | |
| except Exception: | |
| print(path) | |
| return img # Return a numpy array | |
| else: | |
| return buf | |
| def __len__(self): | |
| r"""Return number of keys in LMDB dataset.""" | |
| return self.length | |