Spaces:
Build error
Build error
| import numbers | |
| import os | |
| import queue as Queue | |
| import threading | |
| from functools import partial | |
| from typing import Iterable | |
| import mxnet as mx | |
| import numpy as np | |
| import torch | |
| from torch import distributed | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from torchvision.datasets import ImageFolder | |
| from utils.utils_distributed_sampler import DistributedSampler | |
| from utils.utils_distributed_sampler import get_dist_info | |
| from utils.utils_distributed_sampler import worker_init_fn | |
| def get_dataloader( | |
| root_dir, | |
| local_rank, | |
| batch_size, | |
| dali=False, | |
| seed=2048, | |
| num_workers=2, | |
| ) -> Iterable: | |
| rec = os.path.join(root_dir, "train.rec") | |
| idx = os.path.join(root_dir, "train.idx") | |
| train_set = None | |
| # Synthetic | |
| if root_dir == "synthetic": | |
| train_set = SyntheticDataset() | |
| dali = False | |
| # Mxnet RecordIO | |
| elif os.path.exists(rec) and os.path.exists(idx): | |
| train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank) | |
| # Image Folder | |
| else: | |
| transform = transforms.Compose( | |
| [ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| train_set = ImageFolder(root_dir, transform) | |
| # DALI | |
| if dali: | |
| return dali_data_iter(batch_size=batch_size, rec_file=rec, idx_file=idx, num_threads=2, local_rank=local_rank) | |
| rank, world_size = get_dist_info() | |
| train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed) | |
| if seed is None: | |
| init_fn = None | |
| else: | |
| init_fn = partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) | |
| train_loader = DataLoaderX( | |
| local_rank=local_rank, | |
| dataset=train_set, | |
| batch_size=batch_size, | |
| sampler=train_sampler, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| drop_last=True, | |
| worker_init_fn=init_fn, | |
| ) | |
| return train_loader | |
| class BackgroundGenerator(threading.Thread): | |
| def __init__(self, generator, local_rank, max_prefetch=6): | |
| super(BackgroundGenerator, self).__init__() | |
| self.queue = Queue.Queue(max_prefetch) | |
| self.generator = generator | |
| self.local_rank = local_rank | |
| self.daemon = True | |
| self.start() | |
| def run(self): | |
| torch.cuda.set_device(self.local_rank) | |
| for item in self.generator: | |
| self.queue.put(item) | |
| self.queue.put(None) | |
| def next(self): | |
| next_item = self.queue.get() | |
| if next_item is None: | |
| raise StopIteration | |
| return next_item | |
| def __next__(self): | |
| return self.next() | |
| def __iter__(self): | |
| return self | |
| class DataLoaderX(DataLoader): | |
| def __init__(self, local_rank, **kwargs): | |
| super(DataLoaderX, self).__init__(**kwargs) | |
| self.stream = torch.cuda.Stream(local_rank) | |
| self.local_rank = local_rank | |
| def __iter__(self): | |
| self.iter = super(DataLoaderX, self).__iter__() | |
| self.iter = BackgroundGenerator(self.iter, self.local_rank) | |
| self.preload() | |
| return self | |
| def preload(self): | |
| self.batch = next(self.iter, None) | |
| if self.batch is None: | |
| return None | |
| with torch.cuda.stream(self.stream): | |
| for k in range(len(self.batch)): | |
| self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) | |
| def __next__(self): | |
| torch.cuda.current_stream().wait_stream(self.stream) | |
| batch = self.batch | |
| if batch is None: | |
| raise StopIteration | |
| self.preload() | |
| return batch | |
| class MXFaceDataset(Dataset): | |
| def __init__(self, root_dir, local_rank): | |
| super(MXFaceDataset, self).__init__() | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.ToPILImage(), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| self.root_dir = root_dir | |
| self.local_rank = local_rank | |
| path_imgrec = os.path.join(root_dir, "train.rec") | |
| path_imgidx = os.path.join(root_dir, "train.idx") | |
| self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r") | |
| s = self.imgrec.read_idx(0) | |
| header, _ = mx.recordio.unpack(s) | |
| if header.flag > 0: | |
| self.header0 = (int(header.label[0]), int(header.label[1])) | |
| self.imgidx = np.array(range(1, int(header.label[0]))) | |
| else: | |
| self.imgidx = np.array(list(self.imgrec.keys)) | |
| def __getitem__(self, index): | |
| idx = self.imgidx[index] | |
| s = self.imgrec.read_idx(idx) | |
| header, img = mx.recordio.unpack(s) | |
| label = header.label | |
| if not isinstance(label, numbers.Number): | |
| label = label[0] | |
| label = torch.tensor(label, dtype=torch.long) | |
| sample = mx.image.imdecode(img).asnumpy() | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| return sample, label | |
| def __len__(self): | |
| return len(self.imgidx) | |
| class SyntheticDataset(Dataset): | |
| def __init__(self): | |
| super(SyntheticDataset, self).__init__() | |
| img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) | |
| img = np.transpose(img, (2, 0, 1)) | |
| img = torch.from_numpy(img).squeeze(0).float() | |
| img = ((img / 255) - 0.5) / 0.5 | |
| self.img = img | |
| self.label = 1 | |
| def __getitem__(self, index): | |
| return self.img, self.label | |
| def __len__(self): | |
| return 1000000 | |
| def dali_data_iter( | |
| batch_size: int, | |
| rec_file: str, | |
| idx_file: str, | |
| num_threads: int, | |
| initial_fill=32768, | |
| random_shuffle=True, | |
| prefetch_queue_depth=1, | |
| local_rank=0, | |
| name="reader", | |
| mean=(127.5, 127.5, 127.5), | |
| std=(127.5, 127.5, 127.5), | |
| ): | |
| """ | |
| Parameters: | |
| ---------- | |
| initial_fill: int | |
| Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored. | |
| """ | |
| rank: int = distributed.get_rank() | |
| world_size: int = distributed.get_world_size() | |
| import nvidia.dali.fn as fn | |
| import nvidia.dali.types as types | |
| from nvidia.dali.pipeline import Pipeline | |
| from nvidia.dali.plugin.pytorch import DALIClassificationIterator | |
| pipe = Pipeline( | |
| batch_size=batch_size, | |
| num_threads=num_threads, | |
| device_id=local_rank, | |
| prefetch_queue_depth=prefetch_queue_depth, | |
| ) | |
| condition_flip = fn.random.coin_flip(probability=0.5) | |
| with pipe: | |
| jpegs, labels = fn.readers.mxnet( | |
| path=rec_file, | |
| index_path=idx_file, | |
| initial_fill=initial_fill, | |
| num_shards=world_size, | |
| shard_id=rank, | |
| random_shuffle=random_shuffle, | |
| pad_last_batch=False, | |
| name=name, | |
| ) | |
| images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB) | |
| images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip) | |
| pipe.set_outputs(images, labels) | |
| pipe.build() | |
| return DALIWarper( | |
| DALIClassificationIterator( | |
| pipelines=[pipe], | |
| reader_name=name, | |
| ) | |
| ) | |
| class DALIWarper(object): | |
| def __init__(self, dali_iter): | |
| self.iter = dali_iter | |
| def __next__(self): | |
| data_dict = self.iter.__next__()[0] | |
| tensor_data = data_dict["data"].cuda() | |
| tensor_label: torch.Tensor = data_dict["label"].cuda().long() | |
| tensor_label.squeeze_() | |
| return tensor_data, tensor_label | |
| def __iter__(self): | |
| return self | |
| def reset(self): | |
| self.iter.reset() | |