import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader , random_split
from datasets import load_dataset , concatenate_datasets
from tokenizers import Tokenizer
from tokenizers.models import BPE,WordLevel
from tokenizers.trainers import BpeTrainer,WordLevelTrainer
from tokenizers.pre_tokenizers import ByteLevel,Whitespace
from tokenizers.processors import TemplateProcessing
from  tokenizers import  decoders
from torch.cuda.amp import autocast, GradScaler
import time 
from torch.utils.tensorboard import SummaryWriter
from itertools import islice
from config import get_weights_file_path, get_config
from tqdm import tqdm
from pathlib import Path
import warnings
from dataset import BilingualDataset
from model import build_gpt
g = torch.Generator()
g.manual_seed(23)
def greedy_decode(model, text,mask, tokenizer, max_len, device):
    sos_idx = tokenizer.token_to_id('')
    eos_idx = tokenizer.token_to_id('')
    
    decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(text).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break
        
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device)
        
        out = model.decode(decoder_input, decoder_mask)
        
        prob = model.project(out[:,-1])
        _, next_word = torch.max(prob, dim=1)
        
        
        decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(text).fill_(next_word.item()).to(device)],dim=1)
        if next_word == eos_idx:
            break
        
    return decoder_input.squeeze(0)
def generate_text(
    model, text, mask, tokenizer, max_len, device, 
    temperature=0.7, top_k=50
):
    eos_idx = tokenizer.token_to_id('')
    # Start with the input text as initial decoder input
    decoder_input = text.to(device)
    if decoder_input.dim() == 1:
       decoder_input = decoder_input.unsqueeze(0)
    # Print the initial prompt
    prompt_text = tokenizer.decode(text.squeeze(0).tolist())
    print(prompt_text, end="", flush=True)
    while len(decoder_input[0]) < max_len - 3:
        # Apply causal mask based on current decoder_input length
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device)
        # Get model output
        out = model.decode(decoder_input, decoder_mask)
        logits = model.project(out[:, -1])  # Get logits for last token
        # Sampling: temperature + top-k
        logits = logits / temperature
        top_k_logits, top_k_indices = torch.topk(logits, top_k)
        probs = torch.softmax(top_k_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        next_token = top_k_indices.gather(-1, next_token)
        # Decode and print token
        word = tokenizer.decode([next_token.item()])
        print(word, end="", flush=True)
        # Append next token
        decoder_input = torch.cat([decoder_input, next_token], dim=1)
        if next_token.item() == eos_idx:
            break
    print()
    return decoder_input.squeeze(0)
def generate_text_(model, text,m, tokenizer, max_len, device, temperature=0.7, top_k=50):
    sos_idx = tokenizer.token_to_id('')
    eos_idx = tokenizer.token_to_id('')
    pad_idx = tokenizer.token_to_id('')
    
    # Encode input and add  at beginning
    input_tokens = [sos_idx] + tokenizer.encode(text).ids
    
    # Truncate if too long
    input_tokens = input_tokens[:max_len-1]  # Leave room for 
    
    # Convert to tensor
    decoder_input = torch.tensor(input_tokens, device=device).unsqueeze(0)
    
    # Generate until max_len
    for _ in range(max_len - len(input_tokens)):
        # Create causal mask for what we've generated so far
        decoder_mask = causal_mask(decoder_input.size(1)).to(device)
        
        # Get model output
        out = model.decode(decoder_input, decoder_mask)
        logits = model.project(out[:, -1])
        
        # Apply sampling
        logits = logits / temperature
        top_k_logits, top_k_indices = torch.topk(logits, top_k)
        probs = torch.softmax(top_k_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        next_token = top_k_indices.gather(-1, next_token)
        
        # Print the generated word
        word = tokenizer.decode([next_token.item()])
        print(word, end="", flush=True)
        
        # Append to input
        decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1)
        
        if next_token.item() == eos_idx:
            break
            
    return decoder_input.squeeze(0)
def run_validation(model,validation_ds,  tokenizer, max_len, device, print_msg, global_state, writer, num_examples=2):
    model.eval()
    
    count = 0
    pad_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) 
    sos_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) 
    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            input_tokens = batch['input']
            # print("TEXT INPUT : ",text)
            # input_tokens = tokenizer.encode(text).ids[:-1]
            print("TOKENIZED INPUT : ",input_tokens)
            input_tokens =  input_tokens
            # if len(input_tokens) < config['seq_len'] :
                # input_tokens+=[pad_token] * ((config['seq_len']  ) - len(input_tokens))
            # if len(input_tokens) > config['seq_len'] :
                # input_tokens = input_tokens[:config['seq_len']]
                
            
            
            input_tokens = torch.tensor(input_tokens)
            # (input_tokens != pad_token).unsqueeze(0).int() & 
            mask = causal_mask(input_tokens.size(0))
            # text = batch['input'].to(device)
            # mask = batch['input_mask'].to(device)
            model_output = generate_text(model, input_tokens, mask, tokenizer, max_len, device)
            # model_output = greed0y_decode(model, text, mask,tokenizer, max_len, device)
            print_msg("Model Output Embedding : ")
            print_msg(str(model_output.tolist()))
            model_out_text = tokenizer.decode(model_output.detach().cpu().numpy())
            # text = tokenizer.decode(input_tokens[0].tolist(),skip_special_tokens=True)
            
            #print
            print_msg(f'SOURCE : {input_tokens}')
            print_msg(f'PREDICTED : {model_out_text}')
            
            if count == num_examples:
                break
            
def get_all_sentences(ds):
    for item in ds:
        yield item['text']
def get_or_build_tokenizer_(config,ds):
    tokenizer_path = Path(config['tokenizer_file'])
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token=""))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["", "", "", "", "","","","","","",""],min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds),trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer
def get_or_build_tokenizer(config, ds):
    tokenizer_path = Path(config['tokenizer_file'])
    if not tokenizer_path.exists():
        # Define tokenizer with BPE model
        tokenizer = Tokenizer(BPE(unk_token=""))
        # ByteLevel pre-tokenizer and decoder
        tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
        tokenizer.decoder = decoders.ByteLevel()
        # Optional: Add post-processing for special tokens
        tokenizer.post_processor = TemplateProcessing(
            single=" $A ",
            pair=" $A   $B ",
            special_tokens=[
                ("", 0),
                ("", 1),
            ],
        )
        # Trainer
        trainer = BpeTrainer(
            vocab_size = 30000,
            min_frequency=2,
            special_tokens=["", "", "", "", "","","","","","",""]
        )
        # Train from dataset
        tokenizer.train_from_iterator(get_all_sentences(ds), trainer=trainer)
        # Save as single .json file
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer
def get_ds(config):
    # ds_raw  = load_dataset("json",data_files={'train':config['train'],'test':config['test']})
    ds_raw  = load_dataset("json",data_files='./dataset/openwebtext_500k_docs.jsonl',split="train",streaming=True)
    ds_test  = load_dataset("json",data_files='./dataset/openwebtext_test.jsonl',split="train",streaming=True)
    # ds_raw = ds_raw[:1]
    # ds_raw  = load_dataset("stas/openwebtext-10k")
    tokenizer = get_or_build_tokenizer(config,ds_raw)
    # tokenizer = get_or_build_tokenizer(config,ds_raw)
    train_ds = BilingualDataset(ds_raw, tokenizer, config['seq_len'])
    val_ds = BilingualDataset(ds_test, tokenizer, config['seq_len'])
    train_dataloader = DataLoader(train_ds, num_workers=6,prefetch_factor=2,pin_memory=True,batch_size=config['batch_size'])
    val_dataloader = DataLoader(val_ds, batch_size=1)
    
    return train_dataloader, val_dataloader, tokenizer
def get_model(config, vocab_size):
    # model = build_transformer(vocab_src_len,vocab_tgt_len,config['seq_len'],config['seq_len'],config['d_model'], config['N'] , config['h'], config['d_ff'])
    model = build_gpt( vocab_size, config['seq_len'], config['d_model'], config['N'] , config['h'], config['d_ff'],config['dropout'])
    return model
def validate_model(val_dataloader, model,device,loss_fn,vocab_size):
    total_loss = 0
    model.eval()
    i = 0
    with torch.no_grad():
        for batch in val_dataloader:
            input_tokens = batch['input'].to(device,non_blocking=True)
            label = batch['label'].to(device,non_blocking=True)
            decoder_output = model.decode(input_tokens)
            project_output = model.project(decoder_output)
            total_loss += loss_fn(
                        project_output.view(-1,vocab_size), 
                        label.view(-1)
                    )
            i+=1
        print(f"Validation loss : {total_loss/i:4f}")
            
            
def train_model(config):
    #Define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device : {device}")
    # Enable TF32 (optional, speeds up matmul)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
    train_dataloader , val_dataloader, tokenizer = get_ds(config)
    print(tokenizer.get_vocab_size())
    model = get_model(config, tokenizer.get_vocab_size()).to(device)
    # TensorBoard
    writer = SummaryWriter(config['experiment_name'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
    scaler = GradScaler()  # <- added scaler for mixed precision
    initial_epoch = 0
    global_step = 0
    tqdm_state = {'n':0}
    
    model_filename = None
    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(f"Preloading Model {model_filename}")
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        optimizer.load_state_dict(state['optimizer_state_dict'])
        initial_epoch = state['epoch'] if 'mid-' in model_filename else state['epoch'] + 1 
        global_step = state['global_step']
        tqdm_state = state['tqdm_state']  if 'mid-' in model_filename else {'n':0}
    else:
        print("No Model to preload. Setting from scratch.")
    loss_fn = nn.CrossEntropyLoss(
        ignore_index=tokenizer.token_to_id(''), 
        label_smoothing=0.05
    ).to(device)
    e = 0
    
    try:
        for epoch in range(initial_epoch, config['num_epochs']):
            model.train()
            batch_iterator = tqdm(islice(train_dataloader,tqdm_state['n'],None), desc=f'Processing epoch {epoch:02d}',initial=tqdm_state['n'] ,total=140000)#total=217013)
            e = epoch
            if 'elapsed' in tqdm_state and "mid-" in model_filename :
                batch_iterator.start_t = time.time() - tqdm_state['elapsed']
            # total_len = len(batch_iterator)
            for batch in batch_iterator:
                # print(len(batch_iterator))
                # torch.cuda.empty_cache()
                
                input_tokens = batch['input'].to(device,non_blocking=True)
                label = batch['label'].to(device,non_blocking=True)
                optimizer.zero_grad(set_to_none=True)
                # 🔥 Mixed precision forward pass
                with autocast(dtype=torch.float16):
                    decoder_output = model.decode(input_tokens)
                    project_output = model.project(decoder_output)  # (B, Seq_len, tgt_vocab_size)
                    loss = loss_fn(
                        project_output.view(-1, tokenizer.get_vocab_size()), 
                        label.view(-1)
                    )
                if global_step%10 ==0:
                    batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"})
                    writer.add_scalar("train loss", loss.item(), global_step)
                    writer.flush()
                if global_step % 10000 == 0 and global_step != 0:
                    validate_model(val_dataloader,model,device,loss_fn,tokenizer.get_vocab_size())
                # 🔥 Mixed precision backward pass
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                global_step += 1
                tqdm_state = {'n': batch_iterator.n,'elapsed':batch_iterator.format_dict["elapsed"]}
                # if()
            tqdm_state['n'] = 0
            del tqdm_state['elapsed']
            model_filename = get_weights_file_path(config, f'{epoch:02d}')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'global_step': global_step,
                'tqdm_state':tqdm_state
            }, model_filename)
            validate_model(validate_model,model,device,loss_fn,tokenizer.get_vocab_size())
    except KeyboardInterrupt:
        print("You are stoping training : ... ")
        model_filename = get_weights_file_path(config, f'mid-{e:02d}{input("Checkpoint Name: ")}')
        torch.save({
            'epoch': e,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step,
            'tqdm_state':tqdm_state
        }, model_filename)
        
if __name__ == "__main__":
    warnings.filterwarnings('ignore')
    config = get_config("./openweb.config.json")
    train_model(config)