Spaces:
Sleeping
Sleeping
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| import logging | |
| import os | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Dict, Tuple, Union | |
| import numpy | |
| import torch | |
| # from megatron import get_tokenizer | |
| from core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig | |
| from core.datasets.indexed_dataset import MMapIndexedDataset | |
| from core.datasets.megatron_dataset import MegatronDataset | |
| from core.datasets.utils import Split, log_single_rank | |
| logger = logging.getLogger(__name__) | |
| class GPTDatasetConfig(BlendedMegatronDatasetConfig): | |
| """Configuration object for Megatron Core GPT datasets | |
| Attributes: | |
| return_document_ids (bool): Whether to return the document ids when querying the dataset. | |
| reset_position_ids (bool): Option to reset the position IDs in the dataset at an interval | |
| reset_attention_mask (bool): Option to reset the attention mask from the dataset | |
| eod_mask_loss (bool): Option to enable the EOD mask loss | |
| eod_id (int): Has the identity of the end of document | |
| """ | |
| return_document_ids: bool = False | |
| reset_position_ids: bool = False | |
| reset_attention_mask: bool = False | |
| eod_mask_loss: bool = False | |
| eod_id: int = 0 | |
| add_bos: bool = False | |
| enable_shuffle: bool = False | |
| class GPTDataset(MegatronDataset): | |
| """The base GPT dataset | |
| Args: | |
| indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the | |
| MegatronDataset | |
| indexed_indices (numpy.ndarray): The set of the documents indices to expose | |
| num_samples (int): The number of samples to draw from the indexed dataset | |
| index_split (Split): The indexed_indices Split | |
| config (GPTDatasetConfig): The GPT-specific container for all config sourced parameters | |
| """ | |
| def __init__( | |
| self, | |
| indexed_dataset: MMapIndexedDataset, | |
| indexed_indices: numpy.ndarray, | |
| num_samples: int, | |
| index_split: Split, | |
| config: GPTDatasetConfig, | |
| ) -> None: | |
| super().__init__(indexed_dataset, indexed_indices, num_samples, index_split, config) | |
| # tokenizer = get_tokenizer() | |
| # self.bos_id = tokenizer.bos | |
| # self.eod_id = tokenizer.eod | |
| def _finalize(self) -> None: | |
| """Abstract method implementation | |
| Load or build/cache the document, sample, and shuffle indices | |
| """ | |
| assert isinstance(self.config, GPTDatasetConfig) | |
| ( | |
| self.document_index, | |
| self.sample_index, | |
| self.shuffle_index, | |
| ) = self._build_document_sample_shuffle_indices() | |
| def __len__(self) -> int: | |
| """Abstract method implementation | |
| Returns: | |
| int: The length of the dataset | |
| """ | |
| return self.sample_index.shape[0] - 1 | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| """Abstract method implementation | |
| Args: | |
| idx (int): The index into the dataset | |
| Returns: | |
| Dict[str, torch.Tensor]: The text ids wrapped in a dictionary | |
| """ | |
| text, _ = self._query_document_sample_shuffle_indices(idx) | |
| text = torch.from_numpy(text) | |
| tokens_ = text.long() | |
| labels = tokens_[1:].contiguous() | |
| tokens = tokens_[:-1].contiguous() | |
| attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids( | |
| tokens, | |
| self.config.eod_id, | |
| self.config.reset_position_ids, | |
| self.config.reset_attention_mask, | |
| self.config.eod_mask_loss, | |
| ) | |
| return { | |
| "input_ids": tokens, | |
| "labels": labels, | |
| "attention_mask": attention_mask, | |
| "loss_mask": loss_mask, | |
| "position_ids": position_ids, | |
| } | |
| def is_multimodal() -> bool: | |
| """Abstract method implementation | |
| Returns: | |
| bool: False | |
| """ | |
| return False | |
| def is_split_by_sequence() -> bool: | |
| """Abstract method implementation | |
| Returns: | |
| bool: True | |
| """ | |
| return True | |
| def _query_document_sample_shuffle_indices( | |
| self, idx: int | |
| ) -> Tuple[numpy.ndarray, numpy.ndarray]: | |
| """Get the text (token ids) and document ids for a given index | |
| Args: | |
| idx (int): The index into the dataset | |
| Returns: | |
| Tuple[numpy.ndarray, numpy.ndarray]: The text ids and document ids | |
| """ | |
| # Do the shuffle mapping | |
| idx = self.shuffle_index[idx] | |
| # Get the beginning and end documents and offsets | |
| doc_index_beg, doc_index_beg_offset = self.sample_index[idx] | |
| doc_index_end, doc_index_end_offset = self.sample_index[idx + 1] | |
| document_ids = [] | |
| sample_parts = [] | |
| # Sample spans a single document | |
| if doc_index_beg == doc_index_end: | |
| # Add the document id | |
| document_ids.append(self.document_index[doc_index_beg]) | |
| # Add the entire sample | |
| sample_parts.append( | |
| self.indexed_dataset.get( | |
| self.document_index[doc_index_beg], | |
| offset=doc_index_beg_offset, | |
| length=doc_index_end_offset - doc_index_beg_offset + 1, | |
| ) | |
| ) | |
| # Sample spans multiple documents | |
| else: | |
| for i in range(doc_index_beg, doc_index_end + 1): | |
| # Add the document id | |
| document_ids.append(self.document_index[i]) | |
| # Add the sample part | |
| offset = 0 if i > doc_index_beg else doc_index_beg_offset | |
| length = None if i < doc_index_end else doc_index_end_offset + 1 | |
| sample_parts.append( | |
| self.indexed_dataset.get(self.document_index[i], offset=offset, length=length) | |
| ) | |
| if getattr(self.config, "add_bos"): | |
| sample = sample_parts[0] | |
| add_token = self.bos_id if sample[0] != self.bos_id else self.eod_id | |
| sample_parts.insert(0, numpy.array([add_token], dtype=sample.dtype)) | |
| return ( | |
| numpy.array(numpy.concatenate(sample_parts), dtype=numpy.int64), | |
| numpy.array(document_ids, dtype=numpy.int64), | |
| ) | |
| def _build_document_sample_shuffle_indices( | |
| self, | |
| ) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: | |
| """Build the document index, the sample index, and the shuffle index | |
| The document index: | |
| -- 1-D | |
| -- An ordered array of document ids | |
| The sample index: | |
| -- 2-D | |
| -- The document indices and offsets which mark the start of every sample | |
| The shuffle index: | |
| -- 1-D | |
| -- A random permutation of index range of the sample index | |
| Returns: | |
| Tuple[numpy.ndarray, numpy.ndarray]: The document index, the sample index, and the | |
| shuffle index | |
| TODO: Explain the 80% threshold | |
| """ | |
| path_to_cache = self.config.path_to_cache | |
| if path_to_cache is None: | |
| path_to_cache = os.path.join( | |
| self.indexed_dataset.path_prefix, "cache", f"{type(self).__name__}_indices" | |
| ) | |
| get_path_to = lambda suffix: os.path.join( | |
| path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}" | |
| ) | |
| path_to_description = get_path_to("description.txt") | |
| path_to_document_index = get_path_to("document_index.npy") | |
| path_to_sample_index = get_path_to("sample_index.npy") | |
| path_to_shuffle_index = get_path_to("shuffle_index.npy") | |
| cache_hit = all( | |
| map( | |
| os.path.isfile, | |
| [ | |
| path_to_description, | |
| path_to_document_index, | |
| path_to_sample_index, | |
| path_to_shuffle_index, | |
| ], | |
| ) | |
| ) | |
| num_tokens_per_epoch = self._get_num_tokens_per_epoch() | |
| num_epochs = self._get_num_epochs(num_tokens_per_epoch) | |
| if not cache_hit and torch.distributed.get_rank() == 0: | |
| log_single_rank( | |
| logger, | |
| logging.INFO, | |
| f"Build and save the {type(self).__name__} {self.index_split.name} indices", | |
| ) | |
| sequence_length = self.config.sequence_length | |
| if num_epochs == 1: | |
| separate_final_epoch = False | |
| else: | |
| # Get the number of samples for the last epoch | |
| num_samples_sans_final_epoch = ( | |
| (num_epochs - 1) * num_tokens_per_epoch - 1 | |
| ) // sequence_length | |
| num_samples_from_final_epoch = self.num_samples - num_samples_sans_final_epoch | |
| num_samples_per_epoch = (num_tokens_per_epoch - 1) // sequence_length | |
| # num_samples_from_final_epoch should be non-negative | |
| assert num_samples_from_final_epoch >= 0 | |
| # num_samples_from_final_epoch should not exceed max value | |
| assert num_samples_from_final_epoch <= num_samples_per_epoch + 1 | |
| # Separate the final epoch if it falls below the threshold | |
| threshold = 0.80 | |
| separate_final_epoch = num_samples_from_final_epoch < int( | |
| threshold * num_samples_per_epoch | |
| ) | |
| log_single_rank( | |
| logger, | |
| logging.DEBUG, | |
| f"> num_samples_from_final_epoch: {num_samples_from_final_epoch}", | |
| ) | |
| log_single_rank(logger, logging.DEBUG, f"> threshold: {threshold}") | |
| log_single_rank( | |
| logger, logging.DEBUG, f"> num_samples_per_epoch: {num_samples_per_epoch}" | |
| ) | |
| log_single_rank( | |
| logger, logging.DEBUG, f"> separate_final_epoch: {separate_final_epoch}" | |
| ) | |
| numpy_random_state = numpy.random.RandomState(self.config.random_seed) | |
| os.makedirs(path_to_cache, exist_ok=True) | |
| # Write the description | |
| with open(path_to_description, "wt") as writer: | |
| writer.write(self.unique_description) | |
| # Build the document index | |
| log_single_rank( | |
| logger, | |
| logging.INFO, | |
| f"\tBuild and save the document index to {os.path.basename(path_to_document_index)}", | |
| ) | |
| t_beg = time.time() | |
| document_index = _build_document_index( | |
| self.indexed_indices, num_epochs, numpy_random_state, separate_final_epoch, self.config.enable_shuffle | |
| ) | |
| numpy.save(path_to_document_index, document_index, allow_pickle=True) | |
| t_end = time.time() | |
| log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") | |
| # Build the sample index | |
| log_single_rank( | |
| logger, | |
| logging.INFO, | |
| f"\tBuild and save the sample index to {os.path.basename(path_to_sample_index)}", | |
| ) | |
| t_beg = time.time() | |
| from core.datasets import helpers | |
| assert document_index.dtype == numpy.int32 | |
| assert self.indexed_dataset.sequence_lengths.dtype == numpy.int32 | |
| sample_index = helpers.build_sample_idx( | |
| self.indexed_dataset.sequence_lengths, | |
| document_index, | |
| sequence_length, | |
| num_epochs, | |
| num_tokens_per_epoch, | |
| ) | |
| numpy.save(path_to_sample_index, sample_index, allow_pickle=True) | |
| t_end = time.time() | |
| log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") | |
| # Build the shuffle index | |
| log_single_rank( | |
| logger, | |
| logging.INFO, | |
| f"\tBuild and save the shuffle index to {os.path.basename(path_to_shuffle_index)}", | |
| ) | |
| t_beg = time.time() | |
| if separate_final_epoch: | |
| shuffle_index = _build_shuffle_index( | |
| num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state, True | |
| ) | |
| else: | |
| shuffle_index = _build_shuffle_index( | |
| sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state, True | |
| ) | |
| numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True) | |
| t_end = time.time() | |
| log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") | |
| log_single_rank( | |
| logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}" | |
| ) | |
| log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}") | |
| return document_index, sample_index, shuffle_index | |
| log_single_rank( | |
| logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices" | |
| ) | |
| log_single_rank( | |
| logger, | |
| logging.INFO, | |
| f"\tLoad the document index from {os.path.basename(path_to_document_index)}", | |
| ) | |
| t_beg = time.time() | |
| document_index = numpy.load(path_to_document_index, allow_pickle=True, mmap_mode='r') | |
| t_end = time.time() | |
| log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") | |
| log_single_rank( | |
| logger, | |
| logging.INFO, | |
| f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}", | |
| ) | |
| t_beg = time.time() | |
| sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r') | |
| t_end = time.time() | |
| log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") | |
| log_single_rank( | |
| logger, | |
| logging.INFO, | |
| f"\tLoad the shuffle index from {os.path.basename(path_to_shuffle_index)}", | |
| ) | |
| t_beg = time.time() | |
| shuffle_index = numpy.load(path_to_shuffle_index, allow_pickle=True, mmap_mode='r') | |
| t_end = time.time() | |
| log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") | |
| log_single_rank( | |
| logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}" | |
| ) | |
| log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}") | |
| return document_index, sample_index, shuffle_index | |
| def _get_num_tokens_per_epoch(self) -> int: | |
| """Calculate the number of tokens in a single epoch | |
| Returns: | |
| int: The number of tokens in a single epoch | |
| """ | |
| return int(numpy.sum(self.indexed_dataset.sequence_lengths[self.indexed_indices])) | |
| def _get_num_epochs(self, num_tokens_per_epoch: int) -> int: | |
| """Calculate the number of epochs | |
| Args: | |
| num_tokens_per_epoch (int): The number of tokens in a single epoch | |
| Returns: | |
| int: The number of epochs | |
| """ | |
| num_epochs = 0 | |
| num_tokens = 0 | |
| num_tokens_requested = (self.num_samples * self.config.sequence_length) + 1 | |
| while True: | |
| num_epochs += 1 | |
| num_tokens += num_tokens_per_epoch | |
| if num_tokens >= num_tokens_requested: | |
| return num_epochs | |
| def _build_document_index( | |
| documents: numpy.ndarray, | |
| num_epochs: int, | |
| numpy_random_state: numpy.random.RandomState, | |
| separate_final_epoch: bool, | |
| enable_shuffle: bool = False, | |
| ) -> numpy.ndarray: | |
| """Build an array with length = num epochs * num documents | |
| Args: | |
| documents (numpy.ndarray): the subset of exposed document indices | |
| num_epochs (int): The number of epochs | |
| numpy_random_state (numpy.random.RandomState): The NumPy random state | |
| separate_final_epoch (bool): Whether to exclude the last epoch from the global shuffle | |
| enable_shuffle (bool): Whether to enable the shuffle. Default is False to ensure the reproducibility | |
| Returns: | |
| numpy.ndarray: The document index | |
| TODO: Explain separate_final_epoch | |
| """ | |
| if not separate_final_epoch or num_epochs == 1: | |
| document_index = numpy.mgrid[0:num_epochs, 0 : len(documents)][1] | |
| document_index[:] = documents | |
| document_index = document_index.reshape(-1) | |
| document_index = document_index.astype(numpy.int32) | |
| if enable_shuffle: | |
| print("INFO: document_index shuffle is enabled...") | |
| numpy_random_state.shuffle(document_index) | |
| else: | |
| print("INFO: document_index shuffle is disabled...") | |
| return document_index | |
| doc_idx_first = _build_document_index(documents, num_epochs - 1, numpy_random_state, False, enable_shuffle) | |
| doc_idx_last = _build_document_index(documents, 1, numpy_random_state, False, enable_shuffle) | |
| return numpy.concatenate((doc_idx_first, doc_idx_last)) | |
| def _build_shuffle_index( | |
| num_samples: int, total_size: int, numpy_random_state: numpy.random.RandomState, | |
| enable_shuffle: bool = False, | |
| ) -> numpy.ndarray: | |
| """Build the range [0, size) and shuffle | |
| Args: | |
| num_samples (int): The size of the first shuffle range [0, num_samples) | |
| total_size (int): The size of the entire index. If larger than 'num_samples', it defines | |
| the second shuffle range [num_samples, total_size) | |
| numpy_random_state (numpy.random.RandomState): The NumPy random state | |
| Returns: | |
| numpy.ndarray: The shuffle index | |
| TODO: Explain [0, num_samples) [num_samples, total_size) split | |
| """ | |
| dtype_ = numpy.uint32 | |
| if total_size >= (numpy.iinfo(numpy.uint32).max - 1): | |
| dtype_ = numpy.int64 | |
| shuffle_idx_first = numpy.arange(start=0, stop=num_samples, step=1, dtype=dtype_) | |
| if enable_shuffle: | |
| print("INFO: shuffle_index shuffle is enabled...") | |
| numpy_random_state.shuffle(shuffle_idx_first) | |
| else: | |
| print("INFO: shuffle_index shuffle is disabled...") | |
| if num_samples == total_size: | |
| return shuffle_idx_first | |
| shuffle_idx_last = numpy.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_) | |
| if enable_shuffle: | |
| print("INFO: shuffle_index shuffle is enabled...") | |
| numpy_random_state.shuffle(shuffle_idx_last) | |
| else: | |
| print("INFO: shuffle_index shuffle is disabled...") | |
| return numpy.concatenate((shuffle_idx_first, shuffle_idx_last)) | |
| def _get_ltor_masks_and_position_ids( | |
| data: torch.Tensor, | |
| eod_token: int, | |
| reset_position_ids: bool, | |
| reset_attention_mask: bool, | |
| eod_mask_loss: bool, | |
| ): | |
| """Build masks and position id for left to right model. | |
| Args: | |
| data (torch.Tensor): The data tenor that holds the tokens from the dataset | |
| eod_token (int): ID of the token to that is considered the EOD | |
| reset_position_ids (bool): Switch to reset the document position ID's | |
| reset_attention_mask (bool): Switch to reset the attention mask | |
| eod_mask_loss (bool): Switch to enable the EOD mask loss | |
| Returns: | |
| torch.Tensor : Attention mask needed to be used for Attention | |
| torch.Tensor : The mask used for loss value during training | |
| torch.Tensor : The position ID's of the token | |
| """ | |
| # Extract batch size and sequence length. | |
| seq_length = data.numel() | |
| attention_mask = torch.tril(torch.ones((seq_length, seq_length), device=data.device)).unsqueeze( | |
| 0 | |
| ) | |
| # Loss mask. | |
| loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device) | |
| if eod_mask_loss: | |
| loss_mask[data == eod_token] = 0.0 | |
| # Position ids. | |
| position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) | |
| # We need to clone as the ids will be modifed based on batch index. | |
| if reset_position_ids: | |
| position_ids = position_ids.clone() | |
| if reset_position_ids or reset_attention_mask: | |
| # Find indecies where EOD token is. | |
| eod_index = position_ids[data[b] == eod_token] | |
| # Detach indecies from positions if going to modify positions. | |
| if reset_position_ids: | |
| eod_index = eod_index.clone() | |
| # Loop through EOD indecies: | |
| prev_index = 0 | |
| for j in range(eod_index.numel()): | |
| i = eod_index[j] | |
| # Mask attention loss. | |
| if reset_attention_mask: | |
| attention_mask[0, (i + 1) :, : (i + 1)] = 0 | |
| # Reset positions. | |
| if reset_position_ids: | |
| position_ids[(i + 1) :] -= i + 1 - prev_index | |
| prev_index = i + 1 | |
| # Convert attention mask to binary: | |
| attention_mask = attention_mask < 0.5 | |
| attention_mask = attention_mask.float() | |
| return attention_mask, loss_mask, position_ids |