Spaces:
Paused
Paused
| #Based on Llama from Meta (https://github.com/meta-llama/llama/blob/main/llama/model.py) | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from dataclasses import dataclass | |
| from tokenizers import Tokenizer | |
| from pathlib import Path | |
| import torch.multiprocessing as mp | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.distributed import init_process_group, destroy_process_group | |
| import torch | |
| from datasets import Dataset | |
| from torch.utils.data import DataLoader | |
| from transformers.models.prophetnet.modeling_prophetnet import ProphetNetDecoderModelOutput | |
| import wandb | |
| from tqdm import tqdm | |
| from functools import partial | |
| import tiktoken | |
| import torch.optim as optim | |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts | |
| # Load model directly | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import os | |
| torch.manual_seed(1337) | |
| torch.cuda.manual_seed(1337) | |
| # import wandb | |
| # wandb.login() | |
| # from torch.utils.tensorboard import SummaryWriter | |
| from datasets import load_dataset, concatenate_datasets | |
| # data = {} | |
| # texts = [] | |
| # with open('data/input.txt', 'r') as f: | |
| # texts.append(f.readlines()) | |
| # # print(texts) | |
| # # print(len(texts[0])) | |
| # data = { | |
| # "text": texts[0] | |
| # } | |
| # fw_train = Dataset.from_dict(data) | |
| # print(fw_train) | |
| # fw_train = load_dataset("karpathy/tiny_shakespeare", split="train", trust_remote_code=True) | |
| # print(fw_train['text']) | |
| # text = fw_train['text'][0].split("\n") | |
| # print(text) | |
| # filtered_lines = [line for line in text if line != ''] | |
| # print(len(filtered_lines)) | |
| # use name="sample-10BT" to use the 10BT sample | |
| tinystories = True | |
| fw = False | |
| fw_train = None | |
| fw_test = None | |
| if(tinystories): | |
| fw_train = load_dataset("roneneldan/TinyStories", split="train") | |
| fw_test = load_dataset("roneneldan/TinyStories", split="validation") | |
| print(fw_train) | |
| print(fw_test) | |
| if(fw): | |
| fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False) | |
| fw_train = fw_train.train_test_split(test_size=0.01) | |
| print(fw_train) | |
| print(fw_train) | |
| # Select only 1000 rows from the dataset | |
| # fw_train = fw_train.select(range(1000000)) | |
| # alpaca = load_dataset("yahma/alpaca-cleaned", split='train') | |
| # dolly = load_dataset("llm-wizard/dolly-15k-instruction-alpaca-format", split='train') | |
| # merged_dataset = concatenate_datasets([alpaca, dolly]) | |
| # dataset = load_dataset("swype/instruct", split='train', trust_remote_code=True) | |
| # print(fw_train) | |
| # Split the dataset into training and validation sets | |
| # Split the dataset into training and validation sets | |
| # fw_train = fw_train.train_test_split(test_size=0.01) | |
| # print(fw_train) | |
| # Access the splits | |
| # train_dataset = train_val_split['train'] | |
| # val_dataset = train_val_split['test'] | |
| # train_dataset = fw_train.train_test_split(test_size=0.2) | |
| def setup(rank=None, world_size=None): | |
| # os.environ['MASTER_ADDR'] = 'localhost' | |
| # os.environ['MASTER_PORT'] = '12355' | |
| init_process_group("nccl") | |
| # torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) | |
| def cleanup(): | |
| destroy_process_group() | |
| class ModelArgs: | |
| #Hyperparameters | |
| epochs = 4 | |
| block_size = 512 | |
| batch_size = 64 | |
| embeddings_dims = 512 | |
| attn_dropout = 0.1 | |
| no_of_heads = 8 | |
| dropout = 0.1 | |
| # epochs = 100 | |
| val_epochs = 2 | |
| max_lr = 6e-4 | |
| no_of_decoder_layers = 8 #IMP needs to be thoroughly calculated | |
| weight_decay_optim = 0.1 | |
| beta_1 = 0.9 | |
| beta_2 = 0.95 | |
| clip = 1.0 | |
| device = 'cuda' | |
| no_kv_heads = 2 | |
| vocab_size = 50304 #powers of 2 so nice! | |
| eps = 1e-5 | |
| dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' | |
| # dtype = 'bfloat16' | |
| def _save_snapshot(model, optimizer, scheduler, epoch, step): | |
| snapshot = { | |
| "MODEL_STATE": model.module.state_dict(), | |
| "OPTIMIZER_STATE": optimizer.state_dict(), | |
| # "SCHEDULER_STATE": scheduler.state_dict(), | |
| "EPOCHS_RUN": epoch, | |
| "STEP_RUN": step | |
| } | |
| torch.save(snapshot, f"snapshot_{step}.pt") | |
| print(f"Epoch: {epoch} | Step: {step} | Snapshot saved.") | |
| def _load_snapshot(snapshot_path, model, optimizer, scheduler): | |
| snapshot = torch.load(snapshot_path) | |
| model.load_state_dict(snapshot["MODEL_STATE"]) | |
| optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"]) | |
| # scheduler.load_state_dict(snapshot["SCHEDULER_STATE"]) # Load scheduler state | |
| epoch = snapshot["EPOCHS_RUN"] | |
| step = snapshot["STEP_RUN"] | |
| print(f"Resuming from Epoch {epoch}, Step {step}") | |
| return epoch, step | |
| tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '...') | |
| # tokenizer.pad_token = tokenizer.eos_token | |
| # if tokenizer.pad_token is None: | |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| # print("ADDED THE TOKENS: ", tokenizer.pad_token_id) | |
| # tokenizer.bos_token = "[INST]" | |
| # tokenizer.eos_token = "[/INST]" | |
| # model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples['text'], | |
| max_length=ModelArgs.block_size, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| def prepare_dataset(split, device, batch_size): | |
| print("Device is: ", device) | |
| # alpaca_prompt = ''' | |
| # ### Instruction: | |
| # {} | |
| # ### Response: | |
| # {} | |
| # ''' | |
| # Load a subset of the C4 dataset with a glob pattern for specific training files | |
| # dataset = load_dataset("allenai/c4", data_files=["en/c4-train.00001-of-01024.json.gz"], trust_remote_code=True) | |
| # Initialize tokenizer | |
| # tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| # generator = torch.Generator(device=device) | |
| def collate_fn(batch): | |
| # Extract text data | |
| texts = [item ["text"] for item in batch] | |
| # Set the pad token if it isn't set already | |
| # if tokenizer.pad_token is None: | |
| # tokenizer.pad_token = tokenizer.eos_token | |
| # outputs = [] | |
| # texts = [] | |
| # for item in batch: | |
| # instruction = item['prompt'] | |
| # # input = item['input'] | |
| # output = item['completion'] | |
| # # out = alpaca_prompt.format(instruction, output) | |
| # texts.append(instruction) | |
| # outputs.append(output) | |
| # Tokenize text data | |
| input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") | |
| # output_encodings = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") | |
| # input_encodings["labels"] = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt") | |
| # out = {"input": input_encodings} | |
| # input_encodings['input_ids'][: , input_encodings["attention_mask"] == 0] = -100 | |
| input_encodings["labels"] = input_encodings["input_ids"].clone() # Use `input_ids` as labels | |
| input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:] # Shift right | |
| input_encodings["labels"][:, -1] = tokenizer.eos_token_id # Let the last token be end | |
| # Return tokenized input tensors | |
| # return out | |
| return input_encodings | |
| # Create DistributedSampler for proper shuffling and partitioning across processes | |
| # dist_sampler = DistributedSampler(fw_train["text"], shuffle=True) | |
| # Create DataLoader with custom collate_fn | |
| # print(fw_dataset) | |
| dataloader = None | |
| if(tinystories): | |
| if(split == 'train'): | |
| data_loader = DataLoader( | |
| fw_train, | |
| # generator=generator, | |
| batch_size=batch_size, | |
| sampler=DistributedSampler(fw_train, shuffle=True), | |
| collate_fn=collate_fn, | |
| drop_last=True, | |
| shuffle=False | |
| ) | |
| elif(split == 'val'): | |
| data_loader = DataLoader( | |
| fw_test, | |
| batch_size=batch_size, | |
| sampler=DistributedSampler(fw_test, shuffle=True), | |
| collate_fn=collate_fn, | |
| drop_last=True, | |
| shuffle=False | |
| ) | |
| elif(fw): | |
| if(split == 'train'): | |
| data_loader = DataLoader( | |
| fw_train['train'], | |
| batch_size=batch_size, | |
| sampler=DistributedSampler(fw_train['train'], shuffle=True), | |
| collate_fn=collate_fn, | |
| drop_last=True, | |
| shuffle=False | |
| ) | |
| elif(split == 'val'): | |
| data_loader = DataLoader( | |
| fw_train['test'], | |
| batch_size=batch_size, | |
| # generator=generator, | |
| sampler=DistributedSampler(fw_train["test"]), | |
| collate_fn=collate_fn, | |
| drop_last=True, | |
| shuffle=False | |
| ) | |
| return data_loader | |
| class Normalization(nn.Module): | |
| def __init__( | |
| self, | |
| embeddings_dims: int = ModelArgs.embeddings_dims | |
| ): | |
| super().__init__() | |
| self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims) | |
| def forward(self, x): | |
| x = self.rmsnorm_layer(x) | |
| return x | |
| # import numpy as np | |
| class RotaryEmbeddings(nn.Module): | |
| def __init__( | |
| self, | |
| device, | |
| embeddings_dims: int = ModelArgs.embeddings_dims, | |
| block_size: int = ModelArgs.block_size, | |
| batch_size: int = ModelArgs.batch_size | |
| ): | |
| super().__init__() | |
| self.embeddings_dims = embeddings_dims | |
| self.block_size = block_size | |
| self.batch_size = batch_size | |
| self.theta = 0 | |
| self.device=device | |
| # self.d_model = embeddings_dims | |
| # self.i = torch.arange(0, embeddings_dims, dtype=torch.float32) | |
| # # self.pos = torch.arange(0, block_size, dtype=torch.float32) | |
| # self.exp = ((2 * self.i)) / self.d_model | |
| # self.theta = 10000 ** self.exp | |
| # # print(self.theta.shape) | |
| # self.x_reshaped = torch.randn(batch_size, block_size, embeddings_dims,dtype=torch.float32, device=device) | |
| # self.cos = torch.cos((self.i / self.theta)) | |
| # self.sin = torch.sin((self.i / self.theta)) | |
| # self.even = self.sin[::2] | |
| # self.odd = self.cos[1::2] | |
| # # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype) | |
| # self.x_reshaped[..., : , ::2] = self.even | |
| # self.x_reshaped[..., : , 1::2] = self.odd | |
| def apply_rope(self, seq): | |
| batch_size, seq_len, embeds_dims = seq.shape | |
| # print(seq.shape) | |
| # print(self.embeddings_dims) | |
| # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device) | |
| positions = torch.arange(0 , embeds_dims, 2, dtype=torch.float32, device = self.device).unsqueeze(0) | |
| # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32) | |
| theta = 10000 ** (-2 * (positions) / embeds_dims) | |
| angles = positions * theta | |
| angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved | |
| x_reshaped = seq.view(batch_size, seq_len, embeds_dims // 2, 2) | |
| cos_angles = torch.cos(angles) | |
| sin_angles = torch.sin(angles) | |
| # print(cos_angles.shape) | |
| # print(sin_angles.shape) | |
| # print(x_reshaped.shape) | |
| # indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device) | |
| out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=-1) | |
| out = out.view(batch_size, seq_len, embeds_dims) | |
| return out | |
| def forward(self, x): | |
| # print("X shape: ", x.shape) | |
| # print("X is: ", x) | |
| # B,T,C = x.shape | |
| # print("MATRIX:",x) | |
| # if(x > self.block_size or x < self.block_size): | |
| # matrix = self.init_matrix(x) | |
| # return matrix | |
| # else: | |
| # matrix = self.init_matrix(self.block_size) | |
| # return matrix | |
| # if(ModelArgs.inference): | |
| res = self.apply_rope(x) | |
| return res | |
| # else: | |
| # return self.x_reshaped | |
| class RotaryAttentionHead(nn.Module): | |
| def __init__( | |
| self, | |
| device, | |
| embeddings_dims: int = ModelArgs.embeddings_dims, | |
| no_of_heads: int = ModelArgs.no_of_heads, | |
| attn_dropout: int = ModelArgs.attn_dropout | |
| ): | |
| super().__init__() | |
| self.head_size = embeddings_dims // no_of_heads | |
| self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) | |
| self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) | |
| self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) | |
| self.rope = RotaryEmbeddings(embeddings_dims=self.head_size, device = device) | |
| self.dropout = nn.Dropout(p = attn_dropout) | |
| self.device = device | |
| def forward(self,x): | |
| # print(x.shape) | |
| # print("X is: ", x) | |
| batch, block_size, embeddings_dims = x.shape | |
| query = self.query(x) | |
| # print(query) | |
| key = self.key(x) | |
| values = self.value(x) | |
| # matrix = self.rotary_matrix(block_size) | |
| rotary_q = self.rope(query) | |
| rotary_k = self.rope(key) | |
| # print(matrix.shape) | |
| # print(query.shape) | |
| masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) | |
| # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
| # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
| weights = rotary_q.permute(2,0,1) @ rotary_k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) | |
| weights_masked = weights.masked_fill(masked == 0, float('-inf')) | |
| scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1]))) | |
| scaled_weights = F.softmax(scaled_weights, dim=-1) | |
| value = scaled_weights @ values | |
| out = self.dropout(value) | |
| return out | |
| # # import numpy as np | |
| # class RotaryEmbeddings(nn.Module): | |
| # def __init__( | |
| # self, | |
| # device, | |
| # embeddings_dims: int = ModelArgs.embeddings_dims, | |
| # block_size: int = ModelArgs.block_size, | |
| # batch_size: int = ModelArgs.batch_size | |
| # ): | |
| # super().__init__() | |
| # self.embeddings_dims = embeddings_dims | |
| # self.block_size = block_size | |
| # self.batch_size = batch_size | |
| # self.theta = 0 | |
| # # def init_matrix(self, seq_len): | |
| # # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False) | |
| # # for pos in range(seq_len): | |
| # # for j in range(1, self.embeddings_dims // 2): | |
| # # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims) | |
| # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta)) | |
| # # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta)) | |
| # # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta)) | |
| # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta)) | |
| # # return self.matrix | |
| # self.device=device | |
| # def init_matrix(self, seq_len): | |
| # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device) | |
| # positions = torch.arange(0 , seq_len, 2, dtype=torch.float32, device = self.device).unsqueeze(1) | |
| # # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32) | |
| # theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims) | |
| # angles = positions * theta | |
| # cos_angles = torch.cos(angles) | |
| # sin_angles = torch.sin(angles) | |
| # indices = torch.arange(seq_len, dtype=torch.int64, device = self.device) | |
| # # print(indices) | |
| # # print(indices.shape) | |
| # # print(indices[::2]) | |
| # even_indices = indices[::2] | |
| # odd_indices = indices[1::2] | |
| # self.matrix[:, even_indices, even_indices] = cos_angles | |
| # self.matrix[:, odd_indices, odd_indices] = sin_angles | |
| # self.matrix[:, odd_indices, even_indices] = -sin_angles | |
| # self.matrix[:, even_indices, odd_indices] = cos_angles | |
| # return self.matrix | |
| # def forward(self, x): | |
| # # B,T,C = x.shape | |
| # # print("MATRIX:",x) | |
| # if(x > self.block_size or x < self.block_size): | |
| # matrix = self.init_matrix(x) | |
| # return matrix | |
| # else: | |
| # matrix = self.init_matrix(self.block_size) | |
| # return matrix | |
| # class RotaryAttentionHead(nn.Module): | |
| # def __init__( | |
| # self, | |
| # device, | |
| # embeddings_dims: int = ModelArgs.embeddings_dims, | |
| # no_of_heads: int = ModelArgs.no_of_heads, | |
| # attn_dropout: int = ModelArgs.attn_dropout | |
| # ): | |
| # super().__init__() | |
| # self.head_size = embeddings_dims // no_of_heads | |
| # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) | |
| # self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) | |
| # self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device) | |
| # self.rotary_matrix = RotaryEmbeddings(embeddings_dims=self.head_size, device = device) | |
| # self.dropout = nn.Dropout(p = attn_dropout) | |
| # self.device = device | |
| # def forward(self,x): | |
| # # print(x.shape) | |
| # batch, block_size, embeddings_dims = x.shape | |
| # query = self.query(x) | |
| # # print(query) | |
| # key = self.key(x) | |
| # values = self.value(x) | |
| # matrix = self.rotary_matrix(block_size) | |
| # # print(matrix.shape) | |
| # # print(query.shape) | |
| # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) | |
| # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
| # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
| # weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) | |
| # weights_masked = weights.masked_fill(masked == 0, float('-inf')) | |
| # scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1]))) | |
| # scaled_weights = F.softmax(scaled_weights, dim=-1) | |
| # value = scaled_weights @ values | |
| # out = self.dropout(value) | |
| # return out | |
| class MQA(nn.Module): | |
| def __init__( | |
| self, | |
| device, | |
| no_of_q_heads: int, | |
| embeddings_dims: int = ModelArgs.embeddings_dims, | |
| block_size: int = ModelArgs.block_size, | |
| ): | |
| super().__init__() | |
| # self.no_of_q_heads = no_of_heads // no_of_kv_heads | |
| # self.no_of_q_heads = no_of_q_heads | |
| self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads | |
| self.head_size = embeddings_dims // no_of_q_heads | |
| # self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2 | |
| self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size, device = device) | |
| # self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size, device = device) | |
| # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False) | |
| self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device) | |
| self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device) | |
| self.dropout = nn.Dropout(p = ModelArgs.attn_dropout) | |
| self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device) | |
| self.device = device | |
| self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, device = self.device) for _ in range(self.no_of_kv_heads)]) | |
| def scaled_dot_product(self, q, k, v, block_size): | |
| # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) | |
| q = self.rotary(q) | |
| masked_table = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device)) | |
| # rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
| # rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T) | |
| # print("Query: ", q.shape) | |
| # print("Keys: ", k.shape) | |
| # print(q.permute(2,0,1).shape) | |
| # print(k.permute(2,0,1).transpose(-2, -1).shape) | |
| # weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T) | |
| # weights = q @ k.permute(2,1,0) | |
| # print(weights.shape) | |
| # print(masked.shape) | |
| weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5) | |
| masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf')) | |
| weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens | |
| weights_normalized = self.dropout(weights_normalized) | |
| out = weights_normalized @ v | |
| return out | |
| def forward(self,x): | |
| # print("MQA: ", x.shape) | |
| batch, block_size, embeddings_dims = x.shape | |
| # query = self.query(x) | |
| # matrix = self.rotary_matrix(block_size) | |
| key = self.key(x) | |
| values = self.value(x) | |
| # print("Keys: ", key.shape) | |
| # print("Values: ", values.shape) | |
| # rotary_value = self.rotary(values) | |
| rotary_key = self.rotary(key) | |
| multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size) for query in self.multi_query], dim=-1) | |
| # print("Multi query: ", multi_query_concat.shape) | |
| linear_layer= self.linear_layer(multi_query_concat) | |
| # out = self.dropout(linear_layer) | |
| return linear_layer | |
| class GQA(nn.Module): | |
| def __init__( | |
| self, | |
| device, | |
| embeddings_dims: int = ModelArgs.embeddings_dims, | |
| block_size: int = ModelArgs.block_size, | |
| # no_of_q_heads: int = ModelArgs.no_of_heads, | |
| mqa_heads: int = ModelArgs.no_kv_heads | |
| ): | |
| super().__init__() | |
| # self.no_of_kv_heads = no_of_kv_heads | |
| self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads | |
| # self.head_dim = embeddings_dims // self.no_kv_heads | |
| self.dropout = nn.Dropout(p = ModelArgs.attn_dropout) | |
| self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device) | |
| self.device = device | |
| self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)]) | |
| # self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size) | |
| def forward(self,x): | |
| batch, block_size, embeddings_dims = x.shape | |
| # res = self.mqa(x) | |
| grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1) | |
| linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one | |
| out = self.dropout(linear_layer) | |
| return out | |
| class Swish(nn.Module): | |
| def __init__( | |
| self, | |
| device, | |
| block_size: int = ModelArgs.block_size, | |
| embeddings_dims: int = ModelArgs.embeddings_dims | |
| ): | |
| super().__init__() | |
| self.sig = torch.nn.Sigmoid() | |
| def forward(self, x): | |
| swish = x * self.sig(x) | |
| return swish | |
| class SWiGLU(nn.Module): | |
| def __init__( | |
| self, | |
| device, | |
| block_size: int = ModelArgs.block_size, | |
| embeddings_dims: int = ModelArgs.embeddings_dims | |
| ): | |
| super().__init__() | |
| self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3) | |
| self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device) | |
| self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device) | |
| self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device) | |
| self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device) | |
| def forward(self, x): | |
| swish_res = self.swish(self.linear_layer1(x)) | |
| x_V = self.linear_layer2(x) | |
| res = torch.mul(swish_res, x_V) | |
| out = self.linear_layer3(res) | |
| return out | |
| class FFN(nn.Module): | |
| def __init__(self, | |
| device, | |
| embeddings_dims: int = ModelArgs.embeddings_dims, | |
| block_size: int = ModelArgs.block_size, | |
| vocab_size: int = ModelArgs.vocab_size, | |
| dropout = ModelArgs.dropout | |
| ): | |
| super().__init__() | |
| # self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device) | |
| self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device) | |
| self.dropout = nn.Dropout(p = dropout) | |
| def forward(self, x): | |
| x = self.swiglue(x) | |
| # x = self.linear_layer(x) | |
| x = self.dropout(x) | |
| return x | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, | |
| device, | |
| embeddings_dims: int = ModelArgs.embeddings_dims, | |
| dropout = ModelArgs.dropout, | |
| block_size: int = ModelArgs.block_size, | |
| vocab_size: int = ModelArgs.vocab_size, | |
| ) : | |
| super().__init__() | |
| self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device) | |
| self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2, device = device) | |
| # self.norm = Normalization(embeddings_dims=embeddings_dims) | |
| self.norm1 = Normalization(embeddings_dims=embeddings_dims) | |
| self.norm2 = Normalization(embeddings_dims=embeddings_dims) | |
| self.dropout = nn.Dropout(p = dropout) | |
| def forward(self, x): | |
| x = x + self.gqa(self.norm1(x)) | |
| x = x + self.feedforward_network(self.norm2(x)) | |
| return x | |
| class Llama(nn.Module): | |
| def __init__(self, | |
| device, | |
| embeddings_dims: int = ModelArgs.embeddings_dims, | |
| no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers, | |
| block_size: int = ModelArgs.block_size, | |
| vocab_size: int = ModelArgs.vocab_size, | |
| dropout = ModelArgs.dropout | |
| ) : | |
| super().__init__() | |
| self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device) | |
| self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)]) | |
| self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device) | |
| self.dropout = nn.Dropout(p = dropout) | |
| # self.norm = Normalization(embeddings_dims) | |
| #weight tying | |
| self.embeddings.weight = self.linear_layer.weight | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, x): | |
| x = self.embeddings(x) | |
| x = self.dropout(x) | |
| x = self.decoder(x) | |
| # x = self.norm(x) | |
| x = self.linear_layer(x) | |
| # out = self.norm(x) | |
| return x | |
| # from andrej karapathy github | |
| def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0): | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| generated_tokens = [] | |
| ModelArgs.inference=True | |
| for _ in range(max_length): | |
| with torch.no_grad(): | |
| outputs = model.module(input_ids) | |
| logits = outputs[:, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| # Top-k filtering | |
| top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1) | |
| # Apply temperature scaling | |
| # probs = probs / temperature | |
| # Sample from top-k | |
| next_token = torch.multinomial(top_k_probs, num_samples=1) | |
| # generated_tokens.append(next_token.item()) | |
| xcol = torch.gather(top_k_indices, -1, next_token) | |
| input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence | |
| return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
| def beam_search(model, tokenizer, prompt, beam_width=5, max_length=50, temperature=1.0): | |
| device = next(model.module.parameters()).device | |
| input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids'] | |
| beam_scores = torch.zeros(beam_width, device=device) | |
| beam_sequences = input_ids.repeat(beam_width, 1) | |
| for _ in range(max_length): | |
| outputs = model(beam_sequences) | |
| logits = outputs[:, -1, :] / temperature | |
| probs = F.softmax(logits, dim=-1) | |
| top_probs, top_indices = torch.topk(probs, beam_width, dim=-1) | |
| # Expand beams | |
| beam_scores = beam_scores.unsqueeze(-1) + torch.log(top_probs) | |
| beam_scores = beam_scores.view(-1) | |
| top_indices = top_indices.view(-1) | |
| # Select top beams | |
| beam_scores, top_beams = torch.topk(beam_scores, beam_width) | |
| beam_sequences = torch.cat([beam_sequences[top_beams // beam_width], top_indices[top_beams].unsqueeze(-1)], dim=-1) | |
| # Return the best sequence | |
| best_sequence = beam_sequences[0] | |
| return tokenizer.decode(best_sequence, skip_special_tokens=True) | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # device = "cpu" | |
| # ModelArgs.device = device | |
| model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout) | |
| model = model.to(ModelArgs.device) | |
| # Printing a summary of the architecture | |
| # !pip install torchinfo | |
| from torchinfo import summary | |
| # idx, targets = get_batch('test') | |
| idx = torch.randint( | |
| low=0, | |
| high=ModelArgs.vocab_size, | |
| size=(ModelArgs.batch_size, ModelArgs.block_size), | |
| dtype=torch.long | |
| ) | |
| # sample_idx = random.randint(range(len(train_dataset))) | |
| # idx, targets = train_dataset[0] | |
| idx = idx.to(ModelArgs.device) | |
| # targets = targets.to(ModelArgs.device) | |
| summary(model=model, | |
| input_data=idx, | |
| # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims), | |
| col_names=["input_size", "output_size", "num_params", "trainable"], | |
| col_width=20, | |
| row_settings=["var_names"]) | |
| def find_unused_parameters(model): | |
| unused = [] | |
| for name, param in model.named_parameters(): | |
| if param.grad is None: | |
| unused.append(name) | |
| return unused | |
| def greedy_decode( | |
| model, | |
| tokenizer, | |
| prompt, | |
| device, | |
| max_length=50, | |
| repetition_penalty=1.2, | |
| context_window=10, | |
| temperature=1.0, | |
| eos_token_id=None, | |
| ): | |
| # model.eval() | |
| # device = next(model.parameters()).device | |
| input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids'] | |
| generated_tokens = [] | |
| eos_token_id = eos_token_id or tokenizer.eos_token_id # Use EOS token if provided | |
| for _ in range(max_length): | |
| with torch.no_grad(): | |
| outputs = model.module(input_ids) | |
| logits = outputs[:, -1, :] # Get logits for the last token | |
| # Apply temperature scaling | |
| # if temperature != 1.0: | |
| # logits = logits / temperature | |
| # Apply repetition penalty | |
| # if repetition_penalty != 1.0 and len(generated_tokens) > 0: | |
| # for token in set(generated_tokens[-context_window:]): # Penalize recent tokens | |
| # logits[0, token] /= repetition_penalty | |
| # Greedy selection | |
| next_token = torch.argmax(logits, dim=-1).unsqueeze(0) | |
| generated_tokens.append(next_token.item()) | |
| # Stop if EOS token is generated | |
| # if next_token.item() == eos_token_id: | |
| # break | |
| # Append the new token to the input | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| # Decode the generated tokens | |
| return tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| def save_to_file(text): | |
| with open('generations.txt', 'a') as f: | |
| f.writelines(text + "\n\n") | |
| #Train the model | |
| # writer = SummaryWriter(log_dir="runs/experiment") | |
| from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR | |
| # Warmup phase for 2000 steps | |
| def warmup_fn(step): | |
| if step < 2000: | |
| return step / 2000 # LR gradually increases | |
| return 1.0 | |
| from torch.optim.lr_scheduler import LambdaLR | |
| def trapezoidal_lr_scheduler(optimizer, max_lr, total_steps, warmup_steps, plateau_steps, decay_steps): | |
| """ | |
| Trapezoidal learning rate scheduler: | |
| - Increases linearly for `warmup_steps` steps. | |
| - Remains constant for `plateau_steps` steps. | |
| - Decreases linearly for `decay_steps` steps. | |
| """ | |
| def lr_lambda(step): | |
| if step < warmup_steps: | |
| # Linear warmup | |
| return float(step) / float(max(1, warmup_steps)) | |
| elif step < warmup_steps + plateau_steps: | |
| # Constant plateau | |
| return 1.0 | |
| else: | |
| # Linear decay | |
| decay_step = step - (warmup_steps + plateau_steps) | |
| return max(0.0, float(decay_steps - decay_step) / float(max(1, decay_steps))) | |
| return LambdaLR(optimizer, lr_lambda) | |
| torch.set_float32_matmul_precision('high') | |
| scaler = torch.amp.GradScaler(enabled=(ModelArgs.dtype == 'float16')) | |
| save_chechpoint_iter = 50 | |
| total_iters = 10000 | |
| eval_iters = 50 | |
| eval_check = 100 | |
| warmup_iters = 700 | |
| min_lr = 0.1 * ModelArgs.max_lr | |
| lr_decay_iters = 10000 | |
| total_batch_size = 524288 | |
| micro_batch_size = ModelArgs.batch_size | |
| gradient_accumulation_steps = total_batch_size // (micro_batch_size * (ModelArgs.block_size * torch.cuda.device_count())) | |
| # learning rate decay scheduler (cosine with warmup) from https://github.com/karpathy/nanoGPT/blob/master/train.py | |
| def get_lr(it): | |
| # 1) linear warmup for warmup_iters steps | |
| if it < warmup_iters: | |
| return ModelArgs.max_lr * (it + 1) / (warmup_iters + 1) | |
| # 2) if it > lr_decay_iters, return min learning rate | |
| if it > lr_decay_iters: | |
| return min_lr | |
| # 3) in between, use cosine decay down to min learning rate | |
| decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) | |
| assert 0 <= decay_ratio <= 1 | |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) | |
| return min_lr + coeff * (ModelArgs.max_lr - min_lr) | |
| def train(): | |
| setup() | |
| device = int(os.environ["LOCAL_RANK"]) | |
| torch.cuda.set_device(int(device)) | |
| # torch.set_default_device('cuda') | |
| # train_dataloader = prepare_dataset(ModelArgs.batch_size) | |
| # rank = torch.distributed.get_rank() | |
| print(f"Start running DDP on rank {device}.") | |
| # # create model and move it to GPU with id rank | |
| # device_id = rank % torch.cuda.device_count() | |
| # CFG = ModelArgs() | |
| if(device == 0): | |
| # # Initialise run | |
| wandb.init( | |
| # entity = 'rajceo2031', | |
| project = 'Llama-DDP-Pretrain-10-billion-tokens', | |
| # config = CFG, | |
| # save_code = True, | |
| #group = 'ANN', | |
| #job_type = 'train' | |
| ) | |
| print("wand initialized") | |
| model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device) | |
| # print(f"Model on device {device} is ready") | |
| print(f"Model on device {device} is ready") | |
| # Wrap model with DDP after moving to GPU | |
| # model = DDP(model, device_ids=[device]) | |
| # optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=1e-8) | |
| # # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=4000, T_mult=1, eta_min=1e-5) | |
| # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(None, T_max=30000, eta_min=1e-6) | |
| # _load_snapshot('/kaggle/input/models/snapshot2.pt', model.module, None, None) | |
| optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=ModelArgs.eps) | |
| # model = torch.compile(model) | |
| model = model.to(device) | |
| model = DDP(model, device_ids=[device]) | |
| # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25000, eta_min=1e-6) #with the prev optim snapshot | |
| # new_scheduler = trapezoidal_lr_scheduler(optimizer, ModelArgs.max_lr, total_steps, warmup_steps, plateau_steps, decay_steps) | |
| # warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_fn) | |
| # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6) | |
| # Cosine decay after warmup | |
| # new_scheduler = CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6) | |
| # Combine both schedulers | |
| # scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, new_scheduler], milestones=[2000]) | |
| # Reset learning rate to 1e-4 | |
| # for param_group in optimizer.param_groups: | |
| # param_group['lr'] = ModelArgs.max_lr | |
| # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2000, T_mult=1, eta_min=1e-6) | |
| # print("Old optimizer with new lr ready") | |
| # optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr) | |
| # Create DataLoader with collate_fn | |
| # train_loader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device)) | |
| # val_loader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device)) | |
| # print("Loader is ready") | |
| # print(train_loader) | |
| # print(next(iter(train_loader))) | |
| # for X,y in train_loader: | |
| # print(X.shape) | |
| # print(y.shape) | |
| # alpaca_prompt = ''' | |
| # ### Instruction: | |
| # {instruction} | |
| # ### Input: | |
| # {input} | |
| # ### Response: | |
| # ''' | |
| # Only create progress bar for rank 0 | |
| # eval_epoch_iterator = range(eval_iters) | |
| # train_epoch_iterator = range(total_iters) | |
| # if device == 0: | |
| # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training") | |
| # train_epoch_iterator = range(ModelArgs.epochs) | |
| # if device == 0: # Ensure tqdm only runs on rank 0 | |
| # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training Progress", position=0, leave=True) | |
| # lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters) | |
| model.eval() | |
| world_size = torch.cuda.device_count() | |
| def estimate_loss(val_loader, val_iterator, device): | |
| out = {} | |
| # train_loader = prepare_dataset('train', ModelArgs.batch_size) | |
| # val_loader_iterator = iter(val_loader) | |
| loader = None | |
| epoch_loss = None | |
| epoch_losses = [] | |
| # print("Starting the eval...") | |
| for split in ['val']: | |
| print(f"Starting with {split} evaluation...") | |
| # losses = torch.zeros(ModelArgs.val_epochs) | |
| # if(split == 'train'): | |
| # loader = train_loader | |
| # if(split == 'val'): | |
| # loader = val_loader | |
| for step in range(eval_check): | |
| try: | |
| batch = next(val_iterator) | |
| except StopIteration: | |
| val_loader_iterator = iter(val_loader) | |
| batch = next(val_loader_iterator) | |
| total_loss = 0 | |
| # loader.sampler.set_epoch(step) | |
| total_batches = 0 | |
| # batch = next(val_loader_iterator) | |
| # for batch in loader: # Loop through DataLoader batches | |
| idx = batch['input_ids'] | |
| targets = batch['labels'] | |
| idx = idx.to(device) | |
| targets = targets.to(device) | |
| with torch.autocast(device_type=device, dtype=torch.bfloat16): | |
| logits = model(idx) | |
| batch_size, block_size, embeddings_dims = logits.shape | |
| logits = logits.view(batch_size * block_size, embeddings_dims) # Flatten tokens | |
| targets = targets.view(batch_size * block_size) | |
| loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) | |
| total_loss += loss.item() | |
| total_batches += 1 | |
| # Compute mean loss for this epoch | |
| epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0 | |
| epoch_losses.append(epoch_loss) | |
| # print(f"Epoch {epoch + 1}/{ModelArgs.val_epochs}: Loss = {epoch_loss:.4f}") | |
| # Compute mean loss across all evaluation epochs | |
| out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0 | |
| epoch_loss = None | |
| epoch_losses = [] | |
| model.train() | |
| return out | |
| # model = model.to(rank) | |
| model.train() | |
| count = 0 | |
| train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size) | |
| val_loader= prepare_dataset('val', device, ModelArgs.batch_size) | |
| # for step in tqdm(range(total_iters)): | |
| # for epoch in range(ModelArgs.epochs): | |
| # torch.cuda.synchronize() | |
| # train_dataloader.sampler.set_epoch(epoch) | |
| # val_loader.sampler.set_epoch(epoch) | |
| print("Loaders ready both") | |
| epochs = ModelArgs.epochs | |
| # train_step_iterator = range(len(train_dataloader)) | |
| # if device == 0: # Only create progress bar on rank 0 | |
| # train_step_iterator = tqdm(train_step_iterator, desc="Training Progress", position=0, leave=True) | |
| # Print progress on rank 0 | |
| train_loader_length = 0 | |
| train_data_iterator = iter(train_dataloader) | |
| val_data_iterator = iter(val_loader) | |
| token_count = 0 | |
| if(device == 0): | |
| train_loader_length = len(train_dataloader) | |
| # print("Total batches: ", train_loader_length) | |
| # print("Length of : ", len(train_dataloader)) | |
| # print("Length of val: ", len(val_loader)) | |
| # for step, batch in enumerate(train_dataloader): | |
| for step in tqdm(range(total_iters)): | |
| # print("Dataloader things: ", batch) | |
| # print("Total batches: ", len(train_dataloader)) | |
| if(device == 0): | |
| # if(step % 100 == 0): | |
| # if(step == train_loader_length): | |
| # break | |
| print("Step : ", step, "/", total_iters) | |
| print('Total batches: ', len(train_dataloader)) | |
| print("Total gradient accumulation steps: ", gradient_accumulation_steps) | |
| print("Total tokens processed: ", token_count) | |
| # all_gpus_avg_train_loss = None | |
| # all_gpus_avg_val_loss = None | |
| # every once in a while evaluate the loss on train and val sets | |
| if (step % eval_iters == 0 and step != 0) or step == total_iters - 1: | |
| losses = estimate_loss( val_loader, val_data_iterator, 'cuda') | |
| # avg_train_loss = losses['train'] | |
| avg_val_loss = losses['val'] | |
| # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") | |
| # if device == 0: # Only print on main process | |
| print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}") | |
| # print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f}") | |
| # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") | |
| # Log training loss more frequently | |
| # Aggregate average loss across all GPUs | |
| # avg_train_loss = torch.Tensor([losses['train']]).to(device) | |
| avg_val_loss = torch.Tensor([losses['val']]).to(device) | |
| # torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| if device == 0: | |
| # all_gpus_avg_train_loss = avg_train_loss / world_size | |
| # print(f"All_GPUs_Train_losses: {all_gpus_avg_train_loss.item():.4f}") | |
| all_gpus_avg_val_loss = avg_val_loss / world_size | |
| print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}") | |
| # if device == 0: | |
| # writer.add_scalar("All_GPUs_Train_losses", all_gpus_avg_train_loss.item(), global_step=step) | |
| # writer.add_scalar("All_GPUs_Val_losses", all_gpus_avg_val_loss.item(), global_step=step) | |
| # writer.add_scalar("training_step_loss", losses['train'], global_step=step) | |
| # writer.add_scalar("val_step_loss", losses['val'], global_step=step) | |
| # writer.add_scalar("GPU", device, global_step=step) | |
| # writer.add_scalar("Epoch", epoch, global_step=step) | |
| wandb.log({ | |
| # "Learning Rate": optimizer.param_groups[0]['lr'], | |
| # "All_GPUs_Train_losses": all_gpus_avg_train_loss, | |
| "All_GPUs_Val_losses": all_gpus_avg_val_loss, | |
| # "training_step_loss": losses['train'], | |
| "val_step_loss": losses['val'], | |
| # "Step": step, | |
| # "Epoch": epoch | |
| }) | |
| #Loading a checkpoint | |
| # if(os.path.exists('snapshot.pt')): | |
| # model, optimizer = _load_snapshot(model=model, optimizer=optimizer, epoch=epoch, step=step, snapshot_path='snapshot.pt') | |
| # if(step % save_chechpoint_iter == 0 and device == 0 and step != 0): | |
| # _save_snapshot(epoch=epoch, model=model, optimizer=optimizer, step=step) | |
| if step % save_chechpoint_iter == 0 and device == 0 and step != 0: | |
| print(f"Saving the model checkpoint for step: {step}") | |
| _save_snapshot(model, optimizer, None, None, step) | |
| accumulated_loss = 0.0 | |
| optimizer.zero_grad(set_to_none=True) | |
| for micro_step in range(gradient_accumulation_steps): | |
| try: | |
| batch = next(train_data_iterator) | |
| except StopIteration: | |
| train_data_iterator = iter(train_dataloader) | |
| batch = next(train_data_iterator) | |
| # print(batch) | |
| # batch = next(train_data_iterator) | |
| # print(batch) | |
| # batch = {k: v.to(self.local_rank) for k, v in batch.items()} | |
| idx = batch['input_ids'].to(device) | |
| # idx, targets = get_batch(split='train') | |
| # print(f"Starting the train step: {step}...") | |
| # for idx, targets in train_loader: | |
| # idx, targets = next(iter(train_loader)) | |
| # print("Idx: ", idx) | |
| # print("Targets: ", targets) | |
| # idx = idx.to(device) | |
| # print("Idx: ", idx) | |
| # print("Targets: ", targets) | |
| targets = batch['labels'].to(device) | |
| token_count += len(idx) | |
| with torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16): | |
| logits = model(idx) | |
| batch_size, block_size, embeddings_dims = logits.shape | |
| # print(logits.shape) | |
| # print(targets) | |
| logits = logits.view(batch_size*block_size, embeddings_dims) | |
| # print("OK") | |
| targets = targets.view(batch_size * block_size) | |
| # print("OK2") | |
| loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id) | |
| loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch | |
| accumulated_loss += loss.detach() | |
| model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices | |
| scaler.scale(loss).backward() | |
| # Check for unused parameters | |
| unused_params = find_unused_parameters(model) | |
| if unused_params: | |
| print(f"Unused parameters: {unused_params}") | |
| # break | |
| if(device == 0): | |
| if(micro_step % 10 == 0): | |
| # if(step == train_loader_length): | |
| # break | |
| print("Micro Batch : ", micro_step) | |
| print("Step : ", step, "/", total_iters) | |
| print('Total batches: ', len(train_dataloader)) | |
| print("Total gradient accumulation steps: ", gradient_accumulation_steps) | |
| print("Total tokens processed: ", token_count) | |
| # count += 1 | |
| lr = get_lr(step) | |
| for params in optimizer.param_groups: | |
| params['lr'] = lr | |
| # Compute gradient norms before clipping | |
| if(ModelArgs.clip != 0.0): | |
| scaler.unscale_(optimizer) #To avoid underflow | |
| total_norm_before = torch.norm( | |
| torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 | |
| ) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip) | |
| # Compute gradient norms after clipping | |
| total_norm_after = torch.norm( | |
| torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2 | |
| ) | |
| if(device == 0 and step !=0): | |
| print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}") | |
| print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}") | |
| scaler.step(optimizer) | |
| scaler.update() | |
| # optimizer.step() | |
| # new_scheduler.step() | |
| torch.cuda.synchronize() | |
| torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| if(device == 0): | |
| wandb.log({ | |
| "Learning Rate": lr, | |
| "All_GPUs_Train_losses": accumulated_loss.item(), | |
| # "All_GPUs_Val_losses": all_gpus_avg_val_loss, | |
| # "training_step_loss": losses['train'], | |
| # "val_step_loss": losses['val'], | |
| "Step": step, | |
| # "Epoch": epoch | |
| }) | |
| # print(loss.item()) | |
| # if(step % 100 == 0): | |
| # print(f'Step : {step} | GPU: {device} Loss: {loss.item()}') | |
| # if device == 0: | |
| # print("loss: ", loss.item()) | |
| # train_epoch_iterator.set_postfix({"loss": f"{loss.item():.4f}"}) | |
| # print(loss.item()) | |
| # break | |
| # if step != 0 and (step % eval_iters == 0 or step == total_steps -1) : | |
| # loss_values = estimate_loss() | |
| # print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val'])) | |
| # Add after a training step: | |
| # unused_params = find_unused_parameters(model) | |
| # print("Unused parameters:", unused_params) | |
| # break | |
| if device == 0 and step % 5 == 0: | |
| count = 3 | |
| while(count): # Only generate text on the main process | |
| # print("Generating text...") | |
| # alpaca_prompt = ''' | |
| # ### Instruction: | |
| # {} | |
| # ### Input: | |
| # {} | |
| # ### Response: | |
| # ''' | |
| # prompt = alpaca_prompt.format("You are a helpful assistant.", "Say a joke.", "") | |
| # print("Generating text") | |
| prompt = "Once upon a time" | |
| generated_text = topk_sampling(model, prompt, max_length=50, top_k=50, temperature=1.0, device=device) | |
| # generated_text = greedy_decode( | |
| # model, | |
| # tokenizer, | |
| # "Once upon a time", | |
| # max_length=40, | |
| # repetition_penalty=1.2, | |
| # context_window=10, | |
| # temperature=0.7, # Lower temperature for more deterministic output | |
| # device=device | |
| # ) | |
| # generated_text = beam_search(model, tokenizer, "Once upon a time ", beam_width=5, max_length=50, temperature=0.6) | |
| print(f" Step: {step} | Generated Text: {generated_text}") | |
| # model.train() | |
| # save_to_file(generated_text) | |
| count -= 1 | |
| # if step != 0: | |
| # train_step_iterator.set_postfix({"Train loss": f"{all_gpus_avg_train_loss.item():.4f} | Val Loss : {all_gpus_avg_val_loss.item():.4f}"}) | |
| # break | |
| # Cleanup | |
| if device == 0: | |
| # writer.close() | |
| wandb.finish() | |
| cleanup() | |
| world_size = torch.cuda.device_count() | |
| print(f"World size: {world_size}") | |
| train() | |