| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). | 
					
					
						
						| 
							 | 
						GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned | 
					
					
						
						| 
							 | 
						using a masked language modeling (MLM) loss. | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import argparse | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						from tqdm import tqdm | 
					
					
						
						| 
							 | 
						from itertools import cycle | 
					
					
						
						| 
							 | 
						import multiprocessing | 
					
					
						
						| 
							 | 
						import time | 
					
					
						
						| 
							 | 
						import sys | 
					
					
						
						| 
							 | 
						import pdb | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from torch.utils.tensorboard import SummaryWriter | 
					
					
						
						| 
							 | 
						from torch.utils.data import DataLoader, SequentialSampler, RandomSampler | 
					
					
						
						| 
							 | 
						from torch.utils.data.distributed import DistributedSampler | 
					
					
						
						| 
							 | 
						from transformers import AdamW, get_linear_schedule_with_warmup | 
					
					
						
						| 
							 | 
						from models import build_or_load_gen_model | 
					
					
						
						| 
							 | 
						from evaluator import smooth_bleu | 
					
					
						
						| 
							 | 
						from evaluator.CodeBLEU import calc_code_bleu | 
					
					
						
						| 
							 | 
						from evaluator.bleu import _bleu | 
					
					
						
						| 
							 | 
						from utils import get_elapse_time, load_and_cache_multi_gen_data | 
					
					
						
						| 
							 | 
						from configs import add_args, set_seed, set_dist | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						cpu_cont = multiprocessing.cpu_count() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s', | 
					
					
						
						| 
							 | 
						                    datefmt='%m/%d/%Y %H:%M:%S', | 
					
					
						
						| 
							 | 
						                    level=logging.INFO) | 
					
					
						
						| 
							 | 
						logger = logging.getLogger(__name__) | 
					
					
						
						| 
							 | 
						WORKER_NUM = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_max_trg_len_by_task(task, sub_task): | 
					
					
						
						| 
							 | 
						    if task == 'summarize': | 
					
					
						
						| 
							 | 
						        max_target_length = 128 | 
					
					
						
						| 
							 | 
						    elif task == 'translate': | 
					
					
						
						| 
							 | 
						        max_target_length = 256 | 
					
					
						
						| 
							 | 
						    elif task == 'refine': | 
					
					
						
						| 
							 | 
						        if sub_task == 'small': | 
					
					
						
						| 
							 | 
						            max_target_length = 120 | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            max_target_length = 240 | 
					
					
						
						| 
							 | 
						    elif task == 'concode': | 
					
					
						
						| 
							 | 
						        max_target_length = 150 | 
					
					
						
						| 
							 | 
						    elif task == 'defect': | 
					
					
						
						| 
							 | 
						        max_target_length = 3 | 
					
					
						
						| 
							 | 
						    return max_target_length | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_bs(cur_task, model_tag): | 
					
					
						
						| 
							 | 
						    task = cur_task.split('_')[0] | 
					
					
						
						| 
							 | 
						    sub_task = cur_task.split('_')[-1] | 
					
					
						
						| 
							 | 
						    if 'codet5_small' in model_tag: | 
					
					
						
						| 
							 | 
						        bs = 32 | 
					
					
						
						| 
							 | 
						        if task == 'summarize' or task == 'translate' or (task == 'refine' and sub_task == 'small'): | 
					
					
						
						| 
							 | 
						            bs = 64 | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        bs = 28 | 
					
					
						
						| 
							 | 
						        if task == 'translate': | 
					
					
						
						| 
							 | 
						            bs = 25 | 
					
					
						
						| 
							 | 
						        elif task == 'summarize': | 
					
					
						
						| 
							 | 
						            bs = 40 | 
					
					
						
						| 
							 | 
						    return bs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def eval_bleu(args, eval_data, eval_examples, model, tokenizer, split_tag, cur_task, criteria): | 
					
					
						
						| 
							 | 
						    eval_sampler = SequentialSampler(eval_data) | 
					
					
						
						| 
							 | 
						    if args.data_num == -1: | 
					
					
						
						| 
							 | 
						        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, | 
					
					
						
						| 
							 | 
						                                     num_workers=4, pin_memory=True) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) | 
					
					
						
						| 
							 | 
						    task = cur_task.split('_')[0] | 
					
					
						
						| 
							 | 
						    sub_task = cur_task.split('_')[-1] | 
					
					
						
						| 
							 | 
						    max_target_length = get_max_trg_len_by_task(task, sub_task) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model.eval() | 
					
					
						
						| 
							 | 
						    pred_ids = [] | 
					
					
						
						| 
							 | 
						    for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval bleu for {} set".format(split_tag)): | 
					
					
						
						| 
							 | 
						        source_ids = batch[0].to(args.device) | 
					
					
						
						| 
							 | 
						        source_mask = source_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						        with torch.no_grad(): | 
					
					
						
						| 
							 | 
						            if args.model_type == 'roberta': | 
					
					
						
						| 
							 | 
						                preds = model(source_ids=source_ids, source_mask=source_mask) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                top_preds = [pred[0].cpu().numpy() for pred in preds] | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                preds = model.generate(source_ids, | 
					
					
						
						| 
							 | 
						                                       attention_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                       use_cache=True, | 
					
					
						
						| 
							 | 
						                                       num_beams=5, | 
					
					
						
						| 
							 | 
						                                       max_length=max_target_length,   | 
					
					
						
						| 
							 | 
						                                       early_stopping=task == 'summarize') | 
					
					
						
						| 
							 | 
						                top_preds = list(preds.cpu().numpy()) | 
					
					
						
						| 
							 | 
						            pred_ids.extend(top_preds) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids] | 
					
					
						
						| 
							 | 
						    if task == 'defect': | 
					
					
						
						| 
							 | 
						        target_dict = {0: 'false', 1: 'true'} | 
					
					
						
						| 
							 | 
						        golds = [target_dict[ex.target] for ex in eval_examples] | 
					
					
						
						| 
							 | 
						        eval_acc = np.mean([int(p == g) for p, g in zip(pred_nls, golds)]) | 
					
					
						
						| 
							 | 
						        result = {'em': eval_acc, 'bleu': 0, 'codebleu': 0} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        dev_accs = [] | 
					
					
						
						| 
							 | 
						        predictions = [] | 
					
					
						
						| 
							 | 
						        res_dir = os.path.join(args.res_dir, cur_task) | 
					
					
						
						| 
							 | 
						        if not os.path.exists(res_dir): | 
					
					
						
						| 
							 | 
						            os.makedirs(res_dir) | 
					
					
						
						| 
							 | 
						        output_fn = os.path.join(res_dir, "test_{}.output".format(criteria)) | 
					
					
						
						| 
							 | 
						        gold_fn = os.path.join(res_dir, "test_{}.gold".format(criteria)) | 
					
					
						
						| 
							 | 
						        with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1: | 
					
					
						
						| 
							 | 
						            for pred_nl, gold in zip(pred_nls, eval_examples): | 
					
					
						
						| 
							 | 
						                dev_accs.append(pred_nl.strip() == gold.target.strip()) | 
					
					
						
						| 
							 | 
						                if task == 'summarize': | 
					
					
						
						| 
							 | 
						                    predictions.append(str(gold.idx) + '\t' + pred_nl) | 
					
					
						
						| 
							 | 
						                    f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n') | 
					
					
						
						| 
							 | 
						                    f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n') | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                    f.write(pred_nl.strip() + '\n') | 
					
					
						
						| 
							 | 
						                    f1.write(gold.target.strip() + '\n') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            if task == 'summarize': | 
					
					
						
						| 
							 | 
						                (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn) | 
					
					
						
						| 
							 | 
						                bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                bleu = round(_bleu(gold_fn, output_fn), 2) | 
					
					
						
						| 
							 | 
						                if split_tag == 'test': | 
					
					
						
						| 
							 | 
						                    if task in ['summarize', 'search']: | 
					
					
						
						| 
							 | 
						                        cur_lang = sub_task | 
					
					
						
						| 
							 | 
						                    elif task in ['refine', 'concode', 'clone']: | 
					
					
						
						| 
							 | 
						                        cur_lang = 'java' | 
					
					
						
						| 
							 | 
						                    elif task == 'defect': | 
					
					
						
						| 
							 | 
						                        cur_lang = 'c' | 
					
					
						
						| 
							 | 
						                    elif task == 'translate': | 
					
					
						
						| 
							 | 
						                        cur_lang = 'c_sharp' if sub_task == 'java-cs' else 'java' | 
					
					
						
						| 
							 | 
						                    codebleu = calc_code_bleu.get_codebleu(gold_fn, output_fn, cur_lang) | 
					
					
						
						| 
							 | 
						        except: | 
					
					
						
						| 
							 | 
						            bleu = 0.0 | 
					
					
						
						| 
							 | 
						            codebleu = 0.0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        result = {} | 
					
					
						
						| 
							 | 
						        em = np.mean(dev_accs) * 100 | 
					
					
						
						| 
							 | 
						        result['em'] = em | 
					
					
						
						| 
							 | 
						        result['bleu'] = bleu | 
					
					
						
						| 
							 | 
						        if not args.task == 'summarize' and split_tag == 'test': | 
					
					
						
						| 
							 | 
						            result['codebleu'] = codebleu * 100 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    logger.info("***** Eval results [%s] *****", cur_task) | 
					
					
						
						| 
							 | 
						    for key in sorted(result.keys()): | 
					
					
						
						| 
							 | 
						        logger.info("  %s = %s", key, str(round(result[key], 4))) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return result | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def main(): | 
					
					
						
						| 
							 | 
						    parser = argparse.ArgumentParser() | 
					
					
						
						| 
							 | 
						    args = add_args(parser) | 
					
					
						
						| 
							 | 
						    logger.info(args) | 
					
					
						
						| 
							 | 
						    t0 = time.time() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    set_dist(args) | 
					
					
						
						| 
							 | 
						    set_seed(args) | 
					
					
						
						| 
							 | 
						    config, model, tokenizer = build_or_load_gen_model(args) | 
					
					
						
						| 
							 | 
						    model.to(args.device) | 
					
					
						
						| 
							 | 
						    if args.n_gpu > 1: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        model = torch.nn.DataParallel(model) | 
					
					
						
						| 
							 | 
						    pool = multiprocessing.Pool(args.cpu_cont) | 
					
					
						
						| 
							 | 
						    fa = open(os.path.join(args.output_dir, 'summary.log'), 'a+') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    fa_dict = {} | 
					
					
						
						| 
							 | 
						    if args.do_train: | 
					
					
						
						| 
							 | 
						        if args.local_rank in [-1, 0] and args.data_num == -1: | 
					
					
						
						| 
							 | 
						            summary_fn = './tensorboard/{}'.format('/'.join(args.output_dir.split('/')[1:])) | 
					
					
						
						| 
							 | 
						            tb_writer = SummaryWriter(summary_fn) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        train_examples_data_dict = load_and_cache_multi_gen_data(args, pool, tokenizer, 'train', is_sample=False) | 
					
					
						
						| 
							 | 
						        train_data_list = [v[1] for k, v in train_examples_data_dict.items()] | 
					
					
						
						| 
							 | 
						        all_tasks = [k for k, v in train_examples_data_dict.items()] | 
					
					
						
						| 
							 | 
						        total_train_data_num = sum([len(v[0]) for k, v in train_examples_data_dict.items()]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for cur_task in all_tasks: | 
					
					
						
						| 
							 | 
						            summary_dir = os.path.join(args.output_dir, 'summary') | 
					
					
						
						| 
							 | 
						            if not os.path.exists(summary_dir): | 
					
					
						
						| 
							 | 
						                os.makedirs(summary_dir) | 
					
					
						
						| 
							 | 
						            fa_dict[cur_task] = open(os.path.join(summary_dir, '{}_summary.log'.format(cur_task)), 'a+') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        train_dataloader_dict = dict() | 
					
					
						
						| 
							 | 
						        for train_data, cur_task in zip(train_data_list, all_tasks): | 
					
					
						
						| 
							 | 
						            if args.local_rank == -1: | 
					
					
						
						| 
							 | 
						                train_sampler = RandomSampler(train_data) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                train_sampler = DistributedSampler(train_data) | 
					
					
						
						| 
							 | 
						            if args.data_num == -1: | 
					
					
						
						| 
							 | 
						                train_dataloader = DataLoader(train_data, sampler=train_sampler, | 
					
					
						
						| 
							 | 
						                                              batch_size=get_bs(cur_task, args.model_name_or_path), | 
					
					
						
						| 
							 | 
						                                              num_workers=WORKER_NUM, pin_memory=True) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                train_dataloader = DataLoader(train_data, sampler=train_sampler, | 
					
					
						
						| 
							 | 
						                                              batch_size=get_bs(cur_task, args.model_name_or_path)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            train_dataloader_dict[cur_task] = cycle(train_dataloader) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        no_decay = ['bias', 'LayerNorm.weight'] | 
					
					
						
						| 
							 | 
						        optimizer_grouped_parameters = [ | 
					
					
						
						| 
							 | 
						            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | 
					
					
						
						| 
							 | 
						             'weight_decay': args.weight_decay}, | 
					
					
						
						| 
							 | 
						            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} | 
					
					
						
						| 
							 | 
						        ] | 
					
					
						
						| 
							 | 
						        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        scheduler = get_linear_schedule_with_warmup(optimizer, | 
					
					
						
						| 
							 | 
						                                                    num_warmup_steps=args.warmup_steps, | 
					
					
						
						| 
							 | 
						                                                    num_training_steps=args.max_steps) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logger.info("***** Running training *****") | 
					
					
						
						| 
							 | 
						        logger.info("  Total train data num = %d", total_train_data_num) | 
					
					
						
						| 
							 | 
						        logger.info("  Max step = %d, Save step = %d", args.max_steps, args.save_steps) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        dev_dataset = {} | 
					
					
						
						| 
							 | 
						        step, global_step = 0, 0 | 
					
					
						
						| 
							 | 
						        best_bleu_em = dict([(k, -1) for k in all_tasks]) | 
					
					
						
						| 
							 | 
						        best_loss = dict([(k, 1e6) for k in all_tasks]) | 
					
					
						
						| 
							 | 
						        not_bleu_em_inc_cnt = dict([(k, 0) for k in all_tasks]) | 
					
					
						
						| 
							 | 
						        is_early_stop = dict([(k, 0) for k in all_tasks]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        patience_pairs = [] | 
					
					
						
						| 
							 | 
						        for cur_task in all_tasks: | 
					
					
						
						| 
							 | 
						            task = cur_task.split('_')[0] | 
					
					
						
						| 
							 | 
						            if task == 'summarize': | 
					
					
						
						| 
							 | 
						                patience_pairs.append((cur_task, 2)) | 
					
					
						
						| 
							 | 
						            elif task == 'translate': | 
					
					
						
						| 
							 | 
						                patience_pairs.append((cur_task, 5)) | 
					
					
						
						| 
							 | 
						            elif task == 'refine': | 
					
					
						
						| 
							 | 
						                patience_pairs.append((cur_task, 5)) | 
					
					
						
						| 
							 | 
						            elif task == 'concode': | 
					
					
						
						| 
							 | 
						                patience_pairs.append((cur_task, 3)) | 
					
					
						
						| 
							 | 
						            elif task == 'defect': | 
					
					
						
						| 
							 | 
						                patience_pairs.append((cur_task, 2)) | 
					
					
						
						| 
							 | 
						        patience_dict = dict(patience_pairs) | 
					
					
						
						| 
							 | 
						        logger.info('Patience: %s', patience_dict) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        probs = [len(x) for x in train_data_list] | 
					
					
						
						| 
							 | 
						        probs = [x / sum(probs) for x in probs] | 
					
					
						
						| 
							 | 
						        probs = [x ** 0.7 for x in probs] | 
					
					
						
						| 
							 | 
						        probs = [x / sum(probs) for x in probs] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        nb_tr_examples, nb_tr_steps, tr_nb, tr_loss, logging_loss = 0, 0, 0, 0, 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        bar = tqdm(total=args.max_steps, desc="Training") | 
					
					
						
						| 
							 | 
						        skip_cnt = 0 | 
					
					
						
						| 
							 | 
						        while True: | 
					
					
						
						| 
							 | 
						            cur_task = np.random.choice(all_tasks, 1, p=probs)[0] | 
					
					
						
						| 
							 | 
						            train_dataloader = train_dataloader_dict[cur_task] | 
					
					
						
						| 
							 | 
						            if is_early_stop[cur_task]: | 
					
					
						
						| 
							 | 
						                skip_cnt += 1 | 
					
					
						
						| 
							 | 
						                if skip_cnt > 50: | 
					
					
						
						| 
							 | 
						                    logger.info('All tasks have early stopped at %d', step) | 
					
					
						
						| 
							 | 
						                    break | 
					
					
						
						| 
							 | 
						                continue | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                skip_cnt = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            step += 1 | 
					
					
						
						| 
							 | 
						            batch = next(train_dataloader) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            model.train() | 
					
					
						
						| 
							 | 
						            batch = tuple(t.to(args.device) for t in batch) | 
					
					
						
						| 
							 | 
						            source_ids, target_ids = batch | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            source_mask = source_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						            target_mask = target_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if args.model_type == 'roberta': | 
					
					
						
						| 
							 | 
						                loss, _, _ = model(source_ids=source_ids, source_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                   target_ids=target_ids, target_mask=target_mask) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                outputs = model(input_ids=source_ids, attention_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                labels=target_ids, decoder_attention_mask=target_mask) | 
					
					
						
						| 
							 | 
						                loss = outputs.loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if args.n_gpu > 1: | 
					
					
						
						| 
							 | 
						                loss = loss.mean()   | 
					
					
						
						| 
							 | 
						            if args.gradient_accumulation_steps > 1: | 
					
					
						
						| 
							 | 
						                loss = loss / args.gradient_accumulation_steps | 
					
					
						
						| 
							 | 
						            tr_loss += loss.item() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            nb_tr_examples += source_ids.size(0) | 
					
					
						
						| 
							 | 
						            nb_tr_steps += 1 | 
					
					
						
						| 
							 | 
						            loss.backward() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if nb_tr_steps % args.gradient_accumulation_steps == 0: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                optimizer.step() | 
					
					
						
						| 
							 | 
						                optimizer.zero_grad() | 
					
					
						
						| 
							 | 
						                scheduler.step() | 
					
					
						
						| 
							 | 
						                global_step += 1 | 
					
					
						
						| 
							 | 
						                train_loss = round((tr_loss - logging_loss) / (global_step - tr_nb), 6) | 
					
					
						
						| 
							 | 
						                bar.update(1) | 
					
					
						
						| 
							 | 
						                bar.set_description("[{}] Train loss {}".format(step, round(train_loss, 3))) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if args.local_rank in [-1, 0] and args.log_steps > 0 and global_step % args.log_steps == 0: | 
					
					
						
						| 
							 | 
						                    logging_loss = train_loss | 
					
					
						
						| 
							 | 
						                    tr_nb = global_step | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                if args.do_eval and args.local_rank in [-1, 0] \ | 
					
					
						
						| 
							 | 
						                        and args.save_steps > 0 and global_step % args.save_steps == 0: | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    if args.data_num == -1 and args.save_last_checkpoints: | 
					
					
						
						| 
							 | 
						                        last_output_dir = os.path.join(args.output_dir, 'checkpoint-last') | 
					
					
						
						| 
							 | 
						                        if not os.path.exists(last_output_dir): | 
					
					
						
						| 
							 | 
						                            os.makedirs(last_output_dir) | 
					
					
						
						| 
							 | 
						                        model_to_save = model.module if hasattr(model, 'module') else model | 
					
					
						
						| 
							 | 
						                        output_model_file = os.path.join(last_output_dir, "pytorch_model.bin") | 
					
					
						
						| 
							 | 
						                        torch.save(model_to_save.state_dict(), output_model_file) | 
					
					
						
						| 
							 | 
						                        logger.info("Save the last model into %s", output_model_file) | 
					
					
						
						| 
							 | 
						                    if global_step % 100000 == 0: | 
					
					
						
						| 
							 | 
						                        step_tag = '{}00k'.format(global_step // 100000) | 
					
					
						
						| 
							 | 
						                        last_output_dir = os.path.join(args.output_dir, 'checkpoint-step-{}'.format(step_tag)) | 
					
					
						
						| 
							 | 
						                        if not os.path.exists(last_output_dir): | 
					
					
						
						| 
							 | 
						                            os.makedirs(last_output_dir) | 
					
					
						
						| 
							 | 
						                        model_to_save = model.module if hasattr(model, 'module') else model | 
					
					
						
						| 
							 | 
						                        output_model_file = os.path.join(last_output_dir, "pytorch_model.bin") | 
					
					
						
						| 
							 | 
						                        torch.save(model_to_save.state_dict(), output_model_file) | 
					
					
						
						| 
							 | 
						                        logger.info("Save the last model into %s", output_model_file) | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    if 'dev_loss' in dev_dataset: | 
					
					
						
						| 
							 | 
						                        eval_examples_data_dict = dev_dataset['dev_loss'] | 
					
					
						
						| 
							 | 
						                    else: | 
					
					
						
						| 
							 | 
						                        eval_examples_data_dict = load_and_cache_multi_gen_data(args, pool, tokenizer, 'dev') | 
					
					
						
						| 
							 | 
						                        dev_dataset['dev_loss'] = eval_examples_data_dict | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    for cur_task in eval_examples_data_dict.keys(): | 
					
					
						
						| 
							 | 
						                        if is_early_stop[cur_task]: | 
					
					
						
						| 
							 | 
						                            continue | 
					
					
						
						| 
							 | 
						                        eval_examples, eval_data = eval_examples_data_dict[cur_task] | 
					
					
						
						| 
							 | 
						                        eval_sampler = SequentialSampler(eval_data) | 
					
					
						
						| 
							 | 
						                        if args.data_num == -1: | 
					
					
						
						| 
							 | 
						                            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, | 
					
					
						
						| 
							 | 
						                                                         batch_size=args.eval_batch_size, | 
					
					
						
						| 
							 | 
						                                                         num_workers=4, pin_memory=True) | 
					
					
						
						| 
							 | 
						                        else: | 
					
					
						
						| 
							 | 
						                            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, | 
					
					
						
						| 
							 | 
						                                                         batch_size=args.eval_batch_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        logger.info("  " + "***** Running ppl evaluation on [{}] *****".format(cur_task)) | 
					
					
						
						| 
							 | 
						                        logger.info("  Num examples = %d", len(eval_examples)) | 
					
					
						
						| 
							 | 
						                        logger.info("  Batch size = %d", args.eval_batch_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        model.eval() | 
					
					
						
						| 
							 | 
						                        eval_loss, batch_num = 0, 0 | 
					
					
						
						| 
							 | 
						                        for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval ppl"): | 
					
					
						
						| 
							 | 
						                            batch = tuple(t.to(args.device) for t in batch) | 
					
					
						
						| 
							 | 
						                            source_ids, target_ids = batch | 
					
					
						
						| 
							 | 
						                            source_mask = source_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						                            target_mask = target_ids.ne(tokenizer.pad_token_id) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                            with torch.no_grad(): | 
					
					
						
						| 
							 | 
						                                if args.model_type == 'roberta': | 
					
					
						
						| 
							 | 
						                                    loss, _, _ = model(source_ids=source_ids, source_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                                       target_ids=target_ids, target_mask=target_mask) | 
					
					
						
						| 
							 | 
						                                else: | 
					
					
						
						| 
							 | 
						                                    outputs = model(input_ids=source_ids, attention_mask=source_mask, | 
					
					
						
						| 
							 | 
						                                                    labels=target_ids, decoder_attention_mask=target_mask) | 
					
					
						
						| 
							 | 
						                                    loss = outputs.loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                            eval_loss += loss.item() | 
					
					
						
						| 
							 | 
						                            batch_num += 1 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        eval_loss = eval_loss / batch_num | 
					
					
						
						| 
							 | 
						                        result = {'cur_task': cur_task, | 
					
					
						
						| 
							 | 
						                                  'global_step': global_step, | 
					
					
						
						| 
							 | 
						                                  'eval_ppl': round(np.exp(eval_loss), 5), | 
					
					
						
						| 
							 | 
						                                  'train_loss': round(train_loss, 5)} | 
					
					
						
						| 
							 | 
						                        for key in sorted(result.keys()): | 
					
					
						
						| 
							 | 
						                            logger.info("  %s = %s", key, str(result[key])) | 
					
					
						
						| 
							 | 
						                        logger.info("  " + "*" * 20) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        if args.data_num == -1: | 
					
					
						
						| 
							 | 
						                            tb_writer.add_scalar('dev_ppl_{}'.format(cur_task), | 
					
					
						
						| 
							 | 
						                                                 round(np.exp(eval_loss), 5), | 
					
					
						
						| 
							 | 
						                                                 global_step) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        if eval_loss < best_loss[cur_task]: | 
					
					
						
						| 
							 | 
						                            logger.info("  Best ppl:%s", round(np.exp(eval_loss), 5)) | 
					
					
						
						| 
							 | 
						                            logger.info("  " + "*" * 20) | 
					
					
						
						| 
							 | 
						                            fa_dict[cur_task].write( | 
					
					
						
						| 
							 | 
						                                "[%d: %s] Best ppl changed into %.4f\n" % (global_step, cur_task, np.exp(eval_loss))) | 
					
					
						
						| 
							 | 
						                            best_loss[cur_task] = eval_loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                            output_dir = os.path.join(args.output_dir, 'checkpoint-best-ppl', cur_task) | 
					
					
						
						| 
							 | 
						                            if not os.path.exists(output_dir): | 
					
					
						
						| 
							 | 
						                                os.makedirs(output_dir) | 
					
					
						
						| 
							 | 
						                            if args.data_num == -1 or args.always_save_model: | 
					
					
						
						| 
							 | 
						                                model_to_save = model.module if hasattr(model, 'module') else model | 
					
					
						
						| 
							 | 
						                                output_model_file = os.path.join(output_dir, "pytorch_model.bin") | 
					
					
						
						| 
							 | 
						                                torch.save(model_to_save.state_dict(), output_model_file) | 
					
					
						
						| 
							 | 
						                                logger.info("Save the best ppl model into %s", output_model_file) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    if args.do_eval_bleu: | 
					
					
						
						| 
							 | 
						                        eval_examples_data_dict = load_and_cache_multi_gen_data(args, pool, tokenizer, 'dev', | 
					
					
						
						| 
							 | 
						                                                                                only_src=True, is_sample=True) | 
					
					
						
						| 
							 | 
						                        for cur_task in eval_examples_data_dict.keys(): | 
					
					
						
						| 
							 | 
						                            if is_early_stop[cur_task]: | 
					
					
						
						| 
							 | 
						                                continue | 
					
					
						
						| 
							 | 
						                            eval_examples, eval_data = eval_examples_data_dict[cur_task] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                            result = eval_bleu(args, eval_data, eval_examples, model, tokenizer, 'dev', cur_task, | 
					
					
						
						| 
							 | 
						                                               criteria='e{}'.format(global_step)) | 
					
					
						
						| 
							 | 
						                            dev_bleu, dev_em = result['bleu'], result['em'] | 
					
					
						
						| 
							 | 
						                            if args.task == 'summarize': | 
					
					
						
						| 
							 | 
						                                dev_bleu_em = dev_bleu | 
					
					
						
						| 
							 | 
						                            elif args.task in ['defect', 'clone']: | 
					
					
						
						| 
							 | 
						                                dev_bleu_em = dev_em | 
					
					
						
						| 
							 | 
						                            else: | 
					
					
						
						| 
							 | 
						                                dev_bleu_em = dev_bleu + dev_em | 
					
					
						
						| 
							 | 
						                            if args.data_num == -1: | 
					
					
						
						| 
							 | 
						                                tb_writer.add_scalar('dev_bleu_em_{}'.format(cur_task), dev_bleu_em, global_step) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                            if dev_bleu_em > best_bleu_em[cur_task]: | 
					
					
						
						| 
							 | 
						                                not_bleu_em_inc_cnt[cur_task] = 0 | 
					
					
						
						| 
							 | 
						                                logger.info("  [%d: %s] Best bleu+em: %.2f (bleu: %.2f, em: %.2f)", | 
					
					
						
						| 
							 | 
						                                            global_step, cur_task, dev_bleu_em, dev_bleu, dev_em) | 
					
					
						
						| 
							 | 
						                                logger.info("  " + "*" * 20) | 
					
					
						
						| 
							 | 
						                                best_bleu_em[cur_task] = dev_bleu_em | 
					
					
						
						| 
							 | 
						                                fa_dict[cur_task].write( | 
					
					
						
						| 
							 | 
						                                    "[%d: %s] Best bleu+em changed into %.2f (bleu: %.2f, em: %.2f)\n" % ( | 
					
					
						
						| 
							 | 
						                                        global_step, cur_task, best_bleu_em[cur_task], dev_bleu, dev_em)) | 
					
					
						
						| 
							 | 
						                                 | 
					
					
						
						| 
							 | 
						                                output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu', cur_task) | 
					
					
						
						| 
							 | 
						                                if not os.path.exists(output_dir): | 
					
					
						
						| 
							 | 
						                                    os.makedirs(output_dir) | 
					
					
						
						| 
							 | 
						                                if args.data_num == -1 or args.always_save_model: | 
					
					
						
						| 
							 | 
						                                    model_to_save = model.module if hasattr(model, 'module') else model | 
					
					
						
						| 
							 | 
						                                    output_model_file = os.path.join(output_dir, "pytorch_model.bin") | 
					
					
						
						| 
							 | 
						                                    torch.save(model_to_save.state_dict(), output_model_file) | 
					
					
						
						| 
							 | 
						                                    logger.info("Save the best bleu model into %s", output_model_file) | 
					
					
						
						| 
							 | 
						                            else: | 
					
					
						
						| 
							 | 
						                                not_bleu_em_inc_cnt[cur_task] += 1 | 
					
					
						
						| 
							 | 
						                                logger.info("[%d %s] bleu/em does not increase for %d eval steps", | 
					
					
						
						| 
							 | 
						                                            global_step, cur_task, not_bleu_em_inc_cnt[cur_task]) | 
					
					
						
						| 
							 | 
						                                if not_bleu_em_inc_cnt[cur_task] > patience_dict[cur_task]: | 
					
					
						
						| 
							 | 
						                                    logger.info("[%d %s] Early stop as bleu/em does not increase for %d eval steps", | 
					
					
						
						| 
							 | 
						                                                global_step, cur_task, not_bleu_em_inc_cnt[cur_task]) | 
					
					
						
						| 
							 | 
						                                    is_early_stop[cur_task] = 1 | 
					
					
						
						| 
							 | 
						                                    fa_dict[cur_task].write( | 
					
					
						
						| 
							 | 
						                                        "[%d %s] Early stop as bleu/em does not increase for %d eval steps, takes %s" % | 
					
					
						
						| 
							 | 
						                                        (global_step, cur_task, not_bleu_em_inc_cnt[cur_task], get_elapse_time(t0))) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    logger.info("***** CUDA.empty_cache() *****") | 
					
					
						
						| 
							 | 
						                    torch.cuda.empty_cache() | 
					
					
						
						| 
							 | 
						                if global_step >= args.max_steps: | 
					
					
						
						| 
							 | 
						                    logger.info("Reach the max step: %d", args.max_steps) | 
					
					
						
						| 
							 | 
						                    break | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if args.local_rank in [-1, 0] and args.data_num == -1: | 
					
					
						
						| 
							 | 
						            tb_writer.close() | 
					
					
						
						| 
							 | 
						        logger.info("Finish training and take %.2f", time.time() - t0) | 
					
					
						
						| 
							 | 
						        for cur_task in all_tasks: | 
					
					
						
						| 
							 | 
						            fa_dict[cur_task].close() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if args.do_test: | 
					
					
						
						| 
							 | 
						        logger.info("  " + "***** Testing *****") | 
					
					
						
						| 
							 | 
						        logger.info("  Batch size = %d", args.eval_batch_size) | 
					
					
						
						| 
							 | 
						        eval_examples_data_dict = load_and_cache_multi_gen_data(args, pool, tokenizer, 'test', only_src=True) | 
					
					
						
						| 
							 | 
						        all_tasks = list(eval_examples_data_dict.keys()) | 
					
					
						
						| 
							 | 
						        for cur_task in all_tasks: | 
					
					
						
						| 
							 | 
						            summary_dir = os.path.join(args.output_dir, 'summary') | 
					
					
						
						| 
							 | 
						            if not os.path.exists(summary_dir): | 
					
					
						
						| 
							 | 
						                os.makedirs(summary_dir) | 
					
					
						
						| 
							 | 
						            fa_dict[cur_task] = open(os.path.join(summary_dir, '{}_summary.log'.format(cur_task)), 'a+') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for cur_task in all_tasks: | 
					
					
						
						| 
							 | 
						            eval_examples, eval_data = eval_examples_data_dict[cur_task] | 
					
					
						
						| 
							 | 
						            args.task = cur_task.split('_')[0] | 
					
					
						
						| 
							 | 
						            args.sub_task = cur_task.split('_')[-1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            for criteria in ['best-bleu', 'best-ppl', 'last']: | 
					
					
						
						| 
							 | 
						                file = os.path.join(args.output_dir, 'checkpoint-{}/{}/pytorch_model.bin'.format(criteria, cur_task)) | 
					
					
						
						| 
							 | 
						                model.load_state_dict(torch.load(file)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                result = eval_bleu(args, eval_data, eval_examples, model, tokenizer, 'test', cur_task, criteria) | 
					
					
						
						| 
							 | 
						                test_bleu, test_em = result['bleu'], result['em'] | 
					
					
						
						| 
							 | 
						                test_codebleu = result['codebleu'] if 'codebleu' in result else 0 | 
					
					
						
						| 
							 | 
						                result_str = "[%s %s] bleu-4: %.2f, em: %.4f, codebleu: %.4f\n" % ( | 
					
					
						
						| 
							 | 
						                    cur_task, criteria, test_bleu, test_em, test_codebleu) | 
					
					
						
						| 
							 | 
						                logger.info(result_str) | 
					
					
						
						| 
							 | 
						                fa_dict[cur_task].write(result_str) | 
					
					
						
						| 
							 | 
						                fa.write(result_str) | 
					
					
						
						| 
							 | 
						                if args.res_fn: | 
					
					
						
						| 
							 | 
						                    with open(args.res_fn, 'a+') as f: | 
					
					
						
						| 
							 | 
						                        f.write('[Time: {}] {}\n'.format(get_elapse_time(t0), file)) | 
					
					
						
						| 
							 | 
						                        f.write(result_str) | 
					
					
						
						| 
							 | 
						    logger.info("Finish and take {}".format(get_elapse_time(t0))) | 
					
					
						
						| 
							 | 
						    for cur_task in all_tasks: | 
					
					
						
						| 
							 | 
						        fa_dict[cur_task].close() | 
					
					
						
						| 
							 | 
						    fa.write("Finish and take {}".format(get_elapse_time(t0))) | 
					
					
						
						| 
							 | 
						    fa.close() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    main() | 
					
					
						
						| 
							 | 
						
 |