Spaces:
Runtime error
Runtime error
| # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
| # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
| # International Conference on Computer Vision (ICCV), 2023 | |
| import copy | |
| import math | |
| import os | |
| import torchvision.transforms as transforms | |
| from torchvision.datasets import ImageFolder | |
| from efficientvit.apps.data_provider import DataProvider | |
| from efficientvit.apps.data_provider.augment import RandAug | |
| from efficientvit.apps.data_provider.random_resolution import MyRandomResizedCrop, get_interpolate | |
| from efficientvit.apps.utils import partial_update_config | |
| from efficientvit.models.utils import val2list | |
| __all__ = ["ImageNetDataProvider"] | |
| class ImageNetDataProvider(DataProvider): | |
| name = "imagenet" | |
| data_dir = "/dataset/imagenet" | |
| n_classes = 1000 | |
| _DEFAULT_RRC_CONFIG = { | |
| "train_interpolate": "random", | |
| "test_interpolate": "bicubic", | |
| "test_crop_ratio": 1.0, | |
| } | |
| def __init__( | |
| self, | |
| data_dir: str or None = None, | |
| rrc_config: dict or None = None, | |
| data_aug: dict or list[dict] or None = None, | |
| ########################################### | |
| train_batch_size=128, | |
| test_batch_size=128, | |
| valid_size: int or float or None = None, | |
| n_worker=8, | |
| image_size: int or list[int] = 224, | |
| num_replicas: int or None = None, | |
| rank: int or None = None, | |
| train_ratio: float or None = None, | |
| drop_last: bool = False, | |
| ): | |
| self.data_dir = data_dir or self.data_dir | |
| self.rrc_config = partial_update_config( | |
| copy.deepcopy(self._DEFAULT_RRC_CONFIG), | |
| rrc_config or {}, | |
| ) | |
| self.data_aug = data_aug | |
| super().__init__( | |
| train_batch_size, | |
| test_batch_size, | |
| valid_size, | |
| n_worker, | |
| image_size, | |
| num_replicas, | |
| rank, | |
| train_ratio, | |
| drop_last, | |
| ) | |
| def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any: | |
| image_size = (image_size or self.active_image_size)[0] | |
| crop_size = int(math.ceil(image_size / self.rrc_config["test_crop_ratio"])) | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| crop_size, | |
| interpolation=get_interpolate(self.rrc_config["test_interpolate"]), | |
| ), | |
| transforms.CenterCrop(image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(**self.mean_std), | |
| ] | |
| ) | |
| def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any: | |
| image_size = image_size or self.image_size | |
| # random_resize_crop -> random_horizontal_flip | |
| train_transforms = [ | |
| MyRandomResizedCrop(interpolation=self.rrc_config["train_interpolate"]), | |
| transforms.RandomHorizontalFlip(), | |
| ] | |
| # data augmentation | |
| post_aug = [] | |
| if self.data_aug is not None: | |
| for aug_op in val2list(self.data_aug): | |
| if aug_op["name"] == "randaug": | |
| data_aug = RandAug(aug_op, mean=self.mean_std["mean"]) | |
| elif aug_op["name"] == "erase": | |
| from timm.data.random_erasing import RandomErasing | |
| random_erase = RandomErasing(aug_op["p"], device="cpu") | |
| post_aug.append(random_erase) | |
| data_aug = None | |
| else: | |
| raise NotImplementedError | |
| if data_aug is not None: | |
| train_transforms.append(data_aug) | |
| train_transforms = [ | |
| *train_transforms, | |
| transforms.ToTensor(), | |
| transforms.Normalize(**self.mean_std), | |
| *post_aug, | |
| ] | |
| return transforms.Compose(train_transforms) | |
| def build_datasets(self) -> tuple[any, any, any]: | |
| train_transform = self.build_train_transform() | |
| valid_transform = self.build_valid_transform() | |
| train_dataset = ImageFolder(os.path.join(self.data_dir, "train"), train_transform) | |
| test_dataset = ImageFolder(os.path.join(self.data_dir, "val"), valid_transform) | |
| train_dataset, val_dataset = self.sample_val_dataset(train_dataset, valid_transform) | |
| return train_dataset, val_dataset, test_dataset | |