Spaces:
Configuration error
Configuration error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| import torch.utils.data as torchdata | |
| from custom_detectron2.config import configurable | |
| from custom_detectron2.data.common import DatasetFromList, MapDataset | |
| from custom_detectron2.data.dataset_mapper import DatasetMapper | |
| from custom_detectron2.data.samplers import ( | |
| InferenceSampler, | |
| ) | |
| from custom_detectron2.data.build import ( | |
| get_detection_dataset_dicts, | |
| trivial_batch_collator | |
| ) | |
| """ | |
| This file contains the default logic to build a dataloader for training or testing. | |
| """ | |
| __all__ = [ | |
| "build_detection_test_loader", | |
| ] | |
| def _test_loader_from_config(cfg, dataset_name, mapper=None): | |
| """ | |
| Uses the given `dataset_name` argument (instead of the names in cfg), because the | |
| standard practice is to evaluate each test set individually (not combining them). | |
| """ | |
| if isinstance(dataset_name, str): | |
| dataset_name = [dataset_name] | |
| dataset = get_detection_dataset_dicts( | |
| dataset_name, | |
| filter_empty=False, | |
| proposal_files=[ | |
| cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name | |
| ] | |
| if cfg.MODEL.LOAD_PROPOSALS | |
| else None, | |
| ) | |
| if mapper is None: | |
| mapper = DatasetMapper(cfg, False) | |
| return { | |
| "dataset": dataset, | |
| "mapper": mapper, | |
| "num_workers": cfg.DATALOADER.NUM_WORKERS, | |
| "sampler": InferenceSampler(len(dataset)) | |
| if not isinstance(dataset, torchdata.IterableDataset) | |
| else None, | |
| } | |
| def build_detection_test_loader( | |
| dataset: Union[List[Any], torchdata.Dataset], | |
| *, | |
| mapper: Callable[[Dict[str, Any]], Any], | |
| sampler: Optional[torchdata.Sampler] = None, | |
| batch_size: int = 1, | |
| num_workers: int = 0, | |
| collate_fn: Optional[Callable[[List[Any]], Any]] = None, | |
| ) -> torchdata.DataLoader: | |
| """ | |
| Similar to `build_detection_train_loader`, with default batch size = 1, | |
| and sampler = :class:`InferenceSampler`. This sampler coordinates all workers | |
| to produce the exact set of all samples. | |
| Args: | |
| dataset: a list of dataset dicts, | |
| or a pytorch dataset (either map-style or iterable). They can be obtained | |
| by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. | |
| mapper: a callable which takes a sample (dict) from dataset | |
| and returns the format to be consumed by the model. | |
| When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. | |
| sampler: a sampler that produces | |
| indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, | |
| which splits the dataset across all workers. Sampler must be None | |
| if `dataset` is iterable. | |
| batch_size: the batch size of the data loader to be created. | |
| Default to 1 image per worker since this is the standard when reporting | |
| inference time in papers. | |
| num_workers: number of parallel data loading workers | |
| collate_fn: same as the argument of `torch.utils.data.DataLoader`. | |
| Defaults to do no collation and return a list of data. | |
| Returns: | |
| DataLoader: a torch DataLoader, that loads the given detection | |
| dataset, with test-time transformation and batching. | |
| Examples: | |
| :: | |
| data_loader = build_detection_test_loader( | |
| DatasetRegistry.get("my_test"), | |
| mapper=DatasetMapper(...)) | |
| # or, instantiate with a CfgNode: | |
| data_loader = build_detection_test_loader(cfg, "my_test") | |
| """ | |
| if isinstance(dataset, list): | |
| dataset = DatasetFromList(dataset, copy=False) | |
| if mapper is not None: | |
| dataset = MapDataset(dataset, mapper) | |
| if isinstance(dataset, torchdata.IterableDataset): | |
| assert sampler is None, "sampler must be None if dataset is IterableDataset" | |
| else: | |
| if sampler is None: | |
| sampler = InferenceSampler(len(dataset)) | |
| return torchdata.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| sampler=sampler, | |
| drop_last=False, | |
| num_workers=num_workers, | |
| collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, | |
| ) |