Spaces:
Paused
Paused
| import argparse | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.distributed import init_process_group, destroy_process_group | |
| import torch | |
| import wandb | |
| import torch.optim as optim | |
| import os | |
| from config import ModelArgs | |
| from model import Llama | |
| from inference import greedy_decode | |
| from data import prepare_dataset | |
| from tokenizer import Tokenizer | |
| torch.set_float32_matmul_precision('high') | |
| scaler = torch.amp.GradScaler(enabled=(ModelArgs.dtype == 'float16')) | |
| save_chechpoint_iter = 50 | |
| total_iters = 10000 | |
| eval_iters = 50 | |
| eval_check = 100 | |
| warmup_iters = 700 | |
| min_lr = 0.1 * ModelArgs.max_lr | |
| lr_decay_iters = 10000 | |
| total_batch_size = 524288 | |
| micro_batch_size = ModelArgs.batch_size | |
| gradient_accumulation_steps = total_batch_size // (micro_batch_size * (ModelArgs.block_size * torch.cuda.device_count())) | |
| class Trainer: | |
| def __init__(self, model_args): | |
| def setup(rank=None, world_size=None): | |
| # os.environ['MASTER_ADDR'] = 'localhost' | |
| # os.environ['MASTER_PORT'] = '12355' | |
| init_process_group("nccl") | |
| # torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) | |
| self.model_args = model_args | |
| self.tokenizer = Tokenizer().ready_tokenizer() | |
| setup() | |
| def cleanup(self): | |
| destroy_process_group() | |
| def _save_snapshot(self, model, optimizer, epoch, step, save_dir): | |
| snapshot = {} | |
| snapshot["MODEL_STATE"] = model.module.state_dict() | |
| snapshot["OPTIMIZER_STATE"]= optimizer.state_dict() | |
| snapshot["EPOCHS_RUN"] = epoch | |
| snapshot["STEP_RUN"] = step | |
| torch.save(snapshot, os.path.join(save_dir, "snapshot.pt")) | |
| print(f"Epoch: {epoch} | step {step} | Training snapshot saved at snapshot.pt") | |
| # Warmup phase for 2000 steps | |
| def warmup_fn(step): | |
| if step < 2000: | |
| return step / 2000 # LR gradually increases | |
| return 1.0 | |
| # learning rate decay scheduler (cosine with warmup) from https://github.com/karpathy/nanoGPT/blob/master/train.py | |
| def get_lr(it): | |
| # 1) linear warmup for warmup_iters steps | |
| if it < warmup_iters: | |
| return ModelArgs.max_lr * (it + 1) / (warmup_iters + 1) | |
| # 2) if it > lr_decay_iters, return min learning rate | |
| if it > lr_decay_iters: | |
| return min_lr | |
| # 3) in between, use cosine decay down to min learning rate | |
| decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) | |
| assert 0 <= decay_ratio <= 1 | |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) | |
| return min_lr + coeff * (ModelArgs.max_lr - min_lr) | |
| def train(): | |
| setup() | |
| device = int(os.environ["LOCAL_RANK"]) | |
| torch.cuda.set_device(int(device)) | |
| print(f"Start running DDP on rank {device}.") | |
| if(device == 0): | |
| # # Initialise run | |
| wandb.init( | |
| # entity = 'rajceo2031', | |
| project = 'Llama-DDP-Pretrain-10-billion-tokens', | |
| # config = CFG, | |
| # save_code = True, | |
| #group = 'ANN', | |
| #job_type = 'train' | |
| ) | |
| print("wand initialized") | |
| model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device) | |
| # print(f"Model on device {device} is ready") | |
| print(f"Model on device {device} is ready") | |
| optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=ModelArgs.eps) | |
| # model = torch.compile(model) | |
| model = model.to(device) | |
| model = DDP(model, device_ids=[device]) | |
| model.eval() | |
| world_size = torch.cuda.device_count() | |
| def estimate_loss(val_loader, val_iterator, device): | |
| out = {} | |
| loader = None | |
| epoch_loss = None | |
| epoch_losses = [] | |
| for split in ['val']: | |
| print(f"Starting with {split} evaluation...") | |
| for step in range(eval_check): | |
| try: | |
| batch = next(val_iterator) | |
| except StopIteration: | |
| val_loader_iterator = iter(val_loader) | |
| batch = next(val_loader_iterator) | |
| total_loss = 0 | |
| total_batches = 0 | |
| idx = batch['input_ids'] | |
| targets = batch['labels'] | |
| idx = idx.to(device) | |
| targets = targets.to(device) | |
| with torch.autocast(device_type=device, dtype=torch.bfloat16): | |
| logits = model(idx) | |
| batch_size, block_size, embeddings_dims = logits.shape | |
| logits = logits.view(batch_size * block_size, embeddings_dims) | |
| targets = targets.view(batch_size * block_size) | |
| loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) | |
| total_loss += loss.item() | |
| total_batches += 1 | |
| epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0 | |
| epoch_losses.append(epoch_loss) | |
| out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0 | |
| epoch_loss = None | |
| epoch_losses = [] | |
| model.train() | |
| return out | |
| model.train() | |
| count = 0 | |
| train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size) | |
| val_loader= prepare_dataset('val', device, ModelArgs.batch_size) | |
| print("Loaders ready both") | |
| epochs = ModelArgs.epochs | |
| train_loader_length = 0 | |
| train_data_iterator = iter(train_dataloader) | |
| val_data_iterator = iter(val_loader) | |
| token_count = 0 | |
| if(device == 0): | |
| train_loader_length = len(train_dataloader) | |
| for step in tqdm(range(total_iters)): | |
| if(device == 0): | |
| print("Step : ", step, "/", total_iters) | |
| print('Total batches: ', len(train_dataloader)) | |
| print("Total gradient accumulation steps: ", gradient_accumulation_steps) | |
| print("Total tokens processed: ", token_count) | |
| if (step % eval_iters == 0 and step != 0) or step == total_iters - 1: | |
| losses = estimate_loss( val_loader, val_data_iterator, 'cuda') | |
| # avg_train_loss = losses['train'] | |
| avg_val_loss = losses['val'] | |
| print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}") | |
| avg_val_loss = torch.Tensor([losses['val']]).to(device) | |
| # torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| if device == 0: | |
| all_gpus_avg_val_loss = avg_val_loss / world_size | |
| print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}") | |
| wandb.log({ | |
| # "Learning Rate": optimizer.param_groups[0]['lr'], | |
| # "All_GPUs_Train_losses": all_gpus_avg_train_loss, | |
| "All_GPUs_Val_losses": all_gpus_avg_val_loss, | |
| # "training_step_loss": losses['train'], | |
| "val_step_loss": losses['val'], | |
| # "Step": step, | |
| # "Epoch": epoch | |
| }) | |
| if step % save_chechpoint_iter == 0 and device == 0 and step != 0: | |
| print(f"Saving the model checkpoint for step: {step}") | |
| _save_snapshot(model, optimizer, None, None, step) | |
| accumulated_loss = 0.0 | |
| optimizer.zero_grad(set_to_none=True) | |
| for micro_step in range(gradient_accumulation_steps): | |
| try: | |
| batch = next(train_data_iterator) | |
| except StopIteration: | |
| train_data_iterator = iter(train_dataloader) | |
| batch = next(train_data_iterator) | |
| # print(batch) | |
| # batch = next(train_data_iterator) | |
| # print(batch) | |
| # batch = {k: v.to(self.local_rank) for k, v in batch.items()} | |
| idx = batch['input_ids'].to(device) | |
| # idx, targets = get_batch(split='train') | |
| # print(f"Starting the train step: {step}...") | |
| # for idx, targets in train_loader: | |
| # idx, targets = next(iter(train_loader)) | |
| # print("Idx: ", idx) | |
| # print("Targets: ", targets) | |
| # idx = idx.to(device) | |
| # print("Idx: ", idx) | |
| # print("Targets: ", targets) | |
| targets = batch['labels'].to(device) | |
| token_count += len(idx) | |
| with torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16): | |
| logits = model(idx) | |
| batch_size, block_size, embeddings_dims = logits.shape | |
| # print(logits.shape) | |
| # print(targets) | |
| logits = logits.view(batch_size*block_size, embeddings_dims) | |
| # print("OK") | |
| targets = targets.view(batch_size * block_size) | |
| # print("OK2") | |
| loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) | |
| loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch | |
| accumulated_loss += loss.detach() | |
| model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices | |
| scaler.scale(loss).backward() | |
| # Check for unused parameters | |
| unused_params = find_unused_parameters(model) | |
| if unused_params: | |
| print(f"Unused parameters: {unused_params}") | |
| # break | |
| if(device == 0): | |
| if(micro_step % 10 == 0): | |
| # if(step == train_loader_length): | |
| # break | |
| print("Micro Batch : ", micro_step) | |
| print("Step : ", step, "/", total_iters) | |
| print('Total batches: ', len(train_dataloader)) | |
| print("Total gradient accumulation steps: ", gradient_accumulation_steps) | |
| print("Total tokens processed: ", token_count) | |
| # count += 1 | |
| lr = get_lr(step) | |
| for params in optimizer.param_groups: | |
| params['lr'] = lr | |
| # Compute gradient norms before clipping | |
| if(ModelArgs.clip != 0.0): | |
| scaler.unscale_(optimizer) #To avoid underflow | |
| total_norm_before = torch.norm( | |
| torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 | |
| ) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip) | |
| # Compute gradient norms after clipping | |
| total_norm_after = torch.norm( | |
| torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 | |
| ) | |
| if(device == 0 and step !=0): | |
| print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}") | |
| print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}") | |
| scaler.step(optimizer) | |
| scaler.update() | |
| # optimizer.step() | |
| # new_scheduler.step() | |
| torch.cuda.synchronize() | |
| torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| if(device == 0): | |
| wandb.log({ | |
| "Learning Rate": lr, | |
| "All_GPUs_Train_losses": accumulated_loss.item(), | |
| # "All_GPUs_Val_losses": all_gpus_avg_val_loss, | |
| # "training_step_loss": losses['train'], | |
| # "val_step_loss": losses['val'], | |
| "Step": step, | |
| # "Epoch": epoch | |
| }) | |
| # print(loss.item()) | |
| # break | |
| if device == 0 and step % 5 == 0: | |
| count = 3 | |
| while(count): # Only generate text on the main process | |
| prompt = "Once upon a time" | |
| generated_text = topk_sampling(model, prompt, max_length=50, top_k=50, temperature=1.0, device=device) | |
| print(f" Step: {step} | Generated Text: {generated_text}") | |
| count -= 1 | |
| if device == 0: | |
| wandb.finish() | |
| cleanup() | |
| world_size = torch.cuda.device_count() | |
| print(f"World size: {world_size}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Model Training Arguments") | |
| # Add arguments for each field in ModelArgs | |
| parser.add_argument("--epochs", type=int, default=ModelArgs.epochs, help="Number of training epochs.") | |
| parser.add_argument("--block_size", type=int, default=ModelArgs.block_size, help="Block size for the model.") | |
| parser.add_argument("--batch_size", type=int, default=ModelArgs.batch_size, help="Batch size for training.") | |
| # parser.add_argument("--inference", type=lambda x: (str(x).lower() == 'true'), default=ModelArgs.inference, help="Whether to run in inference mode.") | |
| parser.add_argument("--embeddings_dims", type=int, default=ModelArgs.embeddings_dims, help="Embedding dimensions.") | |
| parser.add_argument("--attn_dropout", type=float, default=ModelArgs.attn_dropout, help="Attention dropout rate.") | |
| parser.add_argument("--no_of_heads", type=int, default=ModelArgs.no_of_heads, help="Number of attention heads.") | |
| parser.add_argument("--dropout", type=float, default=ModelArgs.dropout, help="Dropout rate.") | |
| parser.add_argument("--val_epochs", type=int, default=ModelArgs.val_epochs, help="Number of validation epochs.") | |
| parser.add_argument("--max_lr", type=float, default=ModelArgs.max_lr, help="Learning rate.") | |
| parser.add_argument("--no_of_decoder_layers", type=int, default=ModelArgs.no_of_decoder_layers, help="Number of decoder layers.") | |
| parser.add_argument("--weight_decay_optim", type=float, default=ModelArgs.weight_decay_optim, help="Weight decay for optimizer.") | |
| parser.add_argument("--beta_1", type=float, default=ModelArgs.beta_1, help="Beta1 for Adam optimizer.") | |
| parser.add_argument("--beta_2", type=float, default=ModelArgs.beta_2, help="Beta2 for Adam optimizer.") | |
| parser.add_argument("--clip", type=float, default=ModelArgs.clip, help="Gradient clipping value.") | |
| parser.add_argument("--device", type=str, default=ModelArgs.device, help="Device to run the model on (e.g., 'cuda' or 'cpu').") | |
| parser.add_argument("--no_kv_heads", type=int, default=ModelArgs.no_kv_heads, help="Number of key/value heads.") | |
| parser.add_argument("--vocab_size", type=int, default=ModelArgs.vocab_size, help="Vocabulary size.") | |
| parser.add_argument("--eps", type=float, default=ModelArgs.eps, help="Epsilon value for numerical stability.") | |
| parser.add_argument("--dtype", type=str, default=ModelArgs.dtype, help="Data type for tensors (e.g., 'float16' or 'bfloat16').") | |
| parser.add_argument("--save_checkpoint_dir", type=str, default=ModelArgs.save_checkpoint_dir, help="Directory to save model checkpoints.") | |
| parser.add_argument("--prompt", type=str, default=ModelArgs.prompt, help="Prompt for testing during training.") | |
| # Additional arguments | |
| parser.add_argument("--save_checkpoint_iter", type=int, default=ModelArgs.save_checkpoint_iter, help="Save checkpoint every N iterations.") | |
| parser.add_argument("--total_iters", type=int, default=ModelArgs.total_iters, help="Total number of training iterations.") | |
| parser.add_argument("--eval_iters", type=int, default=ModelArgs.eval_iters, help="Number of iterations for evaluation.") | |
| parser.add_argument("--eval_check", type=int, default=ModelArgs.eval_check, help="Evaluate model every N iterations.") | |
| parser.add_argument("--warmup_iters", type=int, default=ModelArgs.warmup_iters, help="Number of warmup iterations for learning rate scheduling.") | |
| parser.add_argument("--min_lr", type=float, default=ModelArgs.min_lr, help="Minimum learning rate.") | |
| parser.add_argument("--lr_decay_iters", type=int, default=ModelArgs.lr_decay_iters, help="Number of iterations for learning rate decay.") | |
| parser.add_argument("--total_batch_size", type=int, default=ModelArgs.total_batch_size, help="Total batch size across all devices.") | |
| parser.add_argument("--micro_batch_size", type=int, default=ModelArgs.micro_batch_size, help="Micro batch size per device.") | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=ModelArgs.gradient_accumulation_steps, help="Number of gradient accumulation steps.") | |
| args = parser.parse_args() | |
| return args | |
| def initialize_model_args(args): | |
| # Create a ModelArgs instance from the parsed arguments | |
| model_args = ModelArgs( | |
| epochs=args.epochs, | |
| block_size=args.block_size, | |
| batch_size=args.batch_size, | |
| # inference=args.inference, | |
| embeddings_dims=args.embeddings_dims, | |
| attn_dropout=args.attn_dropout, | |
| no_of_heads=args.no_of_heads, | |
| dropout=args.dropout, | |
| val_epochs=args.val_epochs, | |
| max_lr=args.max_lr, | |
| no_of_decoder_layers=args.no_of_decoder_layers, | |
| weight_decay_optim=args.weight_decay_optim, | |
| beta_1=args.beta_1, | |
| beta_2=args.beta_2, | |
| clip=args.clip, | |
| device=args.device, | |
| no_kv_heads=args.no_kv_heads, | |
| vocab_size=args.vocab_size, | |
| eps=args.eps, | |
| dtype=args.dtype, | |
| save_checkpoint_dir=args.save_checkpoint_dir, | |
| prompt=args.prompt, | |
| save_checkpoint_iter=args.save_checkpoint_iter, | |
| total_iters=args.total_iters, | |
| eval_iters=args.eval_iters, | |
| eval_check=args.eval_check, | |
| warmup_iters=args.warmup_iters, | |
| min_lr=args.min_lr, | |
| lr_decay_iters=args.lr_decay_iters, | |
| total_batch_size=args.total_batch_size, | |
| micro_batch_size=args.micro_batch_size, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps | |
| ) | |
| return model_args | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| model_args = initialize_model_args(args) | |