File size: 14,524 Bytes
f24563f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 |
"""
Data loaders for LLM training.
"""
import random
import numpy as np
import jax
import jax.numpy as jnp
from typing import Dict, List, Optional, Tuple, Union, Callable, Iterator, Any
from data.dataset import Dataset
def pad_batch(
examples: List[Dict[str, np.ndarray]],
pad_token_id: int = 0
) -> Dict[str, np.ndarray]:
"""
Pad batch of examples to the same length.
Args:
examples: List of examples
pad_token_id: Padding token ID
Returns:
Padded batch
"""
# Get maximum length
max_length = max(example["input_ids"].shape[0] for example in examples)
# Initialize batch
batch = {
"input_ids": np.full((len(examples), max_length), pad_token_id, dtype=np.int32),
"attention_mask": np.zeros((len(examples), max_length), dtype=np.int32),
"position_ids": np.zeros((len(examples), max_length), dtype=np.int32),
}
# Fill batch
for i, example in enumerate(examples):
length = example["input_ids"].shape[0]
batch["input_ids"][i, :length] = example["input_ids"]
batch["attention_mask"][i, :length] = example["attention_mask"]
batch["position_ids"][i, :length] = example["position_ids"]
return batch
def create_masks(
input_ids: np.ndarray,
pad_token_id: int = 0
) -> Tuple[np.ndarray, np.ndarray]:
"""
Create attention mask and padding mask.
Args:
input_ids: Input token IDs [batch_size, seq_len]
pad_token_id: Padding token ID
Returns:
Tuple of (attention_mask, padding_mask)
"""
# Create padding mask
padding_mask = (input_ids != pad_token_id).astype(np.int32)
# Create causal attention mask
seq_len = input_ids.shape[1]
causal_mask = np.tril(np.ones((seq_len, seq_len), dtype=np.int32))
# Combine padding mask and causal mask
batch_size = input_ids.shape[0]
attention_mask = padding_mask[:, None, None, :] * causal_mask[None, None, :, :]
# Convert to float and apply large negative value to masked positions
attention_mask = (1.0 - attention_mask.astype(np.float32)) * -1e9
return attention_mask, padding_mask
class DataLoader:
"""
Data loader for LLM training.
Attributes:
dataset: Dataset
batch_size: Batch size
shuffle: Whether to shuffle data
drop_last: Whether to drop last incomplete batch
pad_token_id: Padding token ID
collate_fn: Function to collate examples into batch
"""
def __init__(
self,
dataset: Dataset,
batch_size: int = 32,
shuffle: bool = True,
drop_last: bool = False,
pad_token_id: int = 0,
collate_fn: Optional[Callable] = None
):
"""
Initialize data loader.
Args:
dataset: Dataset
batch_size: Batch size
shuffle: Whether to shuffle data
drop_last: Whether to drop last incomplete batch
pad_token_id: Padding token ID
collate_fn: Function to collate examples into batch
"""
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.pad_token_id = pad_token_id
# Set collate function
if collate_fn is None:
self.collate_fn = lambda examples: pad_batch(examples, pad_token_id)
else:
self.collate_fn = collate_fn
def __iter__(self) -> Iterator[Dict[str, np.ndarray]]:
"""
Iterate over batches.
Returns:
Iterator over batches
"""
# Get indices
indices = list(range(len(self.dataset)))
# Shuffle indices if requested
if self.shuffle:
random.shuffle(indices)
# Yield batches
batch_indices = []
for idx in indices:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
# Get examples
examples = [self.dataset[i] for i in batch_indices]
# Collate examples into batch
batch = self.collate_fn(examples)
# Create masks
attention_mask, padding_mask = create_masks(
batch["input_ids"],
self.pad_token_id
)
# Update batch
batch["attention_mask"] = attention_mask
batch["padding_mask"] = padding_mask
# Convert to JAX arrays
batch = {k: jnp.array(v) for k, v in batch.items()}
yield batch
# Reset batch indices
batch_indices = []
# Yield last batch if not empty and not dropping last
if batch_indices and not self.drop_last:
# Get examples
examples = [self.dataset[i] for i in batch_indices]
# Collate examples into batch
batch = self.collate_fn(examples)
# Create masks
attention_mask, padding_mask = create_masks(
batch["input_ids"],
self.pad_token_id
)
# Update batch
batch["attention_mask"] = attention_mask
batch["padding_mask"] = padding_mask
# Convert to JAX arrays
batch = {k: jnp.array(v) for k, v in batch.items()}
yield batch
class TPUDataLoader:
"""
Data loader optimized for TPU v4-32 with high-performance data loading.
Attributes:
dataset: Dataset
batch_size: Batch size per device
shuffle: Whether to shuffle data
drop_last: Whether to drop last incomplete batch
pad_token_id: Padding token ID
collate_fn: Function to collate examples into batch
prefetch_size: Number of batches to prefetch
use_pjit: Whether to use pjit for data loading
use_circular_buffer: Whether to use circular buffer for data loading
"""
def __init__(
self,
dataset: Dataset,
batch_size: int = 32,
shuffle: bool = True,
drop_last: bool = True, # Default to True for TPU efficiency
pad_token_id: int = 0,
collate_fn: Optional[Callable] = None,
prefetch_size: int = 2,
use_pjit: bool = True,
use_circular_buffer: bool = True,
seed: Optional[int] = None
):
"""
Initialize data loader optimized for TPU v4-32.
Args:
dataset: Dataset
batch_size: Batch size per device
shuffle: Whether to shuffle data
drop_last: Whether to drop last incomplete batch
pad_token_id: Padding token ID
collate_fn: Function to collate examples into batch
prefetch_size: Number of batches to prefetch
use_pjit: Whether to use pjit for data loading
use_circular_buffer: Whether to use circular buffer for data loading
seed: Random seed for shuffling
"""
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.pad_token_id = pad_token_id
self.prefetch_size = prefetch_size
self.use_pjit = use_pjit
self.use_circular_buffer = use_circular_buffer
self.seed = seed if seed is not None else random.randint(0, 2**32 - 1)
# Set collate function with optimized padding
if collate_fn is None:
self.collate_fn = lambda examples: self._optimized_pad_batch(examples, pad_token_id)
else:
self.collate_fn = collate_fn
# Get number of devices
self.num_devices = jax.device_count()
print(f"TPUDataLoader: Using {self.num_devices} devices")
# Compute global batch size
self.global_batch_size = self.batch_size * self.num_devices
print(f"TPUDataLoader: Global batch size: {self.global_batch_size}")
# Create prefetch buffer
self.prefetch_buffer = []
def _optimized_pad_batch(self, examples: List[Dict[str, np.ndarray]], pad_token_id: int) -> Dict[str, np.ndarray]:
"""
Optimized padding function for TPU v4-32.
Args:
examples: List of examples
pad_token_id: Padding token ID
Returns:
Padded batch
"""
# Get maximum length
max_length = max(example["input_ids"].shape[0] for example in examples)
# Round max_length to multiple of 128 for TPU efficiency
max_length = ((max_length + 127) // 128) * 128
# Initialize batch with preallocated arrays
batch_size = len(examples)
batch = {
"input_ids": np.full((batch_size, max_length), pad_token_id, dtype=np.int32),
"attention_mask": np.zeros((batch_size, max_length), dtype=np.int32),
"position_ids": np.zeros((batch_size, max_length), dtype=np.int32),
}
# Fill batch with vectorized operations where possible
for i, example in enumerate(examples):
length = example["input_ids"].shape[0]
batch["input_ids"][i, :length] = example["input_ids"]
batch["attention_mask"][i, :length] = 1 # Simplified mask creation
batch["position_ids"][i, :length] = np.arange(length, dtype=np.int32)
return batch
def __iter__(self) -> Iterator[Dict[str, jnp.ndarray]]:
"""
Iterate over batches with optimized data loading for TPU v4-32.
Returns:
Iterator over batches
"""
# Get indices
indices = np.array(range(len(self.dataset)), dtype=np.int32)
# Shuffle indices if requested
if self.shuffle:
rng = np.random.RandomState(self.seed)
rng.shuffle(indices)
# Create batches
num_batches = len(indices) // self.global_batch_size
if not self.drop_last and len(indices) % self.global_batch_size > 0:
num_batches += 1
# Prefetch batches in background if enabled
if self.use_circular_buffer:
import threading
import queue
# Create queue for prefetched batches
batch_queue = queue.Queue(maxsize=self.prefetch_size)
# Define batch loading function
def load_batches():
for batch_idx in range(num_batches):
# Get batch indices
start_idx = batch_idx * self.global_batch_size
end_idx = min(start_idx + self.global_batch_size, len(indices))
batch_indices = indices[start_idx:end_idx]
# Pad batch if necessary
if len(batch_indices) < self.global_batch_size and not self.drop_last:
# Pad with repeated indices
pad_indices = batch_indices[:self.global_batch_size - len(batch_indices)]
batch_indices = np.concatenate([batch_indices, pad_indices])
elif len(batch_indices) < self.global_batch_size and self.drop_last:
# Skip incomplete batch
continue
# Get examples
examples = [self.dataset[int(i)] for i in batch_indices]
# Collate examples into batch
batch = self.collate_fn(examples)
# Create masks
attention_mask, padding_mask = create_masks(
batch["input_ids"],
self.pad_token_id
)
# Update batch
batch["attention_mask"] = attention_mask
batch["padding_mask"] = padding_mask
# Reshape batch for devices
batch = {
k: v.reshape(self.num_devices, self.batch_size, *v.shape[1:])
for k, v in batch.items()
}
# Convert to JAX arrays with optimized memory layout
batch = {k: jnp.asarray(v, dtype=jnp.dtype(v.dtype)) for k, v in batch.items()}
# Add batch to queue
batch_queue.put(batch)
# Signal end of batches
batch_queue.put(None)
# Start batch loading thread
thread = threading.Thread(target=load_batches)
thread.daemon = True
thread.start()
# Yield batches from queue
while True:
batch = batch_queue.get()
if batch is None:
break
yield batch
else:
# Standard batch loading
for batch_idx in range(num_batches):
# Get batch indices
start_idx = batch_idx * self.global_batch_size
end_idx = min(start_idx + self.global_batch_size, len(indices))
batch_indices = indices[start_idx:end_idx]
# Skip incomplete batch if dropping last
if len(batch_indices) < self.global_batch_size and self.drop_last:
continue
# Pad batch if necessary
if len(batch_indices) < self.global_batch_size:
# Pad with repeated indices
pad_indices = batch_indices[:self.global_batch_size - len(batch_indices)]
batch_indices = np.concatenate([batch_indices, pad_indices])
# Get examples
examples = [self.dataset[int(i)] for i in batch_indices]
# Collate examples into batch
batch = self.collate_fn(examples)
# Create masks
attention_mask, padding_mask = create_masks(
batch["input_ids"],
self.pad_token_id
)
# Update batch
batch["attention_mask"] = attention_mask
batch["padding_mask"] = padding_mask
# Reshape batch for devices
batch = {
k: v.reshape(self.num_devices, self.batch_size, *v.shape[1:])
for k, v in batch.items()
}
# Convert to JAX arrays with optimized memory layout
batch = {k: jnp.asarray(v, dtype=jnp.dtype(v.dtype)) for k, v in batch.items()}
yield batch
|