|
|
""" |
|
|
Data loaders for LLM training. |
|
|
""" |
|
|
|
|
|
import random |
|
|
import numpy as np |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
from typing import Dict, List, Optional, Tuple, Union, Callable, Iterator, Any |
|
|
from data.dataset import Dataset |
|
|
|
|
|
|
|
|
def pad_batch( |
|
|
examples: List[Dict[str, np.ndarray]], |
|
|
pad_token_id: int = 0 |
|
|
) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Pad batch of examples to the same length. |
|
|
|
|
|
Args: |
|
|
examples: List of examples |
|
|
pad_token_id: Padding token ID |
|
|
|
|
|
Returns: |
|
|
Padded batch |
|
|
""" |
|
|
|
|
|
max_length = max(example["input_ids"].shape[0] for example in examples) |
|
|
|
|
|
|
|
|
batch = { |
|
|
"input_ids": np.full((len(examples), max_length), pad_token_id, dtype=np.int32), |
|
|
"attention_mask": np.zeros((len(examples), max_length), dtype=np.int32), |
|
|
"position_ids": np.zeros((len(examples), max_length), dtype=np.int32), |
|
|
} |
|
|
|
|
|
|
|
|
for i, example in enumerate(examples): |
|
|
length = example["input_ids"].shape[0] |
|
|
batch["input_ids"][i, :length] = example["input_ids"] |
|
|
batch["attention_mask"][i, :length] = example["attention_mask"] |
|
|
batch["position_ids"][i, :length] = example["position_ids"] |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
def create_masks( |
|
|
input_ids: np.ndarray, |
|
|
pad_token_id: int = 0 |
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Create attention mask and padding mask. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs [batch_size, seq_len] |
|
|
pad_token_id: Padding token ID |
|
|
|
|
|
Returns: |
|
|
Tuple of (attention_mask, padding_mask) |
|
|
""" |
|
|
|
|
|
padding_mask = (input_ids != pad_token_id).astype(np.int32) |
|
|
|
|
|
|
|
|
seq_len = input_ids.shape[1] |
|
|
causal_mask = np.tril(np.ones((seq_len, seq_len), dtype=np.int32)) |
|
|
|
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
|
attention_mask = padding_mask[:, None, None, :] * causal_mask[None, None, :, :] |
|
|
|
|
|
|
|
|
attention_mask = (1.0 - attention_mask.astype(np.float32)) * -1e9 |
|
|
|
|
|
return attention_mask, padding_mask |
|
|
|
|
|
|
|
|
class DataLoader: |
|
|
""" |
|
|
Data loader for LLM training. |
|
|
|
|
|
Attributes: |
|
|
dataset: Dataset |
|
|
batch_size: Batch size |
|
|
shuffle: Whether to shuffle data |
|
|
drop_last: Whether to drop last incomplete batch |
|
|
pad_token_id: Padding token ID |
|
|
collate_fn: Function to collate examples into batch |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset: Dataset, |
|
|
batch_size: int = 32, |
|
|
shuffle: bool = True, |
|
|
drop_last: bool = False, |
|
|
pad_token_id: int = 0, |
|
|
collate_fn: Optional[Callable] = None |
|
|
): |
|
|
""" |
|
|
Initialize data loader. |
|
|
|
|
|
Args: |
|
|
dataset: Dataset |
|
|
batch_size: Batch size |
|
|
shuffle: Whether to shuffle data |
|
|
drop_last: Whether to drop last incomplete batch |
|
|
pad_token_id: Padding token ID |
|
|
collate_fn: Function to collate examples into batch |
|
|
""" |
|
|
self.dataset = dataset |
|
|
self.batch_size = batch_size |
|
|
self.shuffle = shuffle |
|
|
self.drop_last = drop_last |
|
|
self.pad_token_id = pad_token_id |
|
|
|
|
|
|
|
|
if collate_fn is None: |
|
|
self.collate_fn = lambda examples: pad_batch(examples, pad_token_id) |
|
|
else: |
|
|
self.collate_fn = collate_fn |
|
|
|
|
|
def __iter__(self) -> Iterator[Dict[str, np.ndarray]]: |
|
|
""" |
|
|
Iterate over batches. |
|
|
|
|
|
Returns: |
|
|
Iterator over batches |
|
|
""" |
|
|
|
|
|
indices = list(range(len(self.dataset))) |
|
|
|
|
|
|
|
|
if self.shuffle: |
|
|
random.shuffle(indices) |
|
|
|
|
|
|
|
|
batch_indices = [] |
|
|
for idx in indices: |
|
|
batch_indices.append(idx) |
|
|
|
|
|
if len(batch_indices) == self.batch_size: |
|
|
|
|
|
examples = [self.dataset[i] for i in batch_indices] |
|
|
|
|
|
|
|
|
batch = self.collate_fn(examples) |
|
|
|
|
|
|
|
|
attention_mask, padding_mask = create_masks( |
|
|
batch["input_ids"], |
|
|
self.pad_token_id |
|
|
) |
|
|
|
|
|
|
|
|
batch["attention_mask"] = attention_mask |
|
|
batch["padding_mask"] = padding_mask |
|
|
|
|
|
|
|
|
batch = {k: jnp.array(v) for k, v in batch.items()} |
|
|
|
|
|
yield batch |
|
|
|
|
|
|
|
|
batch_indices = [] |
|
|
|
|
|
|
|
|
if batch_indices and not self.drop_last: |
|
|
|
|
|
examples = [self.dataset[i] for i in batch_indices] |
|
|
|
|
|
|
|
|
batch = self.collate_fn(examples) |
|
|
|
|
|
|
|
|
attention_mask, padding_mask = create_masks( |
|
|
batch["input_ids"], |
|
|
self.pad_token_id |
|
|
) |
|
|
|
|
|
|
|
|
batch["attention_mask"] = attention_mask |
|
|
batch["padding_mask"] = padding_mask |
|
|
|
|
|
|
|
|
batch = {k: jnp.array(v) for k, v in batch.items()} |
|
|
|
|
|
yield batch |
|
|
|
|
|
|
|
|
class TPUDataLoader: |
|
|
""" |
|
|
Data loader optimized for TPU v4-32 with high-performance data loading. |
|
|
|
|
|
Attributes: |
|
|
dataset: Dataset |
|
|
batch_size: Batch size per device |
|
|
shuffle: Whether to shuffle data |
|
|
drop_last: Whether to drop last incomplete batch |
|
|
pad_token_id: Padding token ID |
|
|
collate_fn: Function to collate examples into batch |
|
|
prefetch_size: Number of batches to prefetch |
|
|
use_pjit: Whether to use pjit for data loading |
|
|
use_circular_buffer: Whether to use circular buffer for data loading |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset: Dataset, |
|
|
batch_size: int = 32, |
|
|
shuffle: bool = True, |
|
|
drop_last: bool = True, |
|
|
pad_token_id: int = 0, |
|
|
collate_fn: Optional[Callable] = None, |
|
|
prefetch_size: int = 2, |
|
|
use_pjit: bool = True, |
|
|
use_circular_buffer: bool = True, |
|
|
seed: Optional[int] = None |
|
|
): |
|
|
""" |
|
|
Initialize data loader optimized for TPU v4-32. |
|
|
|
|
|
Args: |
|
|
dataset: Dataset |
|
|
batch_size: Batch size per device |
|
|
shuffle: Whether to shuffle data |
|
|
drop_last: Whether to drop last incomplete batch |
|
|
pad_token_id: Padding token ID |
|
|
collate_fn: Function to collate examples into batch |
|
|
prefetch_size: Number of batches to prefetch |
|
|
use_pjit: Whether to use pjit for data loading |
|
|
use_circular_buffer: Whether to use circular buffer for data loading |
|
|
seed: Random seed for shuffling |
|
|
""" |
|
|
self.dataset = dataset |
|
|
self.batch_size = batch_size |
|
|
self.shuffle = shuffle |
|
|
self.drop_last = drop_last |
|
|
self.pad_token_id = pad_token_id |
|
|
self.prefetch_size = prefetch_size |
|
|
self.use_pjit = use_pjit |
|
|
self.use_circular_buffer = use_circular_buffer |
|
|
self.seed = seed if seed is not None else random.randint(0, 2**32 - 1) |
|
|
|
|
|
|
|
|
if collate_fn is None: |
|
|
self.collate_fn = lambda examples: self._optimized_pad_batch(examples, pad_token_id) |
|
|
else: |
|
|
self.collate_fn = collate_fn |
|
|
|
|
|
|
|
|
self.num_devices = jax.device_count() |
|
|
print(f"TPUDataLoader: Using {self.num_devices} devices") |
|
|
|
|
|
|
|
|
self.global_batch_size = self.batch_size * self.num_devices |
|
|
print(f"TPUDataLoader: Global batch size: {self.global_batch_size}") |
|
|
|
|
|
|
|
|
self.prefetch_buffer = [] |
|
|
|
|
|
def _optimized_pad_batch(self, examples: List[Dict[str, np.ndarray]], pad_token_id: int) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Optimized padding function for TPU v4-32. |
|
|
|
|
|
Args: |
|
|
examples: List of examples |
|
|
pad_token_id: Padding token ID |
|
|
|
|
|
Returns: |
|
|
Padded batch |
|
|
""" |
|
|
|
|
|
max_length = max(example["input_ids"].shape[0] for example in examples) |
|
|
|
|
|
|
|
|
max_length = ((max_length + 127) // 128) * 128 |
|
|
|
|
|
|
|
|
batch_size = len(examples) |
|
|
batch = { |
|
|
"input_ids": np.full((batch_size, max_length), pad_token_id, dtype=np.int32), |
|
|
"attention_mask": np.zeros((batch_size, max_length), dtype=np.int32), |
|
|
"position_ids": np.zeros((batch_size, max_length), dtype=np.int32), |
|
|
} |
|
|
|
|
|
|
|
|
for i, example in enumerate(examples): |
|
|
length = example["input_ids"].shape[0] |
|
|
batch["input_ids"][i, :length] = example["input_ids"] |
|
|
batch["attention_mask"][i, :length] = 1 |
|
|
batch["position_ids"][i, :length] = np.arange(length, dtype=np.int32) |
|
|
|
|
|
return batch |
|
|
|
|
|
def __iter__(self) -> Iterator[Dict[str, jnp.ndarray]]: |
|
|
""" |
|
|
Iterate over batches with optimized data loading for TPU v4-32. |
|
|
|
|
|
Returns: |
|
|
Iterator over batches |
|
|
""" |
|
|
|
|
|
indices = np.array(range(len(self.dataset)), dtype=np.int32) |
|
|
|
|
|
|
|
|
if self.shuffle: |
|
|
rng = np.random.RandomState(self.seed) |
|
|
rng.shuffle(indices) |
|
|
|
|
|
|
|
|
num_batches = len(indices) // self.global_batch_size |
|
|
if not self.drop_last and len(indices) % self.global_batch_size > 0: |
|
|
num_batches += 1 |
|
|
|
|
|
|
|
|
if self.use_circular_buffer: |
|
|
import threading |
|
|
import queue |
|
|
|
|
|
|
|
|
batch_queue = queue.Queue(maxsize=self.prefetch_size) |
|
|
|
|
|
|
|
|
def load_batches(): |
|
|
for batch_idx in range(num_batches): |
|
|
|
|
|
start_idx = batch_idx * self.global_batch_size |
|
|
end_idx = min(start_idx + self.global_batch_size, len(indices)) |
|
|
batch_indices = indices[start_idx:end_idx] |
|
|
|
|
|
|
|
|
if len(batch_indices) < self.global_batch_size and not self.drop_last: |
|
|
|
|
|
pad_indices = batch_indices[:self.global_batch_size - len(batch_indices)] |
|
|
batch_indices = np.concatenate([batch_indices, pad_indices]) |
|
|
elif len(batch_indices) < self.global_batch_size and self.drop_last: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
examples = [self.dataset[int(i)] for i in batch_indices] |
|
|
|
|
|
|
|
|
batch = self.collate_fn(examples) |
|
|
|
|
|
|
|
|
attention_mask, padding_mask = create_masks( |
|
|
batch["input_ids"], |
|
|
self.pad_token_id |
|
|
) |
|
|
|
|
|
|
|
|
batch["attention_mask"] = attention_mask |
|
|
batch["padding_mask"] = padding_mask |
|
|
|
|
|
|
|
|
batch = { |
|
|
k: v.reshape(self.num_devices, self.batch_size, *v.shape[1:]) |
|
|
for k, v in batch.items() |
|
|
} |
|
|
|
|
|
|
|
|
batch = {k: jnp.asarray(v, dtype=jnp.dtype(v.dtype)) for k, v in batch.items()} |
|
|
|
|
|
|
|
|
batch_queue.put(batch) |
|
|
|
|
|
|
|
|
batch_queue.put(None) |
|
|
|
|
|
|
|
|
thread = threading.Thread(target=load_batches) |
|
|
thread.daemon = True |
|
|
thread.start() |
|
|
|
|
|
|
|
|
while True: |
|
|
batch = batch_queue.get() |
|
|
if batch is None: |
|
|
break |
|
|
yield batch |
|
|
else: |
|
|
|
|
|
for batch_idx in range(num_batches): |
|
|
|
|
|
start_idx = batch_idx * self.global_batch_size |
|
|
end_idx = min(start_idx + self.global_batch_size, len(indices)) |
|
|
batch_indices = indices[start_idx:end_idx] |
|
|
|
|
|
|
|
|
if len(batch_indices) < self.global_batch_size and self.drop_last: |
|
|
continue |
|
|
|
|
|
|
|
|
if len(batch_indices) < self.global_batch_size: |
|
|
|
|
|
pad_indices = batch_indices[:self.global_batch_size - len(batch_indices)] |
|
|
batch_indices = np.concatenate([batch_indices, pad_indices]) |
|
|
|
|
|
|
|
|
examples = [self.dataset[int(i)] for i in batch_indices] |
|
|
|
|
|
|
|
|
batch = self.collate_fn(examples) |
|
|
|
|
|
|
|
|
attention_mask, padding_mask = create_masks( |
|
|
batch["input_ids"], |
|
|
self.pad_token_id |
|
|
) |
|
|
|
|
|
|
|
|
batch["attention_mask"] = attention_mask |
|
|
batch["padding_mask"] = padding_mask |
|
|
|
|
|
|
|
|
batch = { |
|
|
k: v.reshape(self.num_devices, self.batch_size, *v.shape[1:]) |
|
|
for k, v in batch.items() |
|
|
} |
|
|
|
|
|
|
|
|
batch = {k: jnp.asarray(v, dtype=jnp.dtype(v.dtype)) for k, v in batch.items()} |
|
|
|
|
|
yield batch |
|
|
|