Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import os | |
| import gzip | |
| import numpy as np | |
| import io | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| try: | |
| from PIL import UnidentifiedImageError | |
| unidentified_error_available = True | |
| except ImportError: | |
| # UnidentifiedImageError isn't available in older versions of PIL | |
| unidentified_error_available = False | |
| class DiskTarDataset(Dataset): | |
| def __init__(self, | |
| tarfile_path='dataset/imagenet/ImageNet-21k/metadata/tar_files.npy', | |
| tar_index_dir='dataset/imagenet/ImageNet-21k/metadata/tarindex_npy', | |
| preload=False, | |
| num_synsets="all"): | |
| """ | |
| - preload (bool): Recommend to set preload to False when using | |
| - num_synsets (integer or string "all"): set to small number for debugging | |
| will load subset of dataset | |
| """ | |
| tar_files = np.load(tarfile_path) | |
| chunk_datasets = [] | |
| dataset_lens = [] | |
| if isinstance(num_synsets, int): | |
| assert num_synsets < len(tar_files) | |
| tar_files = tar_files[:num_synsets] | |
| for tar_file in tar_files: | |
| dataset = _TarDataset(tar_file, tar_index_dir, preload=preload) | |
| chunk_datasets.append(dataset) | |
| dataset_lens.append(len(dataset)) | |
| self.chunk_datasets = chunk_datasets | |
| self.dataset_lens = np.array(dataset_lens).astype(np.int32) | |
| self.dataset_cumsums = np.cumsum(self.dataset_lens) | |
| self.num_samples = sum(self.dataset_lens) | |
| labels = np.zeros(self.dataset_lens.sum(), dtype=np.int64) | |
| sI = 0 | |
| for k in range(len(self.dataset_lens)): | |
| assert (sI+self.dataset_lens[k]) <= len(labels), f"{k} {sI+self.dataset_lens[k]} vs. {len(labels)}" | |
| labels[sI:(sI+self.dataset_lens[k])] = k | |
| sI += self.dataset_lens[k] | |
| self.labels = labels | |
| def __len__(self): | |
| return self.num_samples | |
| def __getitem__(self, index): | |
| assert index >= 0 and index < len(self) | |
| # find the dataset file we need to go to | |
| d_index = np.searchsorted(self.dataset_cumsums, index) | |
| # edge case, if index is at edge of chunks, move right | |
| if index in self.dataset_cumsums: | |
| d_index += 1 | |
| assert d_index == self.labels[index], f"{d_index} vs. {self.labels[index]} mismatch for {index}" | |
| # change index to local dataset index | |
| if d_index == 0: | |
| local_index = index | |
| else: | |
| local_index = index - self.dataset_cumsums[d_index - 1] | |
| data_bytes = self.chunk_datasets[d_index][local_index] | |
| exception_to_catch = UnidentifiedImageError if unidentified_error_available else Exception | |
| try: | |
| image = Image.open(data_bytes).convert("RGB") | |
| except exception_to_catch: | |
| image = Image.fromarray(np.ones((224,224,3), dtype=np.uint8)*128) | |
| d_index = -1 | |
| # label is the dataset (synset) we indexed into | |
| return image, d_index, index | |
| def __repr__(self): | |
| st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})" | |
| return st | |
| class _TarDataset(object): | |
| def __init__(self, filename, npy_index_dir, preload=False): | |
| # translated from | |
| # fbcode/experimental/deeplearning/matthijs/comp_descs/tardataset.lua | |
| self.filename = filename | |
| self.names = [] | |
| self.offsets = [] | |
| self.npy_index_dir = npy_index_dir | |
| names, offsets = self.load_index() | |
| self.num_samples = len(names) | |
| if preload: | |
| self.data = np.memmap(filename, mode='r', dtype='uint8') | |
| self.offsets = offsets | |
| else: | |
| self.data = None | |
| def __len__(self): | |
| return self.num_samples | |
| def load_index(self): | |
| basename = os.path.basename(self.filename) | |
| basename = os.path.splitext(basename)[0] | |
| names = np.load(os.path.join(self.npy_index_dir, f"{basename}_names.npy")) | |
| offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy")) | |
| return names, offsets | |
| def __getitem__(self, idx): | |
| if self.data is None: | |
| self.data = np.memmap(self.filename, mode='r', dtype='uint8') | |
| _, self.offsets = self.load_index() | |
| ofs = self.offsets[idx] * 512 | |
| fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx]) | |
| data = self.data[ofs:ofs + fsize] | |
| if data[:13].tostring() == '././@LongLink': | |
| data = data[3 * 512:] | |
| else: | |
| data = data[512:] | |
| # just to make it more fun a few JPEGs are GZIP compressed... | |
| # catch this case | |
| if tuple(data[:2]) == (0x1f, 0x8b): | |
| s = io.BytesIO(data.tostring()) | |
| g = gzip.GzipFile(None, 'r', 0, s) | |
| sdata = g.read() | |
| else: | |
| sdata = data.tostring() | |
| return io.BytesIO(sdata) |