Spaces:
Running
Running
| # Copyright (c) EPFL VILAB. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # Based on BEiT, timm, DINO, DeiT and MAE-priv code bases | |
| # https://github.com/microsoft/unilm/tree/master/beit | |
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
| # https://github.com/facebookresearch/deit | |
| # https://github.com/facebookresearch/dino | |
| # https://github.com/BUPT-PRIV/MAE-priv | |
| # -------------------------------------------------------- | |
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| from torchvision import datasets, transforms | |
| from utils import create_transform | |
| from .data_constants import (IMAGE_TASKS, IMAGENET_DEFAULT_MEAN, | |
| IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, | |
| IMAGENET_INCEPTION_STD) | |
| from .dataset_folder import ImageFolder, MultiTaskImageFolder | |
| def denormalize(img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): | |
| return TF.normalize( | |
| img.clone(), | |
| mean= [-m/s for m, s in zip(mean, std)], | |
| std= [1/s for s in std] | |
| ) | |
| class DataAugmentationForMAE(object): | |
| def __init__(self, args): | |
| imagenet_default_mean_and_std = args.imagenet_default_mean_and_std | |
| mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN | |
| std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD | |
| trans = [transforms.RandomResizedCrop(args.input_size)] | |
| if args.hflip > 0.0: | |
| trans.append(transforms.RandomHorizontalFlip(args.hflip)) | |
| trans.extend([ | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=torch.tensor(mean), | |
| std=torch.tensor(std))]) | |
| self.transform = transforms.Compose(trans) | |
| def __call__(self, image): | |
| return self.transform(image) | |
| def __repr__(self): | |
| repr = "(DataAugmentationForBEiT,\n" | |
| repr += " transform = %s,\n" % str(self.transform) | |
| repr += ")" | |
| return repr | |
| class DataAugmentationForMultiMAE(object): | |
| def __init__(self, args): | |
| imagenet_default_mean_and_std = args.imagenet_default_mean_and_std | |
| self.rgb_mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN | |
| self.rgb_std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD | |
| self.input_size = args.input_size | |
| self.hflip = args.hflip | |
| def __call__(self, task_dict): | |
| flip = random.random() < self.hflip # Stores whether to flip all images or not | |
| ijhw = None # Stores crop coordinates used for all tasks | |
| # Crop and flip all tasks randomly, but consistently for all tasks | |
| for task in task_dict: | |
| if task not in IMAGE_TASKS: | |
| continue | |
| if ijhw is None: | |
| # Official MAE code uses (0.2, 1.0) for scale and (0.75, 1.3333) for ratio | |
| ijhw = transforms.RandomResizedCrop.get_params( | |
| task_dict[task], scale=(0.2, 1.0), ratio=(0.75, 1.3333) | |
| ) | |
| i, j, h, w = ijhw | |
| task_dict[task] = TF.crop(task_dict[task], i, j, h, w) | |
| task_dict[task] = task_dict[task].resize((self.input_size, self.input_size)) | |
| if flip: | |
| task_dict[task] = TF.hflip(task_dict[task]) | |
| # Convert to Tensor | |
| for task in task_dict: | |
| if task in ['depth']: | |
| img = torch.Tensor(np.array(task_dict[task]) / 2 ** 16) | |
| img = img.unsqueeze(0) # 1 x H x W | |
| elif task in ['rgb']: | |
| img = TF.to_tensor(task_dict[task]) | |
| img = TF.normalize(img, mean=self.rgb_mean, std=self.rgb_std) | |
| elif task in ['semseg', 'semseg_coco']: | |
| # TODO: add this to a config instead | |
| # Rescale to 0.25x size (stride 4) | |
| scale_factor = 0.25 | |
| img = task_dict[task].resize((int(self.input_size * scale_factor), int(self.input_size * scale_factor))) | |
| # Using pil_to_tensor keeps it in uint8, to_tensor converts it to float (rescaled to [0, 1]) | |
| img = TF.pil_to_tensor(img).to(torch.long).squeeze(0) | |
| task_dict[task] = img | |
| return task_dict | |
| def __repr__(self): | |
| repr = "(DataAugmentationForMultiMAE,\n" | |
| #repr += " transform = %s,\n" % str(self.transform) | |
| repr += ")" | |
| return repr | |
| def build_pretraining_dataset(args): | |
| transform = DataAugmentationForMAE(args) | |
| print("Data Aug = %s" % str(transform)) | |
| return ImageFolder(args.data_path, transform=transform) | |
| def build_multimae_pretraining_dataset(args): | |
| transform = DataAugmentationForMultiMAE(args) | |
| return MultiTaskImageFolder(args.data_path, args.all_domains, transform=transform) | |
| def build_dataset(is_train, args): | |
| transform = build_transform(is_train, args) | |
| print("Transform = ") | |
| if isinstance(transform, tuple): | |
| for trans in transform: | |
| print(" - - - - - - - - - - ") | |
| for t in trans.transforms: | |
| print(t) | |
| else: | |
| for t in transform.transforms: | |
| print(t) | |
| print("---------------------------") | |
| if args.data_set == 'CIFAR': | |
| dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) | |
| nb_classes = 100 | |
| elif args.data_set == 'IMNET': | |
| # root = os.path.join(args.data_path, 'train' if is_train else 'val') | |
| root = args.data_path if is_train else args.eval_data_path | |
| dataset = datasets.ImageFolder(root, transform=transform) | |
| nb_classes = 1000 | |
| elif args.data_set == "image_folder": | |
| root = args.data_path if is_train else args.eval_data_path | |
| dataset = ImageFolder(root, transform=transform) | |
| nb_classes = args.nb_classes | |
| assert len(dataset.class_to_idx) == nb_classes | |
| else: | |
| raise NotImplementedError() | |
| assert nb_classes == args.nb_classes | |
| print("Number of the class = %d" % args.nb_classes) | |
| return dataset, nb_classes | |
| def build_transform(is_train, args): | |
| resize_im = args.input_size > 32 | |
| imagenet_default_mean_and_std = args.imagenet_default_mean_and_std | |
| mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN | |
| std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD | |
| if is_train: | |
| # this should always dispatch to transforms_imagenet_train | |
| transform = create_transform( | |
| input_size=args.input_size, | |
| is_training=True, | |
| color_jitter=args.color_jitter, | |
| auto_augment=args.aa, | |
| interpolation=args.train_interpolation, | |
| re_prob=args.reprob, | |
| re_mode=args.remode, | |
| re_count=args.recount, | |
| mean=mean, | |
| std=std, | |
| ) | |
| if not resize_im: | |
| # replace RandomResizedCropAndInterpolation with | |
| # RandomCrop | |
| transform.transforms[0] = transforms.RandomCrop( | |
| args.input_size, padding=4) | |
| return transform | |
| t = [] | |
| if resize_im: | |
| if args.crop_pct is None: | |
| if args.input_size < 384: | |
| args.crop_pct = 224 / 256 | |
| else: | |
| args.crop_pct = 1.0 | |
| size = int(args.input_size / args.crop_pct) | |
| t.append( | |
| transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images | |
| ) | |
| t.append(transforms.CenterCrop(args.input_size)) | |
| t.append(transforms.ToTensor()) | |
| t.append(transforms.Normalize(mean, std)) | |
| return transforms.Compose(t) | |