| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). | 
					
					
						
						| 
							 | 
						GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned | 
					
					
						
						| 
							 | 
						using a masked language modeling (MLM) loss. | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import argparse | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						from tqdm import tqdm | 
					
					
						
						| 
							 | 
						import multiprocessing | 
					
					
						
						| 
							 | 
						import time | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from torch.utils.tensorboard import SummaryWriter | 
					
					
						
						| 
							 | 
						from torch.utils.data import DataLoader, SequentialSampler, RandomSampler | 
					
					
						
						| 
							 | 
						from torch.utils.data.distributed import DistributedSampler | 
					
					
						
						| 
							 | 
						from transformers import AdamW, get_linear_schedule_with_warmup | 
					
					
						
						| 
							 | 
						from models import build_or_load_gen_model | 
					
					
						
						| 
							 | 
						from evaluator import smooth_bleu | 
					
					
						
						| 
							 | 
						from evaluator.CodeBLEU import calc_code_bleu | 
					
					
						
						| 
							 | 
						from evaluator.bleu import _bleu | 
					
					
						
						| 
							 | 
						from utils import get_filenames, get_elapse_time, load_and_cache_gen_data | 
					
					
						
						| 
							 | 
						from configs import add_args, set_seed, set_dist | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s', | 
					
					
						
						| 
							 | 
						                    datefmt='%m/%d/%Y %H:%M:%S', | 
					
					
						
						| 
							 | 
						                    level=logging.INFO) | 
					
					
						
						| 
							 | 
						logger = logging.getLogger(__name__) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer): | 
					
					
						
						| 
							 | 
						    eval_sampler = SequentialSampler(eval_data) | 
					
					
						
						| 
							 | 
						    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, | 
					
					
						
						| 
							 | 
						                                 num_workers=4, pin_memory=True) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    logger.info("  " + "***** Running ppl evaluation *****") | 
					
					
						
						| 
							 | 
						    logger.info("  Num examples = %d", len(eval_examples)) | 
					
					
						
						| 
							 | 
						    logger.info("  Batch size = %d", args.eval_batch_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model.eval() | 
					
					
						
						| 
							 | 
						    eval_loss, batch_num = 0, 0 | 
					
					
						
						| 
							 | 
						    for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval ppl"): | 
					
					
						
						| 
							 | 
						        batch = tuple(t.to(args.device) for t in batch) | 
					
					
						
						| 
							 | 
						        source_ids, target_ids = batch | 
					
					
						
						| 
							 | 
						        source_mask = source_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						        target_mask = target_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with torch.no_grad(): | 
					
					
						
						| 
							 | 
						            if args.model_type == 'roberta': | 
					
					
						
						| 
							 | 
						                loss, _, _ = model(source_ids=source_ids, source_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                   target_ids=target_ids, target_mask=target_mask) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                outputs = model(input_ids=source_ids, attention_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                labels=target_ids, decoder_attention_mask=target_mask) | 
					
					
						
						| 
							 | 
						                loss = outputs.loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        eval_loss += loss.item() | 
					
					
						
						| 
							 | 
						        batch_num += 1 | 
					
					
						
						| 
							 | 
						    eval_loss = eval_loss / batch_num | 
					
					
						
						| 
							 | 
						    eval_ppl = round(np.exp(eval_loss), 5) | 
					
					
						
						| 
							 | 
						    return eval_ppl | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, split_tag, criteria): | 
					
					
						
						| 
							 | 
						    logger.info("  ***** Running bleu evaluation on {} data*****".format(split_tag)) | 
					
					
						
						| 
							 | 
						    logger.info("  Num examples = %d", len(eval_examples)) | 
					
					
						
						| 
							 | 
						    logger.info("  Batch size = %d", args.eval_batch_size) | 
					
					
						
						| 
							 | 
						    eval_sampler = SequentialSampler(eval_data) | 
					
					
						
						| 
							 | 
						    if args.data_num == -1: | 
					
					
						
						| 
							 | 
						        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, | 
					
					
						
						| 
							 | 
						                                     num_workers=4, pin_memory=True) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model.eval() | 
					
					
						
						| 
							 | 
						    pred_ids = [] | 
					
					
						
						| 
							 | 
						    bleu, codebleu = 0.0, 0.0 | 
					
					
						
						| 
							 | 
						    for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval bleu for {} set".format(split_tag)): | 
					
					
						
						| 
							 | 
						        source_ids = batch[0].to(args.device) | 
					
					
						
						| 
							 | 
						        source_mask = source_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						        with torch.no_grad(): | 
					
					
						
						| 
							 | 
						            if args.model_type == 'roberta': | 
					
					
						
						| 
							 | 
						                preds = model(source_ids=source_ids, source_mask=source_mask) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                top_preds = [pred[0].cpu().numpy() for pred in preds] | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                preds = model.generate(source_ids, | 
					
					
						
						| 
							 | 
						                                       attention_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                       use_cache=True, | 
					
					
						
						| 
							 | 
						                                       num_beams=args.beam_size, | 
					
					
						
						| 
							 | 
						                                       early_stopping=args.task == 'summarize', | 
					
					
						
						| 
							 | 
						                                       max_length=args.max_target_length) | 
					
					
						
						| 
							 | 
						                top_preds = list(preds.cpu().numpy()) | 
					
					
						
						| 
							 | 
						            pred_ids.extend(top_preds) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    output_fn = os.path.join(args.res_dir, "test_{}.output".format(criteria)) | 
					
					
						
						| 
							 | 
						    gold_fn = os.path.join(args.res_dir, "test_{}.gold".format(criteria)) | 
					
					
						
						| 
							 | 
						    src_fn = os.path.join(args.res_dir, "test_{}.src".format(criteria)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.task in ['defect']: | 
					
					
						
						| 
							 | 
						        target_dict = {0: 'false', 1: 'true'} | 
					
					
						
						| 
							 | 
						        golds = [target_dict[ex.target] for ex in eval_examples] | 
					
					
						
						| 
							 | 
						        eval_acc = np.mean([int(p == g) for p, g in zip(pred_nls, golds)]) | 
					
					
						
						| 
							 | 
						        result = {'em': eval_acc * 100, 'bleu': 0, 'codebleu': 0} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2: | 
					
					
						
						| 
							 | 
						            for pred_nl, gold in zip(pred_nls, eval_examples): | 
					
					
						
						| 
							 | 
						                f.write(pred_nl.strip() + '\n') | 
					
					
						
						| 
							 | 
						                f1.write(target_dict[gold.target] + '\n') | 
					
					
						
						| 
							 | 
						                f2.write(gold.source.strip() + '\n') | 
					
					
						
						| 
							 | 
						            logger.info("Save the predictions into %s", output_fn) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        dev_accs, predictions = [], [] | 
					
					
						
						| 
							 | 
						        with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2: | 
					
					
						
						| 
							 | 
						            for pred_nl, gold in zip(pred_nls, eval_examples): | 
					
					
						
						| 
							 | 
						                dev_accs.append(pred_nl.strip() == gold.target.strip()) | 
					
					
						
						| 
							 | 
						                if args.task in ['summarize']: | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    predictions.append(str(gold.idx) + '\t' + pred_nl) | 
					
					
						
						| 
							 | 
						                    f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n') | 
					
					
						
						| 
							 | 
						                    f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n') | 
					
					
						
						| 
							 | 
						                    f2.write(str(gold.idx) + '\t' + gold.source.strip() + '\n') | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    f.write(pred_nl.strip() + '\n') | 
					
					
						
						| 
							 | 
						                    f1.write(gold.target.strip() + '\n') | 
					
					
						
						| 
							 | 
						                    f2.write(gold.source.strip() + '\n') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if args.task == 'summarize': | 
					
					
						
						| 
							 | 
						            (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn) | 
					
					
						
						| 
							 | 
						            bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            bleu = round(_bleu(gold_fn, output_fn), 2) | 
					
					
						
						| 
							 | 
						            if args.task in ['concode', 'translate', 'refine']: | 
					
					
						
						| 
							 | 
						                codebleu = calc_code_bleu.get_codebleu(gold_fn, output_fn, args.lang) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        result = {'em': np.mean(dev_accs) * 100, 'bleu': bleu} | 
					
					
						
						| 
							 | 
						        if args.task == 'concode': | 
					
					
						
						| 
							 | 
						            result['codebleu'] = codebleu * 100 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    logger.info("***** Eval results *****") | 
					
					
						
						| 
							 | 
						    for key in sorted(result.keys()): | 
					
					
						
						| 
							 | 
						        logger.info("  %s = %s", key, str(round(result[key], 4))) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return result | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def main(): | 
					
					
						
						| 
							 | 
						    parser = argparse.ArgumentParser() | 
					
					
						
						| 
							 | 
						    args = add_args(parser) | 
					
					
						
						| 
							 | 
						    logger.info(args) | 
					
					
						
						| 
							 | 
						    t0 = time.time() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    set_dist(args) | 
					
					
						
						| 
							 | 
						    set_seed(args) | 
					
					
						
						| 
							 | 
						    config, model, tokenizer = build_or_load_gen_model(args) | 
					
					
						
						| 
							 | 
						    model.to(args.device) | 
					
					
						
						| 
							 | 
						    if args.n_gpu > 1: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        model = torch.nn.DataParallel(model) | 
					
					
						
						| 
							 | 
						    pool = multiprocessing.Pool(args.cpu_cont) | 
					
					
						
						| 
							 | 
						    args.train_filename, args.dev_filename, args.test_filename = get_filenames(args.data_dir, args.task, args.sub_task) | 
					
					
						
						| 
							 | 
						    fa = open(os.path.join(args.output_dir, 'summary.log'), 'a+') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.do_train: | 
					
					
						
						| 
							 | 
						        if args.local_rank in [-1, 0] and args.data_num == -1: | 
					
					
						
						| 
							 | 
						            summary_fn = '{}/{}'.format(args.summary_dir, '/'.join(args.output_dir.split('/')[1:])) | 
					
					
						
						| 
							 | 
						            tb_writer = SummaryWriter(summary_fn) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        train_examples, train_data = load_and_cache_gen_data(args, args.train_filename, pool, tokenizer, 'train') | 
					
					
						
						| 
							 | 
						        train_sampler = RandomSampler(train_data) if args.local_rank == -1 else DistributedSampler(train_data) | 
					
					
						
						| 
							 | 
						        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size, | 
					
					
						
						| 
							 | 
						                                      num_workers=4, pin_memory=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        no_decay = ['bias', 'LayerNorm.weight'] | 
					
					
						
						| 
							 | 
						        optimizer_grouped_parameters = [ | 
					
					
						
						| 
							 | 
						            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | 
					
					
						
						| 
							 | 
						             'weight_decay': args.weight_decay}, | 
					
					
						
						| 
							 | 
						            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} | 
					
					
						
						| 
							 | 
						        ] | 
					
					
						
						| 
							 | 
						        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) | 
					
					
						
						| 
							 | 
						        num_train_optimization_steps = args.num_train_epochs * len(train_dataloader) | 
					
					
						
						| 
							 | 
						        scheduler = get_linear_schedule_with_warmup(optimizer, | 
					
					
						
						| 
							 | 
						                                                    num_warmup_steps=args.warmup_steps, | 
					
					
						
						| 
							 | 
						                                                    num_training_steps=num_train_optimization_steps) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        train_example_num = len(train_data) | 
					
					
						
						| 
							 | 
						        logger.info("***** Running training *****") | 
					
					
						
						| 
							 | 
						        logger.info("  Num examples = %d", train_example_num) | 
					
					
						
						| 
							 | 
						        logger.info("  Batch size = %d", args.train_batch_size) | 
					
					
						
						| 
							 | 
						        logger.info("  Batch num = %d", math.ceil(train_example_num / args.train_batch_size)) | 
					
					
						
						| 
							 | 
						        logger.info("  Num epoch = %d", args.num_train_epochs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        dev_dataset = {} | 
					
					
						
						| 
							 | 
						        global_step, best_bleu_em, best_ppl = 0, -1, 1e6 | 
					
					
						
						| 
							 | 
						        not_loss_dec_cnt, not_bleu_em_inc_cnt = 0, 0 if args.do_eval_bleu else 1e6 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for cur_epoch in range(args.start_epoch, int(args.num_train_epochs)): | 
					
					
						
						| 
							 | 
						            bar = tqdm(train_dataloader, total=len(train_dataloader), desc="Training") | 
					
					
						
						| 
							 | 
						            nb_tr_examples, nb_tr_steps, tr_loss = 0, 0, 0 | 
					
					
						
						| 
							 | 
						            model.train() | 
					
					
						
						| 
							 | 
						            for step, batch in enumerate(bar): | 
					
					
						
						| 
							 | 
						                batch = tuple(t.to(args.device) for t in batch) | 
					
					
						
						| 
							 | 
						                source_ids, target_ids = batch | 
					
					
						
						| 
							 | 
						                source_mask = source_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						                target_mask = target_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if args.model_type == 'roberta': | 
					
					
						
						| 
							 | 
						                    loss, _, _ = model(source_ids=source_ids, source_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                       target_ids=target_ids, target_mask=target_mask) | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    outputs = model(input_ids=source_ids, attention_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                    labels=target_ids, decoder_attention_mask=target_mask) | 
					
					
						
						| 
							 | 
						                    loss = outputs.loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if args.n_gpu > 1: | 
					
					
						
						| 
							 | 
						                    loss = loss.mean()   | 
					
					
						
						| 
							 | 
						                if args.gradient_accumulation_steps > 1: | 
					
					
						
						| 
							 | 
						                    loss = loss / args.gradient_accumulation_steps | 
					
					
						
						| 
							 | 
						                tr_loss += loss.item() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                nb_tr_examples += source_ids.size(0) | 
					
					
						
						| 
							 | 
						                nb_tr_steps += 1 | 
					
					
						
						| 
							 | 
						                loss.backward() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if nb_tr_steps % args.gradient_accumulation_steps == 0: | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    optimizer.step() | 
					
					
						
						| 
							 | 
						                    optimizer.zero_grad() | 
					
					
						
						| 
							 | 
						                    scheduler.step() | 
					
					
						
						| 
							 | 
						                    global_step += 1 | 
					
					
						
						| 
							 | 
						                    train_loss = round(tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1), 4) | 
					
					
						
						| 
							 | 
						                    bar.set_description("[{}] Train loss {}".format(cur_epoch, round(train_loss, 3))) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if args.do_eval: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if 'dev_loss' in dev_dataset: | 
					
					
						
						| 
							 | 
						                    eval_examples, eval_data = dev_dataset['dev_loss'] | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev') | 
					
					
						
						| 
							 | 
						                    dev_dataset['dev_loss'] = eval_examples, eval_data | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                eval_ppl = eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer) | 
					
					
						
						| 
							 | 
						                result = {'epoch': cur_epoch, 'global_step': global_step, 'eval_ppl': eval_ppl} | 
					
					
						
						| 
							 | 
						                for key in sorted(result.keys()): | 
					
					
						
						| 
							 | 
						                    logger.info("  %s = %s", key, str(result[key])) | 
					
					
						
						| 
							 | 
						                logger.info("  " + "*" * 20) | 
					
					
						
						| 
							 | 
						                if args.data_num == -1: | 
					
					
						
						| 
							 | 
						                    tb_writer.add_scalar('dev_ppl', eval_ppl, cur_epoch) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if args.save_last_checkpoints: | 
					
					
						
						| 
							 | 
						                    last_output_dir = os.path.join(args.output_dir, 'checkpoint-last') | 
					
					
						
						| 
							 | 
						                    if not os.path.exists(last_output_dir): | 
					
					
						
						| 
							 | 
						                        os.makedirs(last_output_dir) | 
					
					
						
						| 
							 | 
						                    model_to_save = model.module if hasattr(model, 'module') else model | 
					
					
						
						| 
							 | 
						                    output_model_file = os.path.join(last_output_dir, "pytorch_model.bin") | 
					
					
						
						| 
							 | 
						                    torch.save(model_to_save.state_dict(), output_model_file) | 
					
					
						
						| 
							 | 
						                    logger.info("Save the last model into %s", output_model_file) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if eval_ppl < best_ppl: | 
					
					
						
						| 
							 | 
						                    not_loss_dec_cnt = 0 | 
					
					
						
						| 
							 | 
						                    logger.info("  Best ppl:%s", eval_ppl) | 
					
					
						
						| 
							 | 
						                    logger.info("  " + "*" * 20) | 
					
					
						
						| 
							 | 
						                    fa.write("[%d] Best ppl changed into %.4f\n" % (cur_epoch, eval_ppl)) | 
					
					
						
						| 
							 | 
						                    best_ppl = eval_ppl | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    output_dir = os.path.join(args.output_dir, 'checkpoint-best-ppl') | 
					
					
						
						| 
							 | 
						                    if not os.path.exists(output_dir): | 
					
					
						
						| 
							 | 
						                        os.makedirs(output_dir) | 
					
					
						
						| 
							 | 
						                    if args.always_save_model: | 
					
					
						
						| 
							 | 
						                        model_to_save = model.module if hasattr(model, 'module') else model | 
					
					
						
						| 
							 | 
						                        output_model_file = os.path.join(output_dir, "pytorch_model.bin") | 
					
					
						
						| 
							 | 
						                        torch.save(model_to_save.state_dict(), output_model_file) | 
					
					
						
						| 
							 | 
						                        logger.info("Save the best ppl model into %s", output_model_file) | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    not_loss_dec_cnt += 1 | 
					
					
						
						| 
							 | 
						                    logger.info("Ppl does not decrease for %d epochs", not_loss_dec_cnt) | 
					
					
						
						| 
							 | 
						                    if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]): | 
					
					
						
						| 
							 | 
						                        early_stop_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % ( | 
					
					
						
						| 
							 | 
						                            cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt) | 
					
					
						
						| 
							 | 
						                        logger.info(early_stop_str) | 
					
					
						
						| 
							 | 
						                        fa.write(early_stop_str) | 
					
					
						
						| 
							 | 
						                        break | 
					
					
						
						| 
							 | 
						                logger.info("***** CUDA.empty_cache() *****") | 
					
					
						
						| 
							 | 
						                torch.cuda.empty_cache() | 
					
					
						
						| 
							 | 
						                if args.do_eval_bleu: | 
					
					
						
						| 
							 | 
						                    eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev', | 
					
					
						
						| 
							 | 
						                                                                       only_src=True, is_sample=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'dev', 'e%d' % cur_epoch) | 
					
					
						
						| 
							 | 
						                    dev_bleu, dev_em = result['bleu'], result['em'] | 
					
					
						
						| 
							 | 
						                    if args.task in ['summarize']: | 
					
					
						
						| 
							 | 
						                        dev_bleu_em = dev_bleu | 
					
					
						
						| 
							 | 
						                    elif args.task in ['defect']: | 
					
					
						
						| 
							 | 
						                        dev_bleu_em = dev_em | 
					
					
						
						| 
							 | 
						                    else: | 
					
					
						
						| 
							 | 
						                        dev_bleu_em = dev_bleu + dev_em | 
					
					
						
						| 
							 | 
						                    if args.data_num == -1: | 
					
					
						
						| 
							 | 
						                        tb_writer.add_scalar('dev_bleu_em', dev_bleu_em, cur_epoch) | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                    if dev_bleu_em > best_bleu_em: | 
					
					
						
						| 
							 | 
						                        not_bleu_em_inc_cnt = 0 | 
					
					
						
						| 
							 | 
						                        logger.info("  [%d] Best bleu+em: %.2f (bleu: %.2f, em: %.2f)", | 
					
					
						
						| 
							 | 
						                                    cur_epoch, dev_bleu_em, dev_bleu, dev_em) | 
					
					
						
						| 
							 | 
						                        logger.info("  " + "*" * 20) | 
					
					
						
						| 
							 | 
						                        best_bleu_em = dev_bleu_em | 
					
					
						
						| 
							 | 
						                        fa.write("[%d] Best bleu+em changed into %.2f (bleu: %.2f, em: %.2f)\n" % ( | 
					
					
						
						| 
							 | 
						                            cur_epoch, best_bleu_em, dev_bleu, dev_em)) | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu') | 
					
					
						
						| 
							 | 
						                        if not os.path.exists(output_dir): | 
					
					
						
						| 
							 | 
						                            os.makedirs(output_dir) | 
					
					
						
						| 
							 | 
						                        if args.data_num == -1 or args.always_save_model: | 
					
					
						
						| 
							 | 
						                            model_to_save = model.module if hasattr(model, 'module') else model | 
					
					
						
						| 
							 | 
						                            output_model_file = os.path.join(output_dir, "pytorch_model.bin") | 
					
					
						
						| 
							 | 
						                            torch.save(model_to_save.state_dict(), output_model_file) | 
					
					
						
						| 
							 | 
						                            logger.info("Save the best bleu model into %s", output_model_file) | 
					
					
						
						| 
							 | 
						                    else: | 
					
					
						
						| 
							 | 
						                        not_bleu_em_inc_cnt += 1 | 
					
					
						
						| 
							 | 
						                        logger.info("Bleu does not increase for %d epochs", not_bleu_em_inc_cnt) | 
					
					
						
						| 
							 | 
						                        fa.write( | 
					
					
						
						| 
							 | 
						                            "[%d] Best bleu+em (%.2f) does not drop changed for %d epochs, cur bleu+em: %.2f (bleu: %.2f, em: %.2f)\n" % ( | 
					
					
						
						| 
							 | 
						                                cur_epoch, best_bleu_em, not_bleu_em_inc_cnt, dev_bleu_em, dev_bleu, dev_em)) | 
					
					
						
						| 
							 | 
						                        if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]): | 
					
					
						
						| 
							 | 
						                            stop_early_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % ( | 
					
					
						
						| 
							 | 
						                                cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt) | 
					
					
						
						| 
							 | 
						                            logger.info(stop_early_str) | 
					
					
						
						| 
							 | 
						                            fa.write(stop_early_str) | 
					
					
						
						| 
							 | 
						                            break | 
					
					
						
						| 
							 | 
						            logger.info("***** CUDA.empty_cache() *****") | 
					
					
						
						| 
							 | 
						            torch.cuda.empty_cache() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if args.local_rank in [-1, 0] and args.data_num == -1: | 
					
					
						
						| 
							 | 
						            tb_writer.close() | 
					
					
						
						| 
							 | 
						        logger.info("Finish training and take %s", get_elapse_time(t0)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.do_test: | 
					
					
						
						| 
							 | 
						        logger.info("  " + "***** Testing *****") | 
					
					
						
						| 
							 | 
						        logger.info("  Batch size = %d", args.eval_batch_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for criteria in ['best-bleu']: | 
					
					
						
						| 
							 | 
						            file = os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria)) | 
					
					
						
						| 
							 | 
						            logger.info("Reload model from {}".format(file)) | 
					
					
						
						| 
							 | 
						            model.load_state_dict(torch.load(file)) | 
					
					
						
						| 
							 | 
						            eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, pool, tokenizer, 'test', | 
					
					
						
						| 
							 | 
						                                                               only_src=True, is_sample=False) | 
					
					
						
						| 
							 | 
						            result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'test', criteria) | 
					
					
						
						| 
							 | 
						            test_bleu, test_em = result['bleu'], result['em'] | 
					
					
						
						| 
							 | 
						            test_codebleu = result['codebleu'] if 'codebleu' in result else 0 | 
					
					
						
						| 
							 | 
						            result_str = "[%s] bleu-4: %.2f, em: %.4f, codebleu: %.4f\n" % (criteria, test_bleu, test_em, test_codebleu) | 
					
					
						
						| 
							 | 
						            logger.info(result_str) | 
					
					
						
						| 
							 | 
						            fa.write(result_str) | 
					
					
						
						| 
							 | 
						            if args.res_fn: | 
					
					
						
						| 
							 | 
						                with open(args.res_fn, 'a+') as f: | 
					
					
						
						| 
							 | 
						                    f.write('[Time: {}] {}\n'.format(get_elapse_time(t0), file)) | 
					
					
						
						| 
							 | 
						                    f.write(result_str) | 
					
					
						
						| 
							 | 
						    logger.info("Finish and take {}".format(get_elapse_time(t0))) | 
					
					
						
						| 
							 | 
						    fa.write("Finish and take {}".format(get_elapse_time(t0))) | 
					
					
						
						| 
							 | 
						    fa.close() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    main() | 
					
					
						
						| 
							 | 
						
 |