File size: 14,524 Bytes
f24563f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
"""
Data loaders for LLM training.
"""

import random
import numpy as np
import jax
import jax.numpy as jnp
from typing import Dict, List, Optional, Tuple, Union, Callable, Iterator, Any
from data.dataset import Dataset


def pad_batch(
    examples: List[Dict[str, np.ndarray]],
    pad_token_id: int = 0
) -> Dict[str, np.ndarray]:
    """
    Pad batch of examples to the same length.

    Args:
        examples: List of examples
        pad_token_id: Padding token ID

    Returns:
        Padded batch
    """
    # Get maximum length
    max_length = max(example["input_ids"].shape[0] for example in examples)

    # Initialize batch
    batch = {
        "input_ids": np.full((len(examples), max_length), pad_token_id, dtype=np.int32),
        "attention_mask": np.zeros((len(examples), max_length), dtype=np.int32),
        "position_ids": np.zeros((len(examples), max_length), dtype=np.int32),
    }

    # Fill batch
    for i, example in enumerate(examples):
        length = example["input_ids"].shape[0]
        batch["input_ids"][i, :length] = example["input_ids"]
        batch["attention_mask"][i, :length] = example["attention_mask"]
        batch["position_ids"][i, :length] = example["position_ids"]

    return batch


def create_masks(
    input_ids: np.ndarray,
    pad_token_id: int = 0
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Create attention mask and padding mask.

    Args:
        input_ids: Input token IDs [batch_size, seq_len]
        pad_token_id: Padding token ID

    Returns:
        Tuple of (attention_mask, padding_mask)
    """
    # Create padding mask
    padding_mask = (input_ids != pad_token_id).astype(np.int32)

    # Create causal attention mask
    seq_len = input_ids.shape[1]
    causal_mask = np.tril(np.ones((seq_len, seq_len), dtype=np.int32))

    # Combine padding mask and causal mask
    batch_size = input_ids.shape[0]
    attention_mask = padding_mask[:, None, None, :] * causal_mask[None, None, :, :]

    # Convert to float and apply large negative value to masked positions
    attention_mask = (1.0 - attention_mask.astype(np.float32)) * -1e9

    return attention_mask, padding_mask


class DataLoader:
    """
    Data loader for LLM training.

    Attributes:
        dataset: Dataset
        batch_size: Batch size
        shuffle: Whether to shuffle data
        drop_last: Whether to drop last incomplete batch
        pad_token_id: Padding token ID
        collate_fn: Function to collate examples into batch
    """

    def __init__(
        self,
        dataset: Dataset,
        batch_size: int = 32,
        shuffle: bool = True,
        drop_last: bool = False,
        pad_token_id: int = 0,
        collate_fn: Optional[Callable] = None
    ):
        """
        Initialize data loader.

        Args:
            dataset: Dataset
            batch_size: Batch size
            shuffle: Whether to shuffle data
            drop_last: Whether to drop last incomplete batch
            pad_token_id: Padding token ID
            collate_fn: Function to collate examples into batch
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.pad_token_id = pad_token_id

        # Set collate function
        if collate_fn is None:
            self.collate_fn = lambda examples: pad_batch(examples, pad_token_id)
        else:
            self.collate_fn = collate_fn

    def __iter__(self) -> Iterator[Dict[str, np.ndarray]]:
        """
        Iterate over batches.

        Returns:
            Iterator over batches
        """
        # Get indices
        indices = list(range(len(self.dataset)))

        # Shuffle indices if requested
        if self.shuffle:
            random.shuffle(indices)

        # Yield batches
        batch_indices = []
        for idx in indices:
            batch_indices.append(idx)

            if len(batch_indices) == self.batch_size:
                # Get examples
                examples = [self.dataset[i] for i in batch_indices]

                # Collate examples into batch
                batch = self.collate_fn(examples)

                # Create masks
                attention_mask, padding_mask = create_masks(
                    batch["input_ids"],
                    self.pad_token_id
                )

                # Update batch
                batch["attention_mask"] = attention_mask
                batch["padding_mask"] = padding_mask

                # Convert to JAX arrays
                batch = {k: jnp.array(v) for k, v in batch.items()}

                yield batch

                # Reset batch indices
                batch_indices = []

        # Yield last batch if not empty and not dropping last
        if batch_indices and not self.drop_last:
            # Get examples
            examples = [self.dataset[i] for i in batch_indices]

            # Collate examples into batch
            batch = self.collate_fn(examples)

            # Create masks
            attention_mask, padding_mask = create_masks(
                batch["input_ids"],
                self.pad_token_id
            )

            # Update batch
            batch["attention_mask"] = attention_mask
            batch["padding_mask"] = padding_mask

            # Convert to JAX arrays
            batch = {k: jnp.array(v) for k, v in batch.items()}

            yield batch


class TPUDataLoader:
    """
    Data loader optimized for TPU v4-32 with high-performance data loading.

    Attributes:
        dataset: Dataset
        batch_size: Batch size per device
        shuffle: Whether to shuffle data
        drop_last: Whether to drop last incomplete batch
        pad_token_id: Padding token ID
        collate_fn: Function to collate examples into batch
        prefetch_size: Number of batches to prefetch
        use_pjit: Whether to use pjit for data loading
        use_circular_buffer: Whether to use circular buffer for data loading
    """

    def __init__(
        self,
        dataset: Dataset,
        batch_size: int = 32,
        shuffle: bool = True,
        drop_last: bool = True,  # Default to True for TPU efficiency
        pad_token_id: int = 0,
        collate_fn: Optional[Callable] = None,
        prefetch_size: int = 2,
        use_pjit: bool = True,
        use_circular_buffer: bool = True,
        seed: Optional[int] = None
    ):
        """
        Initialize data loader optimized for TPU v4-32.

        Args:
            dataset: Dataset
            batch_size: Batch size per device
            shuffle: Whether to shuffle data
            drop_last: Whether to drop last incomplete batch
            pad_token_id: Padding token ID
            collate_fn: Function to collate examples into batch
            prefetch_size: Number of batches to prefetch
            use_pjit: Whether to use pjit for data loading
            use_circular_buffer: Whether to use circular buffer for data loading
            seed: Random seed for shuffling
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.pad_token_id = pad_token_id
        self.prefetch_size = prefetch_size
        self.use_pjit = use_pjit
        self.use_circular_buffer = use_circular_buffer
        self.seed = seed if seed is not None else random.randint(0, 2**32 - 1)

        # Set collate function with optimized padding
        if collate_fn is None:
            self.collate_fn = lambda examples: self._optimized_pad_batch(examples, pad_token_id)
        else:
            self.collate_fn = collate_fn

        # Get number of devices
        self.num_devices = jax.device_count()
        print(f"TPUDataLoader: Using {self.num_devices} devices")

        # Compute global batch size
        self.global_batch_size = self.batch_size * self.num_devices
        print(f"TPUDataLoader: Global batch size: {self.global_batch_size}")

        # Create prefetch buffer
        self.prefetch_buffer = []

    def _optimized_pad_batch(self, examples: List[Dict[str, np.ndarray]], pad_token_id: int) -> Dict[str, np.ndarray]:
        """
        Optimized padding function for TPU v4-32.

        Args:
            examples: List of examples
            pad_token_id: Padding token ID

        Returns:
            Padded batch
        """
        # Get maximum length
        max_length = max(example["input_ids"].shape[0] for example in examples)

        # Round max_length to multiple of 128 for TPU efficiency
        max_length = ((max_length + 127) // 128) * 128

        # Initialize batch with preallocated arrays
        batch_size = len(examples)
        batch = {
            "input_ids": np.full((batch_size, max_length), pad_token_id, dtype=np.int32),
            "attention_mask": np.zeros((batch_size, max_length), dtype=np.int32),
            "position_ids": np.zeros((batch_size, max_length), dtype=np.int32),
        }

        # Fill batch with vectorized operations where possible
        for i, example in enumerate(examples):
            length = example["input_ids"].shape[0]
            batch["input_ids"][i, :length] = example["input_ids"]
            batch["attention_mask"][i, :length] = 1  # Simplified mask creation
            batch["position_ids"][i, :length] = np.arange(length, dtype=np.int32)

        return batch

    def __iter__(self) -> Iterator[Dict[str, jnp.ndarray]]:
        """
        Iterate over batches with optimized data loading for TPU v4-32.

        Returns:
            Iterator over batches
        """
        # Get indices
        indices = np.array(range(len(self.dataset)), dtype=np.int32)

        # Shuffle indices if requested
        if self.shuffle:
            rng = np.random.RandomState(self.seed)
            rng.shuffle(indices)

        # Create batches
        num_batches = len(indices) // self.global_batch_size
        if not self.drop_last and len(indices) % self.global_batch_size > 0:
            num_batches += 1

        # Prefetch batches in background if enabled
        if self.use_circular_buffer:
            import threading
            import queue

            # Create queue for prefetched batches
            batch_queue = queue.Queue(maxsize=self.prefetch_size)

            # Define batch loading function
            def load_batches():
                for batch_idx in range(num_batches):
                    # Get batch indices
                    start_idx = batch_idx * self.global_batch_size
                    end_idx = min(start_idx + self.global_batch_size, len(indices))
                    batch_indices = indices[start_idx:end_idx]

                    # Pad batch if necessary
                    if len(batch_indices) < self.global_batch_size and not self.drop_last:
                        # Pad with repeated indices
                        pad_indices = batch_indices[:self.global_batch_size - len(batch_indices)]
                        batch_indices = np.concatenate([batch_indices, pad_indices])
                    elif len(batch_indices) < self.global_batch_size and self.drop_last:
                        # Skip incomplete batch
                        continue

                    # Get examples
                    examples = [self.dataset[int(i)] for i in batch_indices]

                    # Collate examples into batch
                    batch = self.collate_fn(examples)

                    # Create masks
                    attention_mask, padding_mask = create_masks(
                        batch["input_ids"],
                        self.pad_token_id
                    )

                    # Update batch
                    batch["attention_mask"] = attention_mask
                    batch["padding_mask"] = padding_mask

                    # Reshape batch for devices
                    batch = {
                        k: v.reshape(self.num_devices, self.batch_size, *v.shape[1:])
                        for k, v in batch.items()
                    }

                    # Convert to JAX arrays with optimized memory layout
                    batch = {k: jnp.asarray(v, dtype=jnp.dtype(v.dtype)) for k, v in batch.items()}

                    # Add batch to queue
                    batch_queue.put(batch)

                # Signal end of batches
                batch_queue.put(None)

            # Start batch loading thread
            thread = threading.Thread(target=load_batches)
            thread.daemon = True
            thread.start()

            # Yield batches from queue
            while True:
                batch = batch_queue.get()
                if batch is None:
                    break
                yield batch
        else:
            # Standard batch loading
            for batch_idx in range(num_batches):
                # Get batch indices
                start_idx = batch_idx * self.global_batch_size
                end_idx = min(start_idx + self.global_batch_size, len(indices))
                batch_indices = indices[start_idx:end_idx]

                # Skip incomplete batch if dropping last
                if len(batch_indices) < self.global_batch_size and self.drop_last:
                    continue

                # Pad batch if necessary
                if len(batch_indices) < self.global_batch_size:
                    # Pad with repeated indices
                    pad_indices = batch_indices[:self.global_batch_size - len(batch_indices)]
                    batch_indices = np.concatenate([batch_indices, pad_indices])

                # Get examples
                examples = [self.dataset[int(i)] for i in batch_indices]

                # Collate examples into batch
                batch = self.collate_fn(examples)

                # Create masks
                attention_mask, padding_mask = create_masks(
                    batch["input_ids"],
                    self.pad_token_id
                )

                # Update batch
                batch["attention_mask"] = attention_mask
                batch["padding_mask"] = padding_mask

                # Reshape batch for devices
                batch = {
                    k: v.reshape(self.num_devices, self.batch_size, *v.shape[1:])
                    for k, v in batch.items()
                }

                # Convert to JAX arrays with optimized memory layout
                batch = {k: jnp.asarray(v, dtype=jnp.dtype(v.dtype)) for k, v in batch.items()}

                yield batch