| 
							 | 
						 | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						Main training script for LLM training on TPU v4-32. | 
					
					
						
						| 
							 | 
						Optimized for 128K token context length and 30-day training. | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import sys | 
					
					
						
						| 
							 | 
						import time | 
					
					
						
						| 
							 | 
						import json | 
					
					
						
						| 
							 | 
						import argparse | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import threading | 
					
					
						
						| 
							 | 
						import queue | 
					
					
						
						| 
							 | 
						from typing import Dict, Any, Optional, List, Tuple | 
					
					
						
						| 
							 | 
						import jax | 
					
					
						
						| 
							 | 
						import jax.numpy as jnp | 
					
					
						
						| 
							 | 
						import flax | 
					
					
						
						| 
							 | 
						import tensorflow as tf | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import sentencepiece as spm | 
					
					
						
						| 
							 | 
						from functools import partial | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						logger = logging.getLogger(__name__) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						try: | 
					
					
						
						| 
							 | 
						    import wandb | 
					
					
						
						| 
							 | 
						    WANDB_AVAILABLE = True | 
					
					
						
						| 
							 | 
						except ImportError: | 
					
					
						
						| 
							 | 
						    logger.warning("Weights & Biases not available. WandB logging will be disabled.") | 
					
					
						
						| 
							 | 
						    WANDB_AVAILABLE = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						from model.llm import LLM, LLMConfig | 
					
					
						
						| 
							 | 
						from data.tokenizer import SentencePieceTokenizer | 
					
					
						
						| 
							 | 
						from data.dataset import TextDataset, load_jsonl_dataset, StreamingDataset | 
					
					
						
						| 
							 | 
						from data.dataloader import TPUDataLoader | 
					
					
						
						| 
							 | 
						from training.trainer import Trainer, TrainingState, TrainingConfig as TrainerConfig | 
					
					
						
						| 
							 | 
						from training.optimizer import create_adamw_optimizer, create_lion_optimizer | 
					
					
						
						| 
							 | 
						from training.scheduler import create_linear_warmup_cosine_decay_schedule | 
					
					
						
						| 
							 | 
						from parallelism.data_parallel import DataParallel | 
					
					
						
						| 
							 | 
						from parallelism.tensor_parallel import TensorParallel | 
					
					
						
						| 
							 | 
						from config import create_config, Config | 
					
					
						
						| 
							 | 
						from utils.checkpoint import save_checkpoint, load_checkpoint | 
					
					
						
						| 
							 | 
						from utils.logging import setup_logger, log_metrics, create_summary_writer, log_metrics_to_tensorboard | 
					
					
						
						| 
							 | 
						from config import TrainingConfig, get_model_config | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def parse_args(): | 
					
					
						
						| 
							 | 
						    """Parse command line arguments.""" | 
					
					
						
						| 
							 | 
						    parser = argparse.ArgumentParser(description="Train LLM on TPU v4-32") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--model_size", type=str, default="7b", choices=["7b", "13b", "70b", "175b", "600b"], | 
					
					
						
						| 
							 | 
						                        help="Model size") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--learning_rate", type=float, default=3e-4, | 
					
					
						
						| 
							 | 
						                        help="Learning rate") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--batch_size", type=int, default=32, | 
					
					
						
						| 
							 | 
						                        help="Batch size per device") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--gradient_accumulation_steps", type=int, default=1, | 
					
					
						
						| 
							 | 
						                        help="Number of steps to accumulate gradients") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--max_steps", type=int, default=100000, | 
					
					
						
						| 
							 | 
						                        help="Maximum number of training steps") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--warmup_steps", type=int, default=1000, | 
					
					
						
						| 
							 | 
						                        help="Number of warmup steps") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--train_file", type=str, required=True, | 
					
					
						
						| 
							 | 
						                        help="Path to training file or HuggingFace dataset name") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--eval_file", type=str, default="", | 
					
					
						
						| 
							 | 
						                        help="Path to evaluation file or HuggingFace dataset name") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--tokenizer_file", type=str, required=True, | 
					
					
						
						| 
							 | 
						                        help="Path to tokenizer file") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--max_seq_length", type=int, default=131072, | 
					
					
						
						| 
							 | 
						                        help="Maximum sequence length (default: 128K tokens)") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--use_streaming", action="store_true", default=True, | 
					
					
						
						| 
							 | 
						                        help="Use streaming dataset for efficient training") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--streaming_buffer_size", type=int, default=10000, | 
					
					
						
						| 
							 | 
						                        help="Buffer size for streaming dataset") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--text_column", type=str, default="text", | 
					
					
						
						| 
							 | 
						                        help="Name of text column in dataset") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--preprocessing_num_workers", type=int, default=16, | 
					
					
						
						| 
							 | 
						                        help="Number of workers for dataset preprocessing") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--parallelism_type", type=str, default="data", choices=["data", "tensor"], | 
					
					
						
						| 
							 | 
						                        help="Type of parallelism") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--tensor_parallel_size", type=int, default=8, | 
					
					
						
						| 
							 | 
						                        help="Number of tensor parallel devices") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--use_flash_attention", action="store_true", default=True, | 
					
					
						
						| 
							 | 
						                        help="Use flash attention for efficiency") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--use_gradient_checkpointing", action="store_true", default=True, | 
					
					
						
						| 
							 | 
						                        help="Use gradient checkpointing to save memory") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--use_rope_scaling", action="store_true", default=True, | 
					
					
						
						| 
							 | 
						                        help="Use RoPE scaling for longer contexts") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--rope_scaling_factor", type=float, default=0.5, | 
					
					
						
						| 
							 | 
						                        help="Scaling factor for RoPE frequencies") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--use_reasoning_layer", action="store_true", default=True, | 
					
					
						
						| 
							 | 
						                        help="Use additional reasoning layers") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--num_reasoning_layers", type=int, default=None, | 
					
					
						
						| 
							 | 
						                        help="Number of additional reasoning layers (overrides model config)") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--output_dir", type=str, default="output", | 
					
					
						
						| 
							 | 
						                        help="Output directory") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--logging_steps", type=int, default=100, | 
					
					
						
						| 
							 | 
						                        help="Number of steps between logging") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--save_steps", type=int, default=1000, | 
					
					
						
						| 
							 | 
						                        help="Number of steps between checkpoints") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--eval_steps", type=int, default=1000, | 
					
					
						
						| 
							 | 
						                        help="Number of steps between evaluations") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--use_wandb", action="store_true", default=True, | 
					
					
						
						| 
							 | 
						                        help="Use Weights & Biases for logging") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--wandb_project", type=str, default="llm-training", | 
					
					
						
						| 
							 | 
						                        help="Weights & Biases project name") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--wandb_entity", type=str, default=None, | 
					
					
						
						| 
							 | 
						                        help="Weights & Biases entity name") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--wandb_run_name", type=str, default=None, | 
					
					
						
						| 
							 | 
						                        help="Weights & Biases run name") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--log_memory_usage", action="store_true", default=True, | 
					
					
						
						| 
							 | 
						                        help="Log memory usage during training") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--profile_steps", type=int, default=100, | 
					
					
						
						| 
							 | 
						                        help="Number of steps between profiling") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parser.add_argument("--seed", type=int, default=42, | 
					
					
						
						| 
							 | 
						                        help="Random seed") | 
					
					
						
						| 
							 | 
						    parser.add_argument("--resume_from_checkpoint", type=str, default="", | 
					
					
						
						| 
							 | 
						                        help="Path to checkpoint to resume from") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return parser.parse_args() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_config(args): | 
					
					
						
						| 
							 | 
						    """Create training configuration.""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    model_config = get_model_config(args.model_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if args.num_reasoning_layers is not None: | 
					
					
						
						| 
							 | 
						        model_config.num_reasoning_layers = args.num_reasoning_layers | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    model_config.use_flash_attention = args.use_flash_attention | 
					
					
						
						| 
							 | 
						    model_config.use_gradient_checkpointing = args.use_gradient_checkpointing | 
					
					
						
						| 
							 | 
						    model_config.use_rope_scaling = args.use_rope_scaling | 
					
					
						
						| 
							 | 
						    model_config.rope_scaling_factor = args.rope_scaling_factor | 
					
					
						
						| 
							 | 
						    model_config.use_reasoning_layer = args.use_reasoning_layer | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    config = TrainingConfig( | 
					
					
						
						| 
							 | 
						        output_dir=args.output_dir, | 
					
					
						
						| 
							 | 
						        model_config=model_config, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        learning_rate=args.learning_rate, | 
					
					
						
						| 
							 | 
						        batch_size=args.batch_size, | 
					
					
						
						| 
							 | 
						        gradient_accumulation_steps=args.gradient_accumulation_steps, | 
					
					
						
						| 
							 | 
						        max_steps=args.max_steps, | 
					
					
						
						| 
							 | 
						        warmup_steps=args.warmup_steps, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        train_file=args.train_file, | 
					
					
						
						| 
							 | 
						        eval_file=args.eval_file, | 
					
					
						
						| 
							 | 
						        tokenizer_file=args.tokenizer_file, | 
					
					
						
						| 
							 | 
						        max_seq_length=args.max_seq_length, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        parallelism_type=args.parallelism_type, | 
					
					
						
						| 
							 | 
						        tensor_parallel_size=args.tensor_parallel_size, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        use_flash_attention=args.use_flash_attention, | 
					
					
						
						| 
							 | 
						        use_gradient_checkpointing=args.use_gradient_checkpointing, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        use_rope_scaling=args.use_rope_scaling, | 
					
					
						
						| 
							 | 
						        rope_scaling_factor=args.rope_scaling_factor, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        use_reasoning_layer=args.use_reasoning_layer, | 
					
					
						
						| 
							 | 
						        num_reasoning_layers=args.num_reasoning_layers if args.num_reasoning_layers is not None else model_config.num_reasoning_layers, | 
					
					
						
						| 
							 | 
						        reasoning_intermediate_size=model_config.reasoning_intermediate_size, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logging_steps=args.logging_steps, | 
					
					
						
						| 
							 | 
						        save_steps=args.save_steps, | 
					
					
						
						| 
							 | 
						        eval_steps=args.eval_steps, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        seed=args.seed | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return config | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def setup_parallelism(config): | 
					
					
						
						| 
							 | 
						    """Set up parallelism.""" | 
					
					
						
						| 
							 | 
						    if config.parallelism_type == "data": | 
					
					
						
						| 
							 | 
						        return DataParallel() | 
					
					
						
						| 
							 | 
						    elif config.parallelism_type == "tensor": | 
					
					
						
						| 
							 | 
						        return TensorParallel(num_tp=config.tensor_parallel_size) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise ValueError(f"Parallelism type {config.parallelism_type} not supported") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_model(config): | 
					
					
						
						| 
							 | 
						    """Create model.""" | 
					
					
						
						| 
							 | 
						    return LLM(config.model_config) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_optimizer(config, num_train_steps): | 
					
					
						
						| 
							 | 
						    """Create optimizer.""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    lr_schedule = create_linear_warmup_cosine_decay_schedule( | 
					
					
						
						| 
							 | 
						        learning_rate=config.learning_rate, | 
					
					
						
						| 
							 | 
						        warmup_steps=config.warmup_steps, | 
					
					
						
						| 
							 | 
						        decay_steps=num_train_steps - config.warmup_steps, | 
					
					
						
						| 
							 | 
						        final_learning_rate_factor=0.1 | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if config.optimizer == "adamw": | 
					
					
						
						| 
							 | 
						        return create_adamw_optimizer( | 
					
					
						
						| 
							 | 
						            learning_rate=lr_schedule, | 
					
					
						
						| 
							 | 
						            weight_decay=config.weight_decay, | 
					
					
						
						| 
							 | 
						            b1=config.adam_beta1, | 
					
					
						
						| 
							 | 
						            b2=config.adam_beta2, | 
					
					
						
						| 
							 | 
						            eps=config.adam_epsilon | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    elif config.optimizer == "lion": | 
					
					
						
						| 
							 | 
						        return create_lion_optimizer( | 
					
					
						
						| 
							 | 
						            learning_rate=lr_schedule, | 
					
					
						
						| 
							 | 
						            weight_decay=config.weight_decay, | 
					
					
						
						| 
							 | 
						            b1=config.adam_beta1, | 
					
					
						
						| 
							 | 
						            b2=config.adam_beta2 | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise ValueError(f"Optimizer {config.optimizer} not supported") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_train_state(config, model, optimizer, rng): | 
					
					
						
						| 
							 | 
						    """Create training state.""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    dummy_input = jnp.ones((1, 1), dtype=jnp.int32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    params_rng, dropout_rng = jax.random.split(rng) | 
					
					
						
						| 
							 | 
						    params = model.init(params_rng, dummy_input) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return TrainingState.create( | 
					
					
						
						| 
							 | 
						        apply_fn=model.apply, | 
					
					
						
						| 
							 | 
						        params=params, | 
					
					
						
						| 
							 | 
						        tx=optimizer, | 
					
					
						
						| 
							 | 
						        dropout_rng=dropout_rng, | 
					
					
						
						| 
							 | 
						        loss_scale=1.0 | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_tokenizer(config): | 
					
					
						
						| 
							 | 
						    """Load tokenizer.""" | 
					
					
						
						| 
							 | 
						    return SentencePieceTokenizer(config.tokenizer_file) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_dataset(config, tokenizer): | 
					
					
						
						| 
							 | 
						    """Load dataset with streaming support for efficient training.""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if config.use_streaming: | 
					
					
						
						| 
							 | 
						        logger.info(f"Loading streaming dataset from {config.train_file}") | 
					
					
						
						| 
							 | 
						        train_dataset = StreamingDataset( | 
					
					
						
						| 
							 | 
						            tokenizer=tokenizer, | 
					
					
						
						| 
							 | 
						            dataset_path=config.train_file, | 
					
					
						
						| 
							 | 
						            max_seq_length=config.max_seq_length, | 
					
					
						
						| 
							 | 
						            streaming=True, | 
					
					
						
						| 
							 | 
						            buffer_size=config.streaming_buffer_size, | 
					
					
						
						| 
							 | 
						            seed=config.seed, | 
					
					
						
						| 
							 | 
						            text_column=config.text_column, | 
					
					
						
						| 
							 | 
						            preprocessing_num_workers=config.preprocessing_num_workers | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        logger.info("Streaming dataset loaded successfully") | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        logger.info(f"Loading standard dataset from {config.train_file}") | 
					
					
						
						| 
							 | 
						        train_dataset = load_jsonl_dataset( | 
					
					
						
						| 
							 | 
						            file_path=config.train_file, | 
					
					
						
						| 
							 | 
						            tokenizer=tokenizer, | 
					
					
						
						| 
							 | 
						            max_length=config.max_seq_length | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        logger.info(f"Dataset loaded with {len(train_dataset)} examples") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    eval_dataset = None | 
					
					
						
						| 
							 | 
						    if config.eval_file: | 
					
					
						
						| 
							 | 
						        if config.use_streaming: | 
					
					
						
						| 
							 | 
						            logger.info(f"Loading streaming evaluation dataset from {config.eval_file}") | 
					
					
						
						| 
							 | 
						            eval_dataset = StreamingDataset( | 
					
					
						
						| 
							 | 
						                tokenizer=tokenizer, | 
					
					
						
						| 
							 | 
						                dataset_path=config.eval_file, | 
					
					
						
						| 
							 | 
						                max_seq_length=config.max_seq_length, | 
					
					
						
						| 
							 | 
						                streaming=False,   | 
					
					
						
						| 
							 | 
						                buffer_size=config.streaming_buffer_size, | 
					
					
						
						| 
							 | 
						                seed=config.seed, | 
					
					
						
						| 
							 | 
						                text_column=config.text_column, | 
					
					
						
						| 
							 | 
						                preprocessing_num_workers=config.preprocessing_num_workers | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            logger.info("Streaming evaluation dataset loaded successfully") | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            logger.info(f"Loading standard evaluation dataset from {config.eval_file}") | 
					
					
						
						| 
							 | 
						            eval_dataset = load_jsonl_dataset( | 
					
					
						
						| 
							 | 
						                file_path=config.eval_file, | 
					
					
						
						| 
							 | 
						                tokenizer=tokenizer, | 
					
					
						
						| 
							 | 
						                max_length=config.max_seq_length | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            logger.info(f"Evaluation dataset loaded with {len(eval_dataset)} examples") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return train_dataset, eval_dataset | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_data_loaders(config, train_dataset, eval_dataset, tokenizer): | 
					
					
						
						| 
							 | 
						    """Create data loaders.""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    train_loader = TPUDataLoader( | 
					
					
						
						| 
							 | 
						        dataset=train_dataset, | 
					
					
						
						| 
							 | 
						        batch_size=config.batch_size, | 
					
					
						
						| 
							 | 
						        shuffle=True, | 
					
					
						
						| 
							 | 
						        drop_last=True, | 
					
					
						
						| 
							 | 
						        pad_token_id=tokenizer.pad_token_id | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    eval_loader = None | 
					
					
						
						| 
							 | 
						    if eval_dataset is not None: | 
					
					
						
						| 
							 | 
						        eval_loader = TPUDataLoader( | 
					
					
						
						| 
							 | 
						            dataset=eval_dataset, | 
					
					
						
						| 
							 | 
						            batch_size=config.batch_size, | 
					
					
						
						| 
							 | 
						            shuffle=False, | 
					
					
						
						| 
							 | 
						            drop_last=False, | 
					
					
						
						| 
							 | 
						            pad_token_id=tokenizer.pad_token_id | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return train_loader, eval_loader | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def main(): | 
					
					
						
						| 
							 | 
						    """Main function optimized for TPU v4-32.""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    args = parse_args() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    print("TPU Configuration:") | 
					
					
						
						| 
							 | 
						    print(f"Number of TPU devices: {jax.device_count()}") | 
					
					
						
						| 
							 | 
						    print(f"TPU devices: {jax.devices()}") | 
					
					
						
						| 
							 | 
						    print(f"JAX process index: {jax.process_index()}") | 
					
					
						
						| 
							 | 
						    print(f"JAX process count: {jax.process_count()}") | 
					
					
						
						| 
							 | 
						    print(f"JAX local devices: {jax.local_devices()}") | 
					
					
						
						| 
							 | 
						    print(f"JAX local device count: {jax.local_device_count()}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    config = create_config(args) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    os.makedirs(config.output_dir, exist_ok=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logger = setup_logger( | 
					
					
						
						| 
							 | 
						        name="tpu_train", | 
					
					
						
						| 
							 | 
						        log_file=os.path.join(config.output_dir, "train.log") | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logger.info(f"Configuration: {config}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if args.use_wandb and WANDB_AVAILABLE: | 
					
					
						
						| 
							 | 
						        logger.info("Initializing Weights & Biases") | 
					
					
						
						| 
							 | 
						        wandb_run_name = args.wandb_run_name or f"{args.model_size}-{time.strftime('%Y%m%d-%H%M%S')}" | 
					
					
						
						| 
							 | 
						        wandb.init( | 
					
					
						
						| 
							 | 
						            project=args.wandb_project, | 
					
					
						
						| 
							 | 
						            entity=args.wandb_entity, | 
					
					
						
						| 
							 | 
						            name=wandb_run_name, | 
					
					
						
						| 
							 | 
						            config={ | 
					
					
						
						| 
							 | 
						                "model_size": args.model_size, | 
					
					
						
						| 
							 | 
						                "learning_rate": args.learning_rate, | 
					
					
						
						| 
							 | 
						                "batch_size": args.batch_size, | 
					
					
						
						| 
							 | 
						                "gradient_accumulation_steps": args.gradient_accumulation_steps, | 
					
					
						
						| 
							 | 
						                "max_steps": args.max_steps, | 
					
					
						
						| 
							 | 
						                "warmup_steps": args.warmup_steps, | 
					
					
						
						| 
							 | 
						                "max_seq_length": args.max_seq_length, | 
					
					
						
						| 
							 | 
						                "parallelism_type": args.parallelism_type, | 
					
					
						
						| 
							 | 
						                "tensor_parallel_size": args.tensor_parallel_size, | 
					
					
						
						| 
							 | 
						                "use_flash_attention": args.use_flash_attention, | 
					
					
						
						| 
							 | 
						                "use_gradient_checkpointing": args.use_gradient_checkpointing, | 
					
					
						
						| 
							 | 
						                "use_rope_scaling": args.use_rope_scaling, | 
					
					
						
						| 
							 | 
						                "rope_scaling_factor": args.rope_scaling_factor, | 
					
					
						
						| 
							 | 
						                "use_reasoning_layer": args.use_reasoning_layer, | 
					
					
						
						| 
							 | 
						                "num_reasoning_layers": args.num_reasoning_layers, | 
					
					
						
						| 
							 | 
						                "use_streaming": args.use_streaming, | 
					
					
						
						| 
							 | 
						                "streaming_buffer_size": args.streaming_buffer_size, | 
					
					
						
						| 
							 | 
						                "text_column": args.text_column, | 
					
					
						
						| 
							 | 
						                "preprocessing_num_workers": args.preprocessing_num_workers, | 
					
					
						
						| 
							 | 
						                "seed": args.seed, | 
					
					
						
						| 
							 | 
						            } | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        logger.info(f"Weights & Biases initialized with run name: {wandb_run_name}") | 
					
					
						
						| 
							 | 
						    elif args.use_wandb and not WANDB_AVAILABLE: | 
					
					
						
						| 
							 | 
						        logger.warning("Weights & Biases not available. Install wandb package to enable logging.") | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        logger.info("Weights & Biases logging disabled.") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logger.info(f"Training on TPU v4-32 with {jax.device_count()} devices") | 
					
					
						
						| 
							 | 
						    logger.info(f"Model size: {args.model_size} ({config.model_config.hidden_size} hidden size, " | 
					
					
						
						| 
							 | 
						               f"{config.model_config.num_hidden_layers} layers)") | 
					
					
						
						| 
							 | 
						    logger.info(f"Max sequence length: {args.max_seq_length} tokens") | 
					
					
						
						| 
							 | 
						    logger.info(f"Batch size: {args.batch_size} per device, {args.batch_size * jax.device_count()} global") | 
					
					
						
						| 
							 | 
						    logger.info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Effective batch size: {args.batch_size * jax.device_count() * args.gradient_accumulation_steps}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Learning rate: {args.learning_rate}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Warmup steps: {args.warmup_steps}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Max steps: {args.max_steps}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Parallelism type: {args.parallelism_type}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Tensor parallel size: {args.tensor_parallel_size}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Using streaming dataset: {args.use_streaming}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Using flash attention: {args.use_flash_attention}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Using gradient checkpointing: {args.use_gradient_checkpointing}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Using RoPE scaling: {args.use_rope_scaling}") | 
					
					
						
						| 
							 | 
						    logger.info(f"RoPE scaling factor: {args.rope_scaling_factor}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Using reasoning layer: {args.use_reasoning_layer}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Number of reasoning layers: {config.model_config.num_reasoning_layers}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Random seed: {args.seed}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Output directory: {args.output_dir}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Logging steps: {args.logging_steps}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Save steps: {args.save_steps}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Eval steps: {args.eval_steps}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Profile steps: {args.profile_steps}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Using Weights & Biases: {args.use_wandb and WANDB_AVAILABLE}") | 
					
					
						
						| 
							 | 
						    logger.info(f"Logging memory usage: {args.log_memory_usage}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    param_count = ( | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config.model_config.vocab_size * config.model_config.hidden_size + | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config.model_config.num_hidden_layers * ( | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            4 * config.model_config.hidden_size * config.model_config.hidden_size + | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            2 * config.model_config.hidden_size * config.model_config.intermediate_size + | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            4 * config.model_config.hidden_size | 
					
					
						
						| 
							 | 
						        ) + | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        (config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * ( | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            4 * config.model_config.hidden_size * config.model_config.hidden_size + | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size + | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            4 * config.model_config.hidden_size | 
					
					
						
						| 
							 | 
						        ) + | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config.model_config.hidden_size + | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config.model_config.hidden_size * config.model_config.vocab_size | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4   | 
					
					
						
						| 
							 | 
						    model_size_gb = param_count * bytes_per_param / 1e9 | 
					
					
						
						| 
							 | 
						    optimizer_size_gb = model_size_gb * 2   | 
					
					
						
						| 
							 | 
						    activation_size_gb = model_size_gb * 0.2   | 
					
					
						
						| 
							 | 
						    total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logger.info(f"Estimated memory requirements:") | 
					
					
						
						| 
							 | 
						    logger.info(f"  Model parameters: {model_size_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						    logger.info(f"  Optimizer states: {optimizer_size_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						    logger.info(f"  Activations: {activation_size_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						    logger.info(f"  Total: {total_memory_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    tpu_memory_gb = 32 * jax.device_count()   | 
					
					
						
						| 
							 | 
						    logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						    if total_memory_gb > tpu_memory_gb * 0.9:   | 
					
					
						
						| 
							 | 
						        logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)") | 
					
					
						
						| 
							 | 
						        logger.warning("Consider enabling gradient checkpointing and using a smaller batch size") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    param_count = ( | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config.model_config.vocab_size * config.model_config.hidden_size + | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config.model_config.num_hidden_layers * ( | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            4 * config.model_config.hidden_size * config.model_config.hidden_size + | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            2 * config.model_config.hidden_size * config.model_config.intermediate_size + | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            4 * config.model_config.hidden_size | 
					
					
						
						| 
							 | 
						        ) + | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        (config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * ( | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            4 * config.model_config.hidden_size * config.model_config.hidden_size + | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size + | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            4 * config.model_config.hidden_size | 
					
					
						
						| 
							 | 
						        ) + | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config.model_config.hidden_size + | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config.model_config.hidden_size * config.model_config.vocab_size | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4   | 
					
					
						
						| 
							 | 
						    model_size_gb = param_count * bytes_per_param / 1e9 | 
					
					
						
						| 
							 | 
						    optimizer_size_gb = model_size_gb * 2   | 
					
					
						
						| 
							 | 
						    activation_size_gb = model_size_gb * 0.2   | 
					
					
						
						| 
							 | 
						    total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logger.info(f"Estimated memory requirements:") | 
					
					
						
						| 
							 | 
						    logger.info(f"  Model parameters: {model_size_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						    logger.info(f"  Optimizer states: {optimizer_size_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						    logger.info(f"  Activations: {activation_size_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						    logger.info(f"  Total: {total_memory_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    tpu_memory_gb = 32 * jax.device_count()   | 
					
					
						
						| 
							 | 
						    logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB") | 
					
					
						
						| 
							 | 
						    if total_memory_gb > tpu_memory_gb * 0.9:   | 
					
					
						
						| 
							 | 
						        logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)") | 
					
					
						
						| 
							 | 
						        logger.warning("Consider enabling gradient checkpointing and using a smaller batch size") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    rng = jax.random.PRNGKey(config.seed) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    start_time = time.time() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    parallel = setup_parallelism(config) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    model = create_model(config) | 
					
					
						
						| 
							 | 
						    logger.info(f"Model created in {time.time() - start_time:.2f} seconds") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    optimizer = create_optimizer(config, config.max_steps) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    state_start_time = time.time() | 
					
					
						
						| 
							 | 
						    state = create_train_state(config, model, optimizer, rng) | 
					
					
						
						| 
							 | 
						    logger.info(f"Training state created in {time.time() - state_start_time:.2f} seconds") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    shard_start_time = time.time() | 
					
					
						
						| 
							 | 
						    state = state.replace(params=parallel.shard_params(state.params)) | 
					
					
						
						| 
							 | 
						    logger.info(f"Parameters sharded in {time.time() - shard_start_time:.2f} seconds") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if args.resume_from_checkpoint: | 
					
					
						
						| 
							 | 
						        checkpoint_start_time = time.time() | 
					
					
						
						| 
							 | 
						        state, step = load_checkpoint(args.resume_from_checkpoint, state) | 
					
					
						
						| 
							 | 
						        logger.info(f"Checkpoint loaded in {time.time() - checkpoint_start_time:.2f} seconds") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    tokenizer_start_time = time.time() | 
					
					
						
						| 
							 | 
						    tokenizer = load_tokenizer(config) | 
					
					
						
						| 
							 | 
						    logger.info(f"Tokenizer loaded in {time.time() - tokenizer_start_time:.2f} seconds") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    dataset_start_time = time.time() | 
					
					
						
						| 
							 | 
						    train_dataset, eval_dataset = load_dataset(config, tokenizer) | 
					
					
						
						| 
							 | 
						    logger.info(f"Datasets loaded in {time.time() - dataset_start_time:.2f} seconds") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    dataloader_start_time = time.time() | 
					
					
						
						| 
							 | 
						    train_loader, eval_loader = create_data_loaders( | 
					
					
						
						| 
							 | 
						        config, | 
					
					
						
						| 
							 | 
						        train_dataset, | 
					
					
						
						| 
							 | 
						        eval_dataset, | 
					
					
						
						| 
							 | 
						        tokenizer | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    logger.info(f"Data loaders created in {time.time() - dataloader_start_time:.2f} seconds") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    summary_writer = create_summary_writer( | 
					
					
						
						| 
							 | 
						        os.path.join(config.output_dir, "tensorboard") | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    trainer_config = TrainerConfig( | 
					
					
						
						| 
							 | 
						        model_config=config.model_config, | 
					
					
						
						| 
							 | 
						        learning_rate=config.learning_rate, | 
					
					
						
						| 
							 | 
						        weight_decay=config.weight_decay, | 
					
					
						
						| 
							 | 
						        warmup_steps=config.warmup_steps, | 
					
					
						
						| 
							 | 
						        max_steps=config.max_steps, | 
					
					
						
						| 
							 | 
						        batch_size=config.batch_size, | 
					
					
						
						| 
							 | 
						        gradient_accumulation_steps=config.gradient_accumulation_steps, | 
					
					
						
						| 
							 | 
						        max_grad_norm=config.max_grad_norm, | 
					
					
						
						| 
							 | 
						        adam_beta1=config.adam_beta1, | 
					
					
						
						| 
							 | 
						        adam_beta2=config.adam_beta2, | 
					
					
						
						| 
							 | 
						        adam_epsilon=config.adam_epsilon, | 
					
					
						
						| 
							 | 
						        logging_steps=config.logging_steps, | 
					
					
						
						| 
							 | 
						        save_steps=config.save_steps, | 
					
					
						
						| 
							 | 
						        eval_steps=config.eval_steps, | 
					
					
						
						| 
							 | 
						        output_dir=config.output_dir, | 
					
					
						
						| 
							 | 
						        seed=config.seed, | 
					
					
						
						| 
							 | 
						        dtype=config.dtype, | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        use_pjit=True,   | 
					
					
						
						| 
							 | 
						        use_scan=True,   | 
					
					
						
						| 
							 | 
						        use_remat=config.model_config.use_gradient_checkpointing,   | 
					
					
						
						| 
							 | 
						        use_sharded_optim=True,   | 
					
					
						
						| 
							 | 
						        profile_steps=100,   | 
					
					
						
						| 
							 | 
						        async_checkpointing=True,   | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    trainer = Trainer( | 
					
					
						
						| 
							 | 
						        config=trainer_config, | 
					
					
						
						| 
							 | 
						        model=model, | 
					
					
						
						| 
							 | 
						        train_dataloader=train_loader, | 
					
					
						
						| 
							 | 
						        eval_dataloader=eval_loader, | 
					
					
						
						| 
							 | 
						        state=state, | 
					
					
						
						| 
							 | 
						        parallel=parallel,   | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logger.info(f"Total initialization time: {time.time() - start_time:.2f} seconds") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    steps_per_day = 24 * 60 * 60 / (5 * 60)   | 
					
					
						
						| 
							 | 
						    estimated_days = config.max_steps / steps_per_day | 
					
					
						
						| 
							 | 
						    logger.info(f"Estimated training time: {estimated_days:.2f} days for {config.max_steps} steps") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        train_start_time = time.time() | 
					
					
						
						| 
							 | 
						        trainer.train() | 
					
					
						
						| 
							 | 
						        train_duration = time.time() - train_start_time | 
					
					
						
						| 
							 | 
						        logger.info(f"Training completed in {train_duration / 3600:.2f} hours") | 
					
					
						
						| 
							 | 
						        logger.info(f"Average training speed: {config.max_steps / train_duration:.2f} steps/second") | 
					
					
						
						| 
							 | 
						    except Exception as e: | 
					
					
						
						| 
							 | 
						        logger.error(f"Training failed with error: {e}") | 
					
					
						
						| 
							 | 
						        raise | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    main() | 
					
					
						
						| 
							 | 
						
 |