Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/data/custom_dataset_dataloader.py | |
| import operator | |
| import torch | |
| import torch.utils.data | |
| from detectron2.utils.comm import get_world_size | |
| from detectron2.config import configurable | |
| from torch.utils.data.sampler import BatchSampler, Sampler | |
| from detectron2.data.common import DatasetFromList, MapDataset | |
| from detectron2.data.dataset_mapper import DatasetMapper | |
| from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader | |
| from detectron2.data.samplers import TrainingSampler | |
| from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram | |
| from detectron2.data.build import filter_images_with_only_crowd_annotations | |
| from detectron2.data.build import filter_images_with_few_keypoints | |
| from detectron2.data.build import check_metadata_consistency | |
| from detectron2.data.catalog import MetadataCatalog, DatasetCatalog | |
| from detectron2.utils import comm | |
| import itertools | |
| from typing import Optional | |
| def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): | |
| sampler_name = cfg.DATALOADER.SAMPLER_TRAIN | |
| if 'MultiDataset' in sampler_name: | |
| dataset_dicts = get_detection_dataset_dicts_with_source( | |
| cfg.DATASETS.TRAIN, | |
| filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, | |
| min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE | |
| if cfg.MODEL.KEYPOINT_ON else 0, | |
| proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, | |
| ) | |
| else: | |
| dataset_dicts = get_detection_dataset_dicts( | |
| cfg.DATASETS.TRAIN, | |
| filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, | |
| min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE | |
| if cfg.MODEL.KEYPOINT_ON else 0, | |
| proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, | |
| ) | |
| if mapper is None: | |
| mapper = DatasetMapper(cfg, True) | |
| if sampler is not None: | |
| pass | |
| elif sampler_name == "TrainingSampler": | |
| sampler = TrainingSampler(len(dataset)) | |
| elif sampler_name == "MultiDatasetSampler": | |
| sampler = MultiDatasetSampler( | |
| dataset_dicts, | |
| dataset_ratio=cfg.DATALOADER.DATASET_RATIO, | |
| ) | |
| else: | |
| raise ValueError("Unknown training sampler: {}".format(sampler_name)) | |
| return { | |
| "dataset": dataset_dicts, | |
| "sampler": sampler, | |
| "mapper": mapper, | |
| "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, | |
| "num_workers": cfg.DATALOADER.NUM_WORKERS, | |
| 'dataset_bs': cfg.DATALOADER.DATASET_BS, | |
| 'num_datasets': len(cfg.DATASETS.TRAIN) | |
| } | |
| def build_custom_train_loader( | |
| dataset, *, mapper, sampler, | |
| total_batch_size=16, | |
| num_workers=0, | |
| num_datasets=1, | |
| dataset_bs=1 | |
| ): | |
| if isinstance(dataset, list): | |
| dataset = DatasetFromList(dataset, copy=False) | |
| if mapper is not None: | |
| dataset = MapDataset(dataset, mapper) | |
| if sampler is None: | |
| sampler = TrainingSampler(len(dataset)) | |
| assert isinstance(sampler, torch.utils.data.sampler.Sampler) | |
| return build_dataset_batch_data_loader( | |
| dataset_bs, | |
| dataset, | |
| sampler, | |
| total_batch_size, | |
| num_datasets=num_datasets, | |
| num_workers=num_workers, | |
| ) | |
| def build_dataset_batch_data_loader( | |
| dataset_bs, dataset, sampler, total_batch_size, num_datasets, num_workers=0 | |
| ): | |
| world_size = get_world_size() | |
| assert ( | |
| total_batch_size > 0 and total_batch_size % world_size == 0 | |
| ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( | |
| total_batch_size, world_size | |
| ) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| sampler=sampler, | |
| num_workers=num_workers, | |
| batch_sampler=None, | |
| collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements | |
| worker_init_fn=worker_init_reset_seed, | |
| ) | |
| if num_datasets > 1: | |
| return MultiDatasets(data_loader, dataset_bs, num_datasets) | |
| else: | |
| return SingleDataset(data_loader, dataset_bs) | |
| def get_detection_dataset_dicts_with_source( | |
| dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None | |
| ): | |
| assert len(dataset_names) | |
| dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] | |
| for dataset_name, dicts in zip(dataset_names, dataset_dicts): | |
| assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) | |
| for source_id, (dataset_name, dicts) in \ | |
| enumerate(zip(dataset_names, dataset_dicts)): | |
| assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) | |
| for d in dicts: | |
| d['dataset_source'] = source_id | |
| if "annotations" in dicts[0]: | |
| try: | |
| class_names = MetadataCatalog.get(dataset_name).thing_classes | |
| check_metadata_consistency("thing_classes", dataset_name) | |
| print_instances_class_histogram(dicts, class_names) | |
| except AttributeError: # class names are not available for this dataset | |
| pass | |
| assert proposal_files is None | |
| dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) | |
| has_instances = "annotations" in dataset_dicts[0] | |
| if filter_empty and has_instances: | |
| dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) | |
| if min_keypoints > 0 and has_instances: | |
| dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) | |
| return dataset_dicts | |
| class MultiDatasetSampler(Sampler): | |
| def __init__( | |
| self, | |
| dataset_dicts, | |
| dataset_ratio, | |
| seed: Optional[int] = None, | |
| ): | |
| sizes = [0 for _ in range(len(dataset_ratio))] | |
| for d in dataset_dicts: | |
| sizes[d['dataset_source']] += 1 | |
| print('dataset sizes', sizes) | |
| self.sizes = sizes | |
| assert len(dataset_ratio) == len(sizes), \ | |
| 'length of dataset ratio {} should be equal to number if dataset {}'.format( | |
| len(dataset_ratio), len(sizes) | |
| ) | |
| if seed is None: | |
| seed = comm.shared_random_seed() | |
| self._seed = int(seed) | |
| self._rank = comm.get_rank() | |
| self._world_size = comm.get_world_size() | |
| self.dataset_ids = torch.tensor( | |
| [d['dataset_source'] for d in dataset_dicts], dtype=torch.long) | |
| self.dataset_ratio = dataset_ratio | |
| dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \ | |
| for i, (r, s) in enumerate(zip(dataset_ratio, sizes))] | |
| dataset_weight = torch.cat(dataset_weight) | |
| self.weights = dataset_weight | |
| self.sample_epoch_size = len(self.weights) | |
| def __iter__(self): | |
| start = self._rank | |
| yield from itertools.islice( | |
| self._infinite_indices(), start, None, self._world_size) | |
| def _infinite_indices(self): | |
| g = torch.Generator() | |
| g.manual_seed(self._seed) | |
| while True: | |
| if len(self.dataset_ratio) > 1: | |
| # multiple datasets | |
| ids = torch.multinomial( | |
| self.weights, self.sample_epoch_size, generator=g, | |
| replacement=True) | |
| nums = [(self.dataset_ids[ids] == i).sum().int().item() \ | |
| for i in range(len(self.sizes))] | |
| yield from ids | |
| else: | |
| # single dataset | |
| yield from torch.randperm(self.sizes[0], generator=g).tolist() | |
| class SingleDataset(torch.utils.data.IterableDataset): | |
| def __init__(self, dataset, batch_sizes): | |
| self.dataset = dataset | |
| self.batch_sizes = batch_sizes | |
| self._buckets = [[] for _ in range(2)] | |
| def __iter__(self): | |
| for d in self.dataset: | |
| w, h = d["width"], d["height"] | |
| aspect_ratio_bucket_id = 0 if w > h else 1 | |
| bucket_id = aspect_ratio_bucket_id | |
| bucket = self._buckets[bucket_id] | |
| bucket.append(d) | |
| if len(bucket) == self.batch_sizes: | |
| yield bucket[:] | |
| del bucket[:] | |
| class MultiDatasets(torch.utils.data.IterableDataset): | |
| def __init__(self, dataset, batch_sizes, num_datasets): | |
| self.dataset = dataset | |
| self.batch_sizes = batch_sizes | |
| self._buckets = [[] for _ in range(2 * num_datasets)] | |
| self.iter_idx = 0 | |
| self.num_datasets = num_datasets | |
| def __iter__(self): | |
| for d in self.dataset: | |
| w, h = d["width"], d["height"] | |
| aspect_ratio_bucket_id = 0 if w > h else 1 | |
| bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id | |
| bucket = self._buckets[bucket_id] | |
| if len(bucket) < self.batch_sizes: | |
| bucket.append(d) | |
| selected_dataset = self.iter_idx % self.num_datasets | |
| if len(bucket) == self.batch_sizes and selected_dataset == d['dataset_source']: | |
| self.iter_idx += 1 | |
| yield bucket[:] | |
| del bucket[:] |