Spaces:
Runtime error
Runtime error
| import os, tarfile, glob, shutil | |
| import yaml | |
| import numpy as np | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import albumentations | |
| from omegaconf import OmegaConf | |
| from torch.utils.data import Dataset | |
| from taming.data.base import ImagePaths | |
| from taming.util import download, retrieve | |
| import taming.data.utils as bdu | |
| def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): | |
| synsets = [] | |
| with open(path_to_yaml) as f: | |
| di2s = yaml.load(f) | |
| for idx in indices: | |
| synsets.append(str(di2s[idx])) | |
| print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets))) | |
| return synsets | |
| def str_to_indices(string): | |
| """Expects a string in the format '32-123, 256, 280-321'""" | |
| assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string) | |
| subs = string.split(",") | |
| indices = [] | |
| for sub in subs: | |
| subsubs = sub.split("-") | |
| assert len(subsubs) > 0 | |
| if len(subsubs) == 1: | |
| indices.append(int(subsubs[0])) | |
| else: | |
| rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))] | |
| indices.extend(rang) | |
| return sorted(indices) | |
| class ImageNetBase(Dataset): | |
| def __init__(self, config=None): | |
| self.config = config or OmegaConf.create() | |
| if not type(self.config)==dict: | |
| self.config = OmegaConf.to_container(self.config) | |
| self._prepare() | |
| self._prepare_synset_to_human() | |
| self._prepare_idx_to_synset() | |
| self._load() | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, i): | |
| return self.data[i] | |
| def _prepare(self): | |
| raise NotImplementedError() | |
| def _filter_relpaths(self, relpaths): | |
| ignore = set([ | |
| "n06596364_9591.JPEG", | |
| ]) | |
| relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] | |
| if "sub_indices" in self.config: | |
| indices = str_to_indices(self.config["sub_indices"]) | |
| synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings | |
| files = [] | |
| for rpath in relpaths: | |
| syn = rpath.split("/")[0] | |
| if syn in synsets: | |
| files.append(rpath) | |
| return files | |
| else: | |
| return relpaths | |
| def _prepare_synset_to_human(self): | |
| SIZE = 2655750 | |
| URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" | |
| self.human_dict = os.path.join(self.root, "synset_human.txt") | |
| if (not os.path.exists(self.human_dict) or | |
| not os.path.getsize(self.human_dict)==SIZE): | |
| download(URL, self.human_dict) | |
| def _prepare_idx_to_synset(self): | |
| URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" | |
| self.idx2syn = os.path.join(self.root, "index_synset.yaml") | |
| if (not os.path.exists(self.idx2syn)): | |
| download(URL, self.idx2syn) | |
| def _load(self): | |
| with open(self.txt_filelist, "r") as f: | |
| self.relpaths = f.read().splitlines() | |
| l1 = len(self.relpaths) | |
| self.relpaths = self._filter_relpaths(self.relpaths) | |
| print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) | |
| self.synsets = [p.split("/")[0] for p in self.relpaths] | |
| self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] | |
| unique_synsets = np.unique(self.synsets) | |
| class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) | |
| self.class_labels = [class_dict[s] for s in self.synsets] | |
| with open(self.human_dict, "r") as f: | |
| human_dict = f.read().splitlines() | |
| human_dict = dict(line.split(maxsplit=1) for line in human_dict) | |
| self.human_labels = [human_dict[s] for s in self.synsets] | |
| labels = { | |
| "relpath": np.array(self.relpaths), | |
| "synsets": np.array(self.synsets), | |
| "class_label": np.array(self.class_labels), | |
| "human_label": np.array(self.human_labels), | |
| } | |
| self.data = ImagePaths(self.abspaths, | |
| labels=labels, | |
| size=retrieve(self.config, "size", default=0), | |
| random_crop=self.random_crop) | |
| class ImageNetTrain(ImageNetBase): | |
| NAME = "ILSVRC2012_train" | |
| URL = "http://www.image-net.org/challenges/LSVRC/2012/" | |
| AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" | |
| FILES = [ | |
| "ILSVRC2012_img_train.tar", | |
| ] | |
| SIZES = [ | |
| 147897477120, | |
| ] | |
| def _prepare(self): | |
| self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", | |
| default=True) | |
| cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) | |
| self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) | |
| self.datadir = os.path.join(self.root, "data") | |
| self.txt_filelist = os.path.join(self.root, "filelist.txt") | |
| self.expected_length = 1281167 | |
| if not bdu.is_prepared(self.root): | |
| # prep | |
| print("Preparing dataset {} in {}".format(self.NAME, self.root)) | |
| datadir = self.datadir | |
| if not os.path.exists(datadir): | |
| path = os.path.join(self.root, self.FILES[0]) | |
| if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: | |
| import academictorrents as at | |
| atpath = at.get(self.AT_HASH, datastore=self.root) | |
| assert atpath == path | |
| print("Extracting {} to {}".format(path, datadir)) | |
| os.makedirs(datadir, exist_ok=True) | |
| with tarfile.open(path, "r:") as tar: | |
| tar.extractall(path=datadir) | |
| print("Extracting sub-tars.") | |
| subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) | |
| for subpath in tqdm(subpaths): | |
| subdir = subpath[:-len(".tar")] | |
| os.makedirs(subdir, exist_ok=True) | |
| with tarfile.open(subpath, "r:") as tar: | |
| tar.extractall(path=subdir) | |
| filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) | |
| filelist = [os.path.relpath(p, start=datadir) for p in filelist] | |
| filelist = sorted(filelist) | |
| filelist = "\n".join(filelist)+"\n" | |
| with open(self.txt_filelist, "w") as f: | |
| f.write(filelist) | |
| bdu.mark_prepared(self.root) | |
| class ImageNetValidation(ImageNetBase): | |
| NAME = "ILSVRC2012_validation" | |
| URL = "http://www.image-net.org/challenges/LSVRC/2012/" | |
| AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" | |
| VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" | |
| FILES = [ | |
| "ILSVRC2012_img_val.tar", | |
| "validation_synset.txt", | |
| ] | |
| SIZES = [ | |
| 6744924160, | |
| 1950000, | |
| ] | |
| def _prepare(self): | |
| self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", | |
| default=False) | |
| cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) | |
| self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) | |
| self.datadir = os.path.join(self.root, "data") | |
| self.txt_filelist = os.path.join(self.root, "filelist.txt") | |
| self.expected_length = 50000 | |
| if not bdu.is_prepared(self.root): | |
| # prep | |
| print("Preparing dataset {} in {}".format(self.NAME, self.root)) | |
| datadir = self.datadir | |
| if not os.path.exists(datadir): | |
| path = os.path.join(self.root, self.FILES[0]) | |
| if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: | |
| import academictorrents as at | |
| atpath = at.get(self.AT_HASH, datastore=self.root) | |
| assert atpath == path | |
| print("Extracting {} to {}".format(path, datadir)) | |
| os.makedirs(datadir, exist_ok=True) | |
| with tarfile.open(path, "r:") as tar: | |
| tar.extractall(path=datadir) | |
| vspath = os.path.join(self.root, self.FILES[1]) | |
| if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: | |
| download(self.VS_URL, vspath) | |
| with open(vspath, "r") as f: | |
| synset_dict = f.read().splitlines() | |
| synset_dict = dict(line.split() for line in synset_dict) | |
| print("Reorganizing into synset folders") | |
| synsets = np.unique(list(synset_dict.values())) | |
| for s in synsets: | |
| os.makedirs(os.path.join(datadir, s), exist_ok=True) | |
| for k, v in synset_dict.items(): | |
| src = os.path.join(datadir, k) | |
| dst = os.path.join(datadir, v) | |
| shutil.move(src, dst) | |
| filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) | |
| filelist = [os.path.relpath(p, start=datadir) for p in filelist] | |
| filelist = sorted(filelist) | |
| filelist = "\n".join(filelist)+"\n" | |
| with open(self.txt_filelist, "w") as f: | |
| f.write(filelist) | |
| bdu.mark_prepared(self.root) | |
| def get_preprocessor(size=None, random_crop=False, additional_targets=None, | |
| crop_size=None): | |
| if size is not None and size > 0: | |
| transforms = list() | |
| rescaler = albumentations.SmallestMaxSize(max_size = size) | |
| transforms.append(rescaler) | |
| if not random_crop: | |
| cropper = albumentations.CenterCrop(height=size,width=size) | |
| transforms.append(cropper) | |
| else: | |
| cropper = albumentations.RandomCrop(height=size,width=size) | |
| transforms.append(cropper) | |
| flipper = albumentations.HorizontalFlip() | |
| transforms.append(flipper) | |
| preprocessor = albumentations.Compose(transforms, | |
| additional_targets=additional_targets) | |
| elif crop_size is not None and crop_size > 0: | |
| if not random_crop: | |
| cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) | |
| else: | |
| cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) | |
| transforms = [cropper] | |
| preprocessor = albumentations.Compose(transforms, | |
| additional_targets=additional_targets) | |
| else: | |
| preprocessor = lambda **kwargs: kwargs | |
| return preprocessor | |
| def rgba_to_depth(x): | |
| assert x.dtype == np.uint8 | |
| assert len(x.shape) == 3 and x.shape[2] == 4 | |
| y = x.copy() | |
| y.dtype = np.float32 | |
| y = y.reshape(x.shape[:2]) | |
| return np.ascontiguousarray(y) | |
| class BaseWithDepth(Dataset): | |
| DEFAULT_DEPTH_ROOT="data/imagenet_depth" | |
| def __init__(self, config=None, size=None, random_crop=False, | |
| crop_size=None, root=None): | |
| self.config = config | |
| self.base_dset = self.get_base_dset() | |
| self.preprocessor = get_preprocessor( | |
| size=size, | |
| crop_size=crop_size, | |
| random_crop=random_crop, | |
| additional_targets={"depth": "image"}) | |
| self.crop_size = crop_size | |
| if self.crop_size is not None: | |
| self.rescaler = albumentations.Compose( | |
| [albumentations.SmallestMaxSize(max_size = self.crop_size)], | |
| additional_targets={"depth": "image"}) | |
| if root is not None: | |
| self.DEFAULT_DEPTH_ROOT = root | |
| def __len__(self): | |
| return len(self.base_dset) | |
| def preprocess_depth(self, path): | |
| rgba = np.array(Image.open(path)) | |
| depth = rgba_to_depth(rgba) | |
| depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) | |
| depth = 2.0*depth-1.0 | |
| return depth | |
| def __getitem__(self, i): | |
| e = self.base_dset[i] | |
| e["depth"] = self.preprocess_depth(self.get_depth_path(e)) | |
| # up if necessary | |
| h,w,c = e["image"].shape | |
| if self.crop_size and min(h,w) < self.crop_size: | |
| # have to upscale to be able to crop - this just uses bilinear | |
| out = self.rescaler(image=e["image"], depth=e["depth"]) | |
| e["image"] = out["image"] | |
| e["depth"] = out["depth"] | |
| transformed = self.preprocessor(image=e["image"], depth=e["depth"]) | |
| e["image"] = transformed["image"] | |
| e["depth"] = transformed["depth"] | |
| return e | |
| class ImageNetTrainWithDepth(BaseWithDepth): | |
| # default to random_crop=True | |
| def __init__(self, random_crop=True, sub_indices=None, **kwargs): | |
| self.sub_indices = sub_indices | |
| super().__init__(random_crop=random_crop, **kwargs) | |
| def get_base_dset(self): | |
| if self.sub_indices is None: | |
| return ImageNetTrain() | |
| else: | |
| return ImageNetTrain({"sub_indices": self.sub_indices}) | |
| def get_depth_path(self, e): | |
| fid = os.path.splitext(e["relpath"])[0]+".png" | |
| fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid) | |
| return fid | |
| class ImageNetValidationWithDepth(BaseWithDepth): | |
| def __init__(self, sub_indices=None, **kwargs): | |
| self.sub_indices = sub_indices | |
| super().__init__(**kwargs) | |
| def get_base_dset(self): | |
| if self.sub_indices is None: | |
| return ImageNetValidation() | |
| else: | |
| return ImageNetValidation({"sub_indices": self.sub_indices}) | |
| def get_depth_path(self, e): | |
| fid = os.path.splitext(e["relpath"])[0]+".png" | |
| fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid) | |
| return fid | |
| class RINTrainWithDepth(ImageNetTrainWithDepth): | |
| def __init__(self, config=None, size=None, random_crop=True, crop_size=None): | |
| sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" | |
| super().__init__(config=config, size=size, random_crop=random_crop, | |
| sub_indices=sub_indices, crop_size=crop_size) | |
| class RINValidationWithDepth(ImageNetValidationWithDepth): | |
| def __init__(self, config=None, size=None, random_crop=False, crop_size=None): | |
| sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" | |
| super().__init__(config=config, size=size, random_crop=random_crop, | |
| sub_indices=sub_indices, crop_size=crop_size) | |
| class DRINExamples(Dataset): | |
| def __init__(self): | |
| self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"}) | |
| with open("data/drin_examples.txt", "r") as f: | |
| relpaths = f.read().splitlines() | |
| self.image_paths = [os.path.join("data/drin_images", | |
| relpath) for relpath in relpaths] | |
| self.depth_paths = [os.path.join("data/drin_depth", | |
| relpath.replace(".JPEG", ".png")) for relpath in relpaths] | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def preprocess_image(self, image_path): | |
| image = Image.open(image_path) | |
| if not image.mode == "RGB": | |
| image = image.convert("RGB") | |
| image = np.array(image).astype(np.uint8) | |
| image = self.preprocessor(image=image)["image"] | |
| image = (image/127.5 - 1.0).astype(np.float32) | |
| return image | |
| def preprocess_depth(self, path): | |
| rgba = np.array(Image.open(path)) | |
| depth = rgba_to_depth(rgba) | |
| depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) | |
| depth = 2.0*depth-1.0 | |
| return depth | |
| def __getitem__(self, i): | |
| e = dict() | |
| e["image"] = self.preprocess_image(self.image_paths[i]) | |
| e["depth"] = self.preprocess_depth(self.depth_paths[i]) | |
| transformed = self.preprocessor(image=e["image"], depth=e["depth"]) | |
| e["image"] = transformed["image"] | |
| e["depth"] = transformed["depth"] | |
| return e | |
| def imscale(x, factor, keepshapes=False, keepmode="bicubic"): | |
| if factor is None or factor==1: | |
| return x | |
| dtype = x.dtype | |
| assert dtype in [np.float32, np.float64] | |
| assert x.min() >= -1 | |
| assert x.max() <= 1 | |
| keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, | |
| "bicubic": Image.BICUBIC}[keepmode] | |
| lr = (x+1.0)*127.5 | |
| lr = lr.clip(0,255).astype(np.uint8) | |
| lr = Image.fromarray(lr) | |
| h, w, _ = x.shape | |
| nh = h//factor | |
| nw = w//factor | |
| assert nh > 0 and nw > 0, (nh, nw) | |
| lr = lr.resize((nw,nh), Image.BICUBIC) | |
| if keepshapes: | |
| lr = lr.resize((w,h), keepmode) | |
| lr = np.array(lr)/127.5-1.0 | |
| lr = lr.astype(dtype) | |
| return lr | |
| class ImageNetScale(Dataset): | |
| def __init__(self, size=None, crop_size=None, random_crop=False, | |
| up_factor=None, hr_factor=None, keep_mode="bicubic"): | |
| self.base = self.get_base() | |
| self.size = size | |
| self.crop_size = crop_size if crop_size is not None else self.size | |
| self.random_crop = random_crop | |
| self.up_factor = up_factor | |
| self.hr_factor = hr_factor | |
| self.keep_mode = keep_mode | |
| transforms = list() | |
| if self.size is not None and self.size > 0: | |
| rescaler = albumentations.SmallestMaxSize(max_size = self.size) | |
| self.rescaler = rescaler | |
| transforms.append(rescaler) | |
| if self.crop_size is not None and self.crop_size > 0: | |
| if len(transforms) == 0: | |
| self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size) | |
| if not self.random_crop: | |
| cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size) | |
| else: | |
| cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size) | |
| transforms.append(cropper) | |
| if len(transforms) > 0: | |
| if self.up_factor is not None: | |
| additional_targets = {"lr": "image"} | |
| else: | |
| additional_targets = None | |
| self.preprocessor = albumentations.Compose(transforms, | |
| additional_targets=additional_targets) | |
| else: | |
| self.preprocessor = lambda **kwargs: kwargs | |
| def __len__(self): | |
| return len(self.base) | |
| def __getitem__(self, i): | |
| example = self.base[i] | |
| image = example["image"] | |
| # adjust resolution | |
| image = imscale(image, self.hr_factor, keepshapes=False) | |
| h,w,c = image.shape | |
| if self.crop_size and min(h,w) < self.crop_size: | |
| # have to upscale to be able to crop - this just uses bilinear | |
| image = self.rescaler(image=image)["image"] | |
| if self.up_factor is None: | |
| image = self.preprocessor(image=image)["image"] | |
| example["image"] = image | |
| else: | |
| lr = imscale(image, self.up_factor, keepshapes=True, | |
| keepmode=self.keep_mode) | |
| out = self.preprocessor(image=image, lr=lr) | |
| example["image"] = out["image"] | |
| example["lr"] = out["lr"] | |
| return example | |
| class ImageNetScaleTrain(ImageNetScale): | |
| def __init__(self, random_crop=True, **kwargs): | |
| super().__init__(random_crop=random_crop, **kwargs) | |
| def get_base(self): | |
| return ImageNetTrain() | |
| class ImageNetScaleValidation(ImageNetScale): | |
| def get_base(self): | |
| return ImageNetValidation() | |
| from skimage.feature import canny | |
| from skimage.color import rgb2gray | |
| class ImageNetEdges(ImageNetScale): | |
| def __init__(self, up_factor=1, **kwargs): | |
| super().__init__(up_factor=1, **kwargs) | |
| def __getitem__(self, i): | |
| example = self.base[i] | |
| image = example["image"] | |
| h,w,c = image.shape | |
| if self.crop_size and min(h,w) < self.crop_size: | |
| # have to upscale to be able to crop - this just uses bilinear | |
| image = self.rescaler(image=image)["image"] | |
| lr = canny(rgb2gray(image), sigma=2) | |
| lr = lr.astype(np.float32) | |
| lr = lr[:,:,None][:,:,[0,0,0]] | |
| out = self.preprocessor(image=image, lr=lr) | |
| example["image"] = out["image"] | |
| example["lr"] = out["lr"] | |
| return example | |
| class ImageNetEdgesTrain(ImageNetEdges): | |
| def __init__(self, random_crop=True, **kwargs): | |
| super().__init__(random_crop=random_crop, **kwargs) | |
| def get_base(self): | |
| return ImageNetTrain() | |
| class ImageNetEdgesValidation(ImageNetEdges): | |
| def get_base(self): | |
| return ImageNetValidation() | |