futurespyhi
1.add YuE 2.modify .gitignore 3.modify requirements.txt
15389e6
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import logging
import os
import time
from dataclasses import dataclass
from typing import Dict, Tuple, Union
import numpy
import torch
# from megatron import get_tokenizer
from core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from core.datasets.indexed_dataset import MMapIndexedDataset
from core.datasets.megatron_dataset import MegatronDataset
from core.datasets.utils import Split, log_single_rank
logger = logging.getLogger(__name__)
@dataclass
class GPTDatasetConfig(BlendedMegatronDatasetConfig):
"""Configuration object for Megatron Core GPT datasets
Attributes:
return_document_ids (bool): Whether to return the document ids when querying the dataset.
reset_position_ids (bool): Option to reset the position IDs in the dataset at an interval
reset_attention_mask (bool): Option to reset the attention mask from the dataset
eod_mask_loss (bool): Option to enable the EOD mask loss
eod_id (int): Has the identity of the end of document
"""
return_document_ids: bool = False
reset_position_ids: bool = False
reset_attention_mask: bool = False
eod_mask_loss: bool = False
eod_id: int = 0
add_bos: bool = False
enable_shuffle: bool = False
class GPTDataset(MegatronDataset):
"""The base GPT dataset
Args:
indexed_dataset (MMapIndexedDataset): The MMapIndexedDataset around which to build the
MegatronDataset
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (int): The number of samples to draw from the indexed dataset
index_split (Split): The indexed_indices Split
config (GPTDatasetConfig): The GPT-specific container for all config sourced parameters
"""
def __init__(
self,
indexed_dataset: MMapIndexedDataset,
indexed_indices: numpy.ndarray,
num_samples: int,
index_split: Split,
config: GPTDatasetConfig,
) -> None:
super().__init__(indexed_dataset, indexed_indices, num_samples, index_split, config)
# tokenizer = get_tokenizer()
# self.bos_id = tokenizer.bos
# self.eod_id = tokenizer.eod
def _finalize(self) -> None:
"""Abstract method implementation
Load or build/cache the document, sample, and shuffle indices
"""
assert isinstance(self.config, GPTDatasetConfig)
(
self.document_index,
self.sample_index,
self.shuffle_index,
) = self._build_document_sample_shuffle_indices()
def __len__(self) -> int:
"""Abstract method implementation
Returns:
int: The length of the dataset
"""
return self.sample_index.shape[0] - 1
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Abstract method implementation
Args:
idx (int): The index into the dataset
Returns:
Dict[str, torch.Tensor]: The text ids wrapped in a dictionary
"""
text, _ = self._query_document_sample_shuffle_indices(idx)
text = torch.from_numpy(text)
tokens_ = text.long()
labels = tokens_[1:].contiguous()
tokens = tokens_[:-1].contiguous()
attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
tokens,
self.config.eod_id,
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
)
return {
"input_ids": tokens,
"labels": labels,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
@staticmethod
def is_multimodal() -> bool:
"""Abstract method implementation
Returns:
bool: False
"""
return False
@staticmethod
def is_split_by_sequence() -> bool:
"""Abstract method implementation
Returns:
bool: True
"""
return True
def _query_document_sample_shuffle_indices(
self, idx: int
) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""Get the text (token ids) and document ids for a given index
Args:
idx (int): The index into the dataset
Returns:
Tuple[numpy.ndarray, numpy.ndarray]: The text ids and document ids
"""
# Do the shuffle mapping
idx = self.shuffle_index[idx]
# Get the beginning and end documents and offsets
doc_index_beg, doc_index_beg_offset = self.sample_index[idx]
doc_index_end, doc_index_end_offset = self.sample_index[idx + 1]
document_ids = []
sample_parts = []
# Sample spans a single document
if doc_index_beg == doc_index_end:
# Add the document id
document_ids.append(self.document_index[doc_index_beg])
# Add the entire sample
sample_parts.append(
self.indexed_dataset.get(
self.document_index[doc_index_beg],
offset=doc_index_beg_offset,
length=doc_index_end_offset - doc_index_beg_offset + 1,
)
)
# Sample spans multiple documents
else:
for i in range(doc_index_beg, doc_index_end + 1):
# Add the document id
document_ids.append(self.document_index[i])
# Add the sample part
offset = 0 if i > doc_index_beg else doc_index_beg_offset
length = None if i < doc_index_end else doc_index_end_offset + 1
sample_parts.append(
self.indexed_dataset.get(self.document_index[i], offset=offset, length=length)
)
if getattr(self.config, "add_bos"):
sample = sample_parts[0]
add_token = self.bos_id if sample[0] != self.bos_id else self.eod_id
sample_parts.insert(0, numpy.array([add_token], dtype=sample.dtype))
return (
numpy.array(numpy.concatenate(sample_parts), dtype=numpy.int64),
numpy.array(document_ids, dtype=numpy.int64),
)
def _build_document_sample_shuffle_indices(
self,
) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
"""Build the document index, the sample index, and the shuffle index
The document index:
-- 1-D
-- An ordered array of document ids
The sample index:
-- 2-D
-- The document indices and offsets which mark the start of every sample
The shuffle index:
-- 1-D
-- A random permutation of index range of the sample index
Returns:
Tuple[numpy.ndarray, numpy.ndarray]: The document index, the sample index, and the
shuffle index
TODO: Explain the 80% threshold
"""
path_to_cache = self.config.path_to_cache
if path_to_cache is None:
path_to_cache = os.path.join(
self.indexed_dataset.path_prefix, "cache", f"{type(self).__name__}_indices"
)
get_path_to = lambda suffix: os.path.join(
path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}"
)
path_to_description = get_path_to("description.txt")
path_to_document_index = get_path_to("document_index.npy")
path_to_sample_index = get_path_to("sample_index.npy")
path_to_shuffle_index = get_path_to("shuffle_index.npy")
cache_hit = all(
map(
os.path.isfile,
[
path_to_description,
path_to_document_index,
path_to_sample_index,
path_to_shuffle_index,
],
)
)
num_tokens_per_epoch = self._get_num_tokens_per_epoch()
num_epochs = self._get_num_epochs(num_tokens_per_epoch)
if not cache_hit and torch.distributed.get_rank() == 0:
log_single_rank(
logger,
logging.INFO,
f"Build and save the {type(self).__name__} {self.index_split.name} indices",
)
sequence_length = self.config.sequence_length
if num_epochs == 1:
separate_final_epoch = False
else:
# Get the number of samples for the last epoch
num_samples_sans_final_epoch = (
(num_epochs - 1) * num_tokens_per_epoch - 1
) // sequence_length
num_samples_from_final_epoch = self.num_samples - num_samples_sans_final_epoch
num_samples_per_epoch = (num_tokens_per_epoch - 1) // sequence_length
# num_samples_from_final_epoch should be non-negative
assert num_samples_from_final_epoch >= 0
# num_samples_from_final_epoch should not exceed max value
assert num_samples_from_final_epoch <= num_samples_per_epoch + 1
# Separate the final epoch if it falls below the threshold
threshold = 0.80
separate_final_epoch = num_samples_from_final_epoch < int(
threshold * num_samples_per_epoch
)
log_single_rank(
logger,
logging.DEBUG,
f"> num_samples_from_final_epoch: {num_samples_from_final_epoch}",
)
log_single_rank(logger, logging.DEBUG, f"> threshold: {threshold}")
log_single_rank(
logger, logging.DEBUG, f"> num_samples_per_epoch: {num_samples_per_epoch}"
)
log_single_rank(
logger, logging.DEBUG, f"> separate_final_epoch: {separate_final_epoch}"
)
numpy_random_state = numpy.random.RandomState(self.config.random_seed)
os.makedirs(path_to_cache, exist_ok=True)
# Write the description
with open(path_to_description, "wt") as writer:
writer.write(self.unique_description)
# Build the document index
log_single_rank(
logger,
logging.INFO,
f"\tBuild and save the document index to {os.path.basename(path_to_document_index)}",
)
t_beg = time.time()
document_index = _build_document_index(
self.indexed_indices, num_epochs, numpy_random_state, separate_final_epoch, self.config.enable_shuffle
)
numpy.save(path_to_document_index, document_index, allow_pickle=True)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
# Build the sample index
log_single_rank(
logger,
logging.INFO,
f"\tBuild and save the sample index to {os.path.basename(path_to_sample_index)}",
)
t_beg = time.time()
from core.datasets import helpers
assert document_index.dtype == numpy.int32
assert self.indexed_dataset.sequence_lengths.dtype == numpy.int32
sample_index = helpers.build_sample_idx(
self.indexed_dataset.sequence_lengths,
document_index,
sequence_length,
num_epochs,
num_tokens_per_epoch,
)
numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
# Build the shuffle index
log_single_rank(
logger,
logging.INFO,
f"\tBuild and save the shuffle index to {os.path.basename(path_to_shuffle_index)}",
)
t_beg = time.time()
if separate_final_epoch:
shuffle_index = _build_shuffle_index(
num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state, True
)
else:
shuffle_index = _build_shuffle_index(
sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state, True
)
numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
)
log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")
return document_index, sample_index, shuffle_index
log_single_rank(
logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices"
)
log_single_rank(
logger,
logging.INFO,
f"\tLoad the document index from {os.path.basename(path_to_document_index)}",
)
t_beg = time.time()
document_index = numpy.load(path_to_document_index, allow_pickle=True, mmap_mode='r')
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger,
logging.INFO,
f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}",
)
t_beg = time.time()
sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r')
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger,
logging.INFO,
f"\tLoad the shuffle index from {os.path.basename(path_to_shuffle_index)}",
)
t_beg = time.time()
shuffle_index = numpy.load(path_to_shuffle_index, allow_pickle=True, mmap_mode='r')
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
)
log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")
return document_index, sample_index, shuffle_index
def _get_num_tokens_per_epoch(self) -> int:
"""Calculate the number of tokens in a single epoch
Returns:
int: The number of tokens in a single epoch
"""
return int(numpy.sum(self.indexed_dataset.sequence_lengths[self.indexed_indices]))
def _get_num_epochs(self, num_tokens_per_epoch: int) -> int:
"""Calculate the number of epochs
Args:
num_tokens_per_epoch (int): The number of tokens in a single epoch
Returns:
int: The number of epochs
"""
num_epochs = 0
num_tokens = 0
num_tokens_requested = (self.num_samples * self.config.sequence_length) + 1
while True:
num_epochs += 1
num_tokens += num_tokens_per_epoch
if num_tokens >= num_tokens_requested:
return num_epochs
def _build_document_index(
documents: numpy.ndarray,
num_epochs: int,
numpy_random_state: numpy.random.RandomState,
separate_final_epoch: bool,
enable_shuffle: bool = False,
) -> numpy.ndarray:
"""Build an array with length = num epochs * num documents
Args:
documents (numpy.ndarray): the subset of exposed document indices
num_epochs (int): The number of epochs
numpy_random_state (numpy.random.RandomState): The NumPy random state
separate_final_epoch (bool): Whether to exclude the last epoch from the global shuffle
enable_shuffle (bool): Whether to enable the shuffle. Default is False to ensure the reproducibility
Returns:
numpy.ndarray: The document index
TODO: Explain separate_final_epoch
"""
if not separate_final_epoch or num_epochs == 1:
document_index = numpy.mgrid[0:num_epochs, 0 : len(documents)][1]
document_index[:] = documents
document_index = document_index.reshape(-1)
document_index = document_index.astype(numpy.int32)
if enable_shuffle:
print("INFO: document_index shuffle is enabled...")
numpy_random_state.shuffle(document_index)
else:
print("INFO: document_index shuffle is disabled...")
return document_index
doc_idx_first = _build_document_index(documents, num_epochs - 1, numpy_random_state, False, enable_shuffle)
doc_idx_last = _build_document_index(documents, 1, numpy_random_state, False, enable_shuffle)
return numpy.concatenate((doc_idx_first, doc_idx_last))
def _build_shuffle_index(
num_samples: int, total_size: int, numpy_random_state: numpy.random.RandomState,
enable_shuffle: bool = False,
) -> numpy.ndarray:
"""Build the range [0, size) and shuffle
Args:
num_samples (int): The size of the first shuffle range [0, num_samples)
total_size (int): The size of the entire index. If larger than 'num_samples', it defines
the second shuffle range [num_samples, total_size)
numpy_random_state (numpy.random.RandomState): The NumPy random state
Returns:
numpy.ndarray: The shuffle index
TODO: Explain [0, num_samples) [num_samples, total_size) split
"""
dtype_ = numpy.uint32
if total_size >= (numpy.iinfo(numpy.uint32).max - 1):
dtype_ = numpy.int64
shuffle_idx_first = numpy.arange(start=0, stop=num_samples, step=1, dtype=dtype_)
if enable_shuffle:
print("INFO: shuffle_index shuffle is enabled...")
numpy_random_state.shuffle(shuffle_idx_first)
else:
print("INFO: shuffle_index shuffle is disabled...")
if num_samples == total_size:
return shuffle_idx_first
shuffle_idx_last = numpy.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_)
if enable_shuffle:
print("INFO: shuffle_index shuffle is enabled...")
numpy_random_state.shuffle(shuffle_idx_last)
else:
print("INFO: shuffle_index shuffle is disabled...")
return numpy.concatenate((shuffle_idx_first, shuffle_idx_last))
def _get_ltor_masks_and_position_ids(
data: torch.Tensor,
eod_token: int,
reset_position_ids: bool,
reset_attention_mask: bool,
eod_mask_loss: bool,
):
"""Build masks and position id for left to right model.
Args:
data (torch.Tensor): The data tenor that holds the tokens from the dataset
eod_token (int): ID of the token to that is considered the EOD
reset_position_ids (bool): Switch to reset the document position ID's
reset_attention_mask (bool): Switch to reset the attention mask
eod_mask_loss (bool): Switch to enable the EOD mask loss
Returns:
torch.Tensor : Attention mask needed to be used for Attention
torch.Tensor : The mask used for loss value during training
torch.Tensor : The position ID's of the token
"""
# Extract batch size and sequence length.
seq_length = data.numel()
attention_mask = torch.tril(torch.ones((seq_length, seq_length), device=data.device)).unsqueeze(
0
)
# Loss mask.
loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Find indecies where EOD token is.
eod_index = position_ids[data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.numel()):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[(i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
attention_mask = attention_mask.float()
return attention_mask, loss_mask, position_ids