File size: 4,392 Bytes
15389e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import hashlib
import json
from abc import ABC, abstractmethod, abstractstaticmethod
from collections import OrderedDict
from typing import Dict, List

import numpy
import torch

from core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from core.datasets.indexed_dataset import MMapIndexedDataset
from core.datasets.utils import Split


class MegatronDataset(ABC, torch.utils.data.Dataset):
    """The wrapper class from which dataset classes should inherit e.g. GPTDataset

    Args:
        indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the
        MegatronDataset

        indexed_indices (numpy.ndarray): The set of the documents indices to expose

        num_samples (int): The number of samples to draw from the indexed dataset

        index_split (Split): The indexed_indices Split

        config (BlendedMegatronDatasetConfig): The container for all config sourced parameters
    """

    def __init__(
        self,
        indexed_dataset: MMapIndexedDataset,
        indexed_indices: numpy.ndarray,
        num_samples: int,
        index_split: Split,
        config: BlendedMegatronDatasetConfig,
    ) -> None:
        assert indexed_indices.size > 0
        assert num_samples > 0
        assert self.is_multimodal() == indexed_dataset.multimodal
        assert self.is_split_by_sequence() != self.is_split_by_document()

        self.indexed_dataset = indexed_dataset
        self.indexed_indices = indexed_indices
        self.num_samples = num_samples
        self.index_split = index_split
        self.config = config

        self.unique_identifiers = OrderedDict()
        self.unique_identifiers["class"] = type(self).__name__
        self.unique_identifiers["path_prefix"] = self.indexed_dataset.path_prefix
        self.unique_identifiers["num_samples"] = self.num_samples
        self.unique_identifiers["index_split"] = self.index_split.name
        for attr in self._key_config_attributes():
            self.unique_identifiers[attr] = getattr(self.config, attr)
        self.unique_identifiers["add_bos"] = getattr(self.config, "add_bos", False)

        self.unique_description = json.dumps(self.unique_identifiers, indent=4)
        self.unique_description_hash = hashlib.md5(
            self.unique_description.encode("utf-8")
        ).hexdigest()

        self._finalize()

    @abstractmethod
    def _finalize(self) -> None:
        """Build the dataset and assert any subclass-specific conditions
        """
        pass

    @abstractmethod
    def __len__(self) -> int:
        """Return the length of the dataset

        Returns:
            int: See abstract implementation
        """
        pass

    @abstractmethod
    def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
        """Return from the dataset

        Args:
            idx (int): The index into the dataset

        Returns:
            Dict[str, numpy.ndarray]: See abstract implementation
        """
        pass

    @abstractstaticmethod
    def is_multimodal() -> bool:
        """Return True if the inheritor class and its internal MMapIndexedDataset are multimodal

        Returns:
            bool: See abstract implementation
        """
        pass

    @abstractstaticmethod
    def is_split_by_sequence() -> bool:
        """Return whether the dataset is split by sequence

        For example, the GPT train/valid/test split is document agnostic

        Returns:
            bool: See abstract implementation
        """
        pass

    @classmethod
    def is_split_by_document(cls) -> bool:
        """Return whether the dataset is split by document

        For example, the BERT train/valid/test split is document aware

        Returns:
            bool: The negation of cls.is_split_by_sequence
        """
        return not cls.is_split_by_sequence()

    @staticmethod
    def _key_config_attributes() -> List[str]:
        """Return all config attributes which contribute to uniquely identifying the dataset.

        These attributes will be used to build a uniquely identifying string and MD5 hash which
        will be used to cache/load the dataset from run to run.

        Returns:
            List[str]: The key config attributes
        """
        return ["split", "random_seed", "sequence_length"]