Spaces:
Running
Running
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| import hashlib | |
| import json | |
| from abc import ABC, abstractmethod, abstractstaticmethod | |
| from collections import OrderedDict | |
| from typing import Dict, List | |
| import numpy | |
| import torch | |
| from core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig | |
| from core.datasets.indexed_dataset import MMapIndexedDataset | |
| from core.datasets.utils import Split | |
| class MegatronDataset(ABC, torch.utils.data.Dataset): | |
| """The wrapper class from which dataset classes should inherit e.g. GPTDataset | |
| Args: | |
| indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the | |
| MegatronDataset | |
| indexed_indices (numpy.ndarray): The set of the documents indices to expose | |
| num_samples (int): The number of samples to draw from the indexed dataset | |
| index_split (Split): The indexed_indices Split | |
| config (BlendedMegatronDatasetConfig): The container for all config sourced parameters | |
| """ | |
| def __init__( | |
| self, | |
| indexed_dataset: MMapIndexedDataset, | |
| indexed_indices: numpy.ndarray, | |
| num_samples: int, | |
| index_split: Split, | |
| config: BlendedMegatronDatasetConfig, | |
| ) -> None: | |
| assert indexed_indices.size > 0 | |
| assert num_samples > 0 | |
| assert self.is_multimodal() == indexed_dataset.multimodal | |
| assert self.is_split_by_sequence() != self.is_split_by_document() | |
| self.indexed_dataset = indexed_dataset | |
| self.indexed_indices = indexed_indices | |
| self.num_samples = num_samples | |
| self.index_split = index_split | |
| self.config = config | |
| self.unique_identifiers = OrderedDict() | |
| self.unique_identifiers["class"] = type(self).__name__ | |
| self.unique_identifiers["path_prefix"] = self.indexed_dataset.path_prefix | |
| self.unique_identifiers["num_samples"] = self.num_samples | |
| self.unique_identifiers["index_split"] = self.index_split.name | |
| for attr in self._key_config_attributes(): | |
| self.unique_identifiers[attr] = getattr(self.config, attr) | |
| self.unique_identifiers["add_bos"] = getattr(self.config, "add_bos", False) | |
| self.unique_description = json.dumps(self.unique_identifiers, indent=4) | |
| self.unique_description_hash = hashlib.md5( | |
| self.unique_description.encode("utf-8") | |
| ).hexdigest() | |
| self._finalize() | |
| def _finalize(self) -> None: | |
| """Build the dataset and assert any subclass-specific conditions | |
| """ | |
| pass | |
| def __len__(self) -> int: | |
| """Return the length of the dataset | |
| Returns: | |
| int: See abstract implementation | |
| """ | |
| pass | |
| def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]: | |
| """Return from the dataset | |
| Args: | |
| idx (int): The index into the dataset | |
| Returns: | |
| Dict[str, numpy.ndarray]: See abstract implementation | |
| """ | |
| pass | |
| def is_multimodal() -> bool: | |
| """Return True if the inheritor class and its internal MMapIndexedDataset are multimodal | |
| Returns: | |
| bool: See abstract implementation | |
| """ | |
| pass | |
| def is_split_by_sequence() -> bool: | |
| """Return whether the dataset is split by sequence | |
| For example, the GPT train/valid/test split is document agnostic | |
| Returns: | |
| bool: See abstract implementation | |
| """ | |
| pass | |
| def is_split_by_document(cls) -> bool: | |
| """Return whether the dataset is split by document | |
| For example, the BERT train/valid/test split is document aware | |
| Returns: | |
| bool: The negation of cls.is_split_by_sequence | |
| """ | |
| return not cls.is_split_by_sequence() | |
| def _key_config_attributes() -> List[str]: | |
| """Return all config attributes which contribute to uniquely identifying the dataset. | |
| These attributes will be used to build a uniquely identifying string and MD5 hash which | |
| will be used to cache/load the dataset from run to run. | |
| Returns: | |
| List[str]: The key config attributes | |
| """ | |
| return ["split", "random_seed", "sequence_length"] |