Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import io | |
| import json | |
| # import cv2 | |
| import boto3 | |
| from botocore.config import Config | |
| import numpy as np | |
| import torch.utils.data as data | |
| from PIL import Image | |
| import imageio | |
| from botocore.exceptions import ClientError | |
| from imaginaire.datasets.cache import Cache | |
| from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS | |
| Image.MAX_IMAGE_PIXELS = None | |
| class ObjectStoreDataset(data.Dataset): | |
| r"""This deals with opening, and reading from an AWS S3 bucket. | |
| Args: | |
| root (str): Path to the AWS S3 bucket. | |
| aws_credentials_file (str): Path to file containing AWS credentials. | |
| data_type (str): Which data type should this dataset load? | |
| """ | |
| def __init__(self, root, aws_credentials_file, data_type='', cache=None): | |
| # Cache. | |
| self.cache = False | |
| if cache is not None: | |
| # raise NotImplementedError | |
| self.cache = Cache(cache.root, cache.size_GB) | |
| # Get bucket info, and keys to info about dataset. | |
| with open(aws_credentials_file) as fin: | |
| self.credentials = json.load(fin) | |
| parts = root.split('/') | |
| self.bucket = parts[0] | |
| self.all_filenames_key = '/'.join(parts[1:]) + '/all_filenames.json' | |
| self.metadata_key = '/'.join(parts[1:]) + '/metadata.json' | |
| # Get list of filenames. | |
| filename_info = self._get_object(self.all_filenames_key) | |
| self.sequence_list = json.loads(filename_info.decode('utf-8')) | |
| # Get length. | |
| length = 0 | |
| for _, value in self.sequence_list.items(): | |
| length += len(value) | |
| self.length = length | |
| # Read metadata. | |
| metadata_info = self._get_object(self.metadata_key) | |
| self.extensions = json.loads(metadata_info.decode('utf-8')) | |
| self.data_type = data_type | |
| print('AWS S3 bucket at %s opened.' % (root + '/' + self.data_type)) | |
| def _get_object(self, key): | |
| r"""Download object from bucket. | |
| Args: | |
| key (str): Key inside bucket. | |
| """ | |
| # Look up value in cache. | |
| object_content = self.cache.read(key) if self.cache else False | |
| if not object_content: | |
| # Either no cache used or key not found in cache. | |
| config = Config(connect_timeout=30, | |
| signature_version="s3", | |
| retries={"max_attempts": 999999}) | |
| s3 = boto3.client('s3', **self.credentials, config=config) | |
| try: | |
| s3_response_object = s3.get_object(Bucket=self.bucket, Key=key) | |
| object_content = s3_response_object['Body'].read() | |
| except Exception as e: | |
| print('%s not found' % (key)) | |
| print(e) | |
| # Save content to cache. | |
| if self.cache: | |
| self.cache.write(key, object_content) | |
| return object_content | |
| def getitem_by_path(self, path, data_type): | |
| r"""Load data item stored for key = path. | |
| Args: | |
| path (str): Path into AWS S3 bucket, without data_type prefix. | |
| data_type (str): Key into self.extensions e.g. data/data_segmaps/... | |
| Returns: | |
| img (PIL.Image) or buf (str): Contents of LMDB value for this key. | |
| """ | |
| # Figure out decoding params. | |
| ext = self.extensions[data_type] | |
| is_image = False | |
| is_hdr = False | |
| parts = path.split('/') | |
| key = parts[0] + '/' + data_type + '/' + '/'.join(parts[1:]) + '.' + ext | |
| if ext in IMG_EXTENSIONS: | |
| is_image = True | |
| if 'tif' in ext: | |
| _, mode = np.uint16, -1 | |
| elif 'JPEG' in ext or 'JPG' in ext \ | |
| or 'jpeg' in ext or 'jpg' in ext: | |
| _, mode = np.uint8, 3 | |
| else: | |
| _, mode = np.uint8, -1 | |
| elif ext in HDR_IMG_EXTENSIONS: | |
| is_hdr = True | |
| else: | |
| is_image = False | |
| # Get value from key. | |
| buf = self._get_object(key) | |
| # Decode and return. | |
| if is_image: | |
| # This is totally a hack. | |
| # We should have a better way to handle grayscale images. | |
| img = Image.open(io.BytesIO(buf)) | |
| if mode == 3: | |
| img = img.convert('RGB') | |
| return img | |
| elif is_hdr: | |
| try: | |
| imageio.plugins.freeimage.download() | |
| img = imageio.imread(buf) | |
| except Exception: | |
| print(path) | |
| return img # Return a numpy array | |
| else: | |
| return buf | |
| def __len__(self): | |
| r"""Return number of keys in LMDB dataset.""" | |
| return self.length | |