Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import numpy as np | |
| import time | |
| import os | |
| from collections import defaultdict | |
| import captioning.utils.opts as opts | |
| import captioning.models as models | |
| from captioning.data.pth_loader import CaptionDataset | |
| import captioning.utils.eval_utils as eval_utils | |
| import captioning.utils.misc as utils | |
| from captioning.utils.rewards import init_scorer, get_self_critical_reward | |
| from captioning.modules.loss_wrapper import LossWrapper | |
| import pytorch_lightning as pl | |
| import detectron2.utils.comm as d2comm | |
| from detectron2.utils.env import seed_all_rng | |
| seed_all_rng(1234) | |
| class LitModel(pl.LightningModule): | |
| def __init__(self, opt): | |
| super().__init__() | |
| self.opt = opt | |
| # Intilaize dataset | |
| self.dataset = CaptionDataset(opt) | |
| opt.vocab_size = self.dataset.vocab_size | |
| opt.seq_length = self.dataset.seq_length | |
| self.batch_size = opt.batch_size | |
| # Build model | |
| opt.vocab = self.dataset.get_vocab() | |
| model = models.setup(opt) | |
| # print(model) | |
| del opt.vocab | |
| # wrapper with loss in it. | |
| lw_model = LossWrapper(model, opt) | |
| self.model = model | |
| self.lw_model = lw_model | |
| self.struc_flag = None | |
| self.sc_flag = None | |
| # if self.opt.use_clipscore: | |
| # if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': | |
| # if CLIP-S+Grammar is used in reward -> Launch another CLIP-S where parameter is unchanged | |
| if getattr(self.opt, 'use_grammar', False): | |
| from captioning.utils.clipscore import CLIPScore | |
| self.val_clipscore_model = CLIPScore( | |
| mode=opt.clipscore_mode, use_grammar=False) | |
| for p in self.val_clipscore_model.parameters(): | |
| p.requires_grad = False | |
| else: | |
| if self.lw_model.clipscore_model is not None: | |
| self.val_clipscore_model = self.lw_model.clipscore_model | |
| else: | |
| from captioning.utils.clipscore import CLIPScore | |
| self.val_clipscore_model = CLIPScore( | |
| mode=opt.clipscore_mode, use_grammar=False) | |
| for p in self.val_clipscore_model.parameters(): | |
| p.requires_grad = False | |
| self.val_clipscore_model.eval() | |
| # BERTSCORE | |
| from bert_score import BERTScorer | |
| self.bert_scorer = BERTScorer( | |
| lang="en", | |
| # rescale_with_baseline=True, | |
| rescale_with_baseline=False, | |
| device='cpu' | |
| ) | |
| def forward(self, *args, **kwargs): | |
| """ | |
| I hate this design. Never pretend it as a nn.Module | |
| """ | |
| raise NotImplementedError | |
| def train_dataloader(self): | |
| train_dataset = torch.utils.data.Subset( | |
| self.dataset, | |
| self.dataset.split_ix['train'] | |
| ) | |
| train_loader = torch.utils.data.DataLoader( | |
| dataset=train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| collate_fn=self.dataset.collate_func | |
| ) | |
| return train_loader | |
| def val_dataloader(self, split='val'): | |
| val_dataset = torch.utils.data.Subset( | |
| self.dataset, | |
| self.dataset.split_ix[split] | |
| ) | |
| val_loader = torch.utils.data.DataLoader( | |
| val_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| drop_last=False, | |
| collate_fn=self.dataset.collate_func | |
| ) | |
| return val_loader | |
| def test_dataloader(self): | |
| return self.val_dataloader('test') | |
| def training_step(self, data, batch_idx): | |
| sc_flag, struc_flag = self.sc_flag, self.struc_flag | |
| tmp = [data['fc_feats'], data['att_feats'], | |
| data['labels'], data['masks'], data['att_masks']] | |
| fc_feats, att_feats, labels, masks, att_masks = tmp | |
| if int(os.getenv('M2_cider', '0')) != 0: | |
| data['gts'] = data['rawgts'] | |
| if self.opt.use_clipscore: | |
| clip_vis_feats = data['clip_vis_feats'] | |
| model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks, | |
| data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag, | |
| clip_vis_feats=clip_vis_feats) | |
| else: | |
| model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks, | |
| data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) | |
| loss = model_out['loss'] | |
| data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1] | |
| data_time = torch.tensor(data_time) | |
| logger_logs = model_out.copy() | |
| # if struc_flag or sc_flag: | |
| # logger_logs['reward'] = model_out['reward'].mean() | |
| # logger_logs['reward_var'] = model_out['reward'].var(1).mean() | |
| if struc_flag or sc_flag: | |
| logger_logs['reward'] = model_out['reward'].mean() | |
| for k in ['CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']: | |
| if k in model_out: | |
| logger_logs[k] = model_out[k] | |
| if struc_flag: | |
| logger_logs['reward_var'] = model_out['reward'].var(1).mean() | |
| logger_logs['scheduled_sampling_prob'] = torch.tensor( | |
| self.model.ss_prob) | |
| # logger_logs['training_loss'] = loss | |
| logger_logs['loss'] = loss | |
| logger_logs['data_time'] = data_time | |
| # UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 | |
| # Please use self.log(...) inside the lightningModule instead. | |
| # # log on a step or aggregate epoch metric to the logger and/or progress bar | |
| # # (inside LightningModule) | |
| # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) | |
| # warnings.warn(*args, **kwargs) | |
| # UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 | |
| # Please use self.log(...) inside the lightningModule instead. | |
| # output = { | |
| # 'loss': loss, | |
| # 'log': logger_logs, | |
| # 'progress_bar': {'data_time': data_time} | |
| # } | |
| for k, v in logger_logs.items(): | |
| if k in ['reward', 'reward_var', 'data_time', 'CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']: | |
| self.log('train/'+k, v, prog_bar=True) | |
| else: | |
| self.log('train/'+k, v) | |
| return loss | |
| def validation_step(self, data, batch_idx): | |
| model = self.model | |
| crit = self.lw_model.crit | |
| opt = self.opt | |
| eval_kwargs = {'dataset': opt.input_json} | |
| eval_kwargs.update(vars(opt)) | |
| # CLIPScore | |
| use_grammar = getattr(self.opt, 'use_grammar', False) | |
| joint_out = getattr(self.opt, 'joint_out', False) | |
| verbose = eval_kwargs.get('verbose', True) | |
| verbose_beam = eval_kwargs.get('verbose_beam', 0) | |
| verbose_loss = eval_kwargs.get('verbose_loss', 1) | |
| # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) | |
| # lang_eval = eval_kwargs.get('language_eval', 0) | |
| dataset = eval_kwargs.get('dataset', 'coco') | |
| beam_size = eval_kwargs.get('beam_size', 1) | |
| sample_n = eval_kwargs.get('sample_n', 1) | |
| remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) | |
| # Use this nasty way to make other code clean since it's a global configuration | |
| os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) | |
| predictions = [] | |
| n_predictions = [] | |
| loss = torch.tensor(0) | |
| if data.get('labels', None) is not None and verbose_loss: | |
| # forward the model to get loss | |
| tmp = [data['fc_feats'], data['att_feats'], | |
| data['labels'], data['masks'], data['att_masks']] | |
| fc_feats, att_feats, labels, masks, att_masks = tmp | |
| loss = crit(model(fc_feats, att_feats, | |
| labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) | |
| # forward the model to also get generated samples for each image | |
| # Only leave one feature for each image, in case duplicate sample | |
| tmp_eval_kwargs = eval_kwargs.copy() | |
| tmp_eval_kwargs.update({'sample_n': 1}) | |
| seq, seq_logprobs = model( | |
| fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') | |
| seq = seq.data | |
| entropy = - (F.softmax(seq_logprobs, dim=2) * | |
| seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) | |
| perplexity = - \ | |
| seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze( | |
| 2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) | |
| # Print beam search | |
| if beam_size > 1 and verbose_beam: | |
| for i in range(fc_feats.shape[0]): | |
| print('\n'.join([utils.decode_sequence(model.vocab, _[ | |
| 'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) | |
| print('--' * 10) | |
| sents = utils.decode_sequence(model.vocab, seq) | |
| # if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': | |
| # text_feat = self.lw_model.clipscore_model.text_extract(sents) | |
| text_feat = self.val_clipscore_model.text_extract(sents, proj_norm=False) | |
| text_cont_feat = self.val_clipscore_model.clip_model.text_projection(text_feat) | |
| text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) | |
| vis_feat = data['clip_vis_feats'] | |
| # if self.opt.clipscore_mode == 'clip_s': | |
| # clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s') | |
| # elif self.opt.clipscore_mode == 'refclip_s': | |
| clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s') | |
| # ref_text = utils.decode_sequence(model.vocab, data['gts']) | |
| gt_indices = torch.arange(0, len(data['gts'])) | |
| data_gts = [data['gts'][_] for _ in gt_indices.tolist()] | |
| B = len(data_gts) | |
| gts = [] | |
| gts_valid_mask = [] | |
| max_n_refs = max([len(_gts) for _gts in data_gts]) | |
| for i in range(len(data_gts)): | |
| _gts = utils.decode_sequence(model.vocab, data_gts[i]) | |
| # pad references | |
| n_ref = len(_gts) | |
| _gts.extend([''] * (max_n_refs - n_ref)) | |
| gts.extend(_gts) | |
| gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref)) | |
| assert len(gts) == B * max_n_refs | |
| assert len(gts_valid_mask) == B * max_n_refs | |
| ref_text = gts | |
| ref_text_mask = gts_valid_mask | |
| refclip_s = self.val_clipscore_model( | |
| text_feat=text_cont_feat, img_feat=vis_feat, | |
| ref_text=ref_text, ref_text_mask=ref_text_mask, mode='refclip_s') | |
| # use_grammar = getattr(self.opt, 'use_grammar', False) | |
| # joint_out = getattr(self.opt, 'joint_out', False) | |
| if use_grammar and not joint_out: | |
| with torch.no_grad(): | |
| # grammar_logit = self.val_clipscore_model.grammar_score_head(text_feat.view(-1, 512)) | |
| grammar_logit = self.lw_model.clipscore_model.grammar_score_head(text_feat.view(-1, 512)) | |
| grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1] | |
| # BERTScore | |
| if next(self.bert_scorer._model.parameters()).device != self.device: | |
| self.bert_scorer._model.to(self.device) | |
| self.bert_scorer.device = self.device | |
| # [B*K] -> [B, K] | |
| ref_text_per_example = [] | |
| for i in range(B): | |
| ref_text_list_example = [] | |
| for k in range(max_n_refs): | |
| ref = ref_text[i * max_n_refs + k] | |
| if len(ref) > 0: | |
| ref_text_list_example.append(ref) | |
| # assert len(ref_text_list_example) == max_n_refs | |
| ref_text_per_example.append(ref_text_list_example) | |
| assert len(ref_text_per_example) == B | |
| P, R, F1 = self.bert_scorer.score( | |
| sents, | |
| ref_text_per_example, | |
| ) | |
| bertscore_f1 = F1 | |
| # print('Example 5:') | |
| # for i in range(5): | |
| # print('Generated:', sents[i]) | |
| # print('ref_text:', ref_text_per_example[i]) | |
| # print('BERT-Score:', F1[i].item()) | |
| for k, sent in enumerate(sents): | |
| entry = {'image_id': data['infos'][k]['id'], 'caption': sent, | |
| 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()} | |
| if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': | |
| # if self.opt.clipscore_mode == 'clip_s': | |
| # entry['clipscore'] = clipscore[k].item() | |
| # entry['CLIP-S'] = clip_s[k].item() | |
| # elif self.opt.clipscore_mode == 'refclip_s': | |
| entry['CLIP-S'] = clip_s[k].item() | |
| entry['RefCLIP-S'] = refclip_s[k].item() | |
| if use_grammar and not joint_out: | |
| entry['grammar_prob'] = grammar_prob[k].item() | |
| # BERT-S | |
| entry['BERT-S'] = bertscore_f1[k].item() | |
| if eval_kwargs.get('dump_path', 0) == 1: | |
| entry['file_name'] = data['infos'][k]['file_path'] | |
| predictions.append(entry) | |
| if eval_kwargs.get('dump_images', 0) == 1: | |
| # dump the raw image to vis/ folder | |
| cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \ | |
| '" vis/imgs/img' + \ | |
| str(len(predictions)) + '.jpg' # bit gross | |
| print(cmd) | |
| os.system(cmd) | |
| if verbose: | |
| print('image %s: %s' % | |
| (entry['image_id'], entry['caption'])) | |
| if sample_n > 1: | |
| eval_utils.eval_split_n(model, n_predictions, [ | |
| fc_feats, att_feats, att_masks, data], eval_kwargs) | |
| output = { | |
| # 'val_loss': loss, | |
| 'loss': loss, | |
| 'predictions': predictions, | |
| 'n_predictions': n_predictions, | |
| } | |
| return output | |
| def test_step(self, *args, **kwargs): | |
| return self.validation_step(*args, **kwargs) | |
| def validation_epoch_end(self, outputs, split='val'): | |
| outputs = d2comm.gather(outputs) | |
| # master node | |
| if d2comm.is_main_process(): | |
| assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0 | |
| outputs = sum(outputs, []) | |
| opt = self.opt | |
| # val_loss_mean = sum([_['val_loss'] | |
| # val_loss_mean = sum([_['val_loss'].cpu() | |
| val_loss_mean = sum([_['loss'].cpu() | |
| for _ in outputs]) / len(outputs) | |
| predictions = sum([_['predictions'] for _ in outputs], []) | |
| if len(outputs[0]['n_predictions']) != 0: | |
| n_predictions = sum([_['n_predictions'] for _ in outputs], []) | |
| else: | |
| n_predictions = [] | |
| lang_stats = None | |
| if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]: | |
| n_predictions = sorted( | |
| n_predictions, key=lambda x: x['perplexity']) | |
| if not os.path.isdir('eval_results'): | |
| os.mkdir('eval_results') | |
| torch.save((predictions, n_predictions), os.path.join( | |
| 'eval_results/', '.saved_pred_' + opt.id + '_' + split + '.pth')) | |
| if opt.language_eval: | |
| lang_stats = eval_utils.language_eval( | |
| opt.input_json, predictions, n_predictions, vars(opt), split) | |
| if opt.reduce_on_plateau: | |
| optimizer = self.trainer.optimizers[0] | |
| if 'CIDEr' in lang_stats: | |
| optimizer.scheduler_step(-lang_stats['CIDEr']) | |
| else: | |
| optimizer.scheduler_step(val_loss_mean) | |
| # out = { | |
| # 'val_loss': val_loss_mean | |
| # } | |
| out = { | |
| 'loss': val_loss_mean | |
| } | |
| out.update(lang_stats) | |
| # out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean | |
| if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': | |
| # if self.opt.clipscore_mode == 'clip_s': | |
| # out['clipscore'] = sum([p['clipscore'] for p in predictions]) / len(predictions) | |
| # print('CLIPScore', out['clipscore']) | |
| # out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions) | |
| # print('CLIP-S', out['CLIP-S']) | |
| # elif self.opt.clipscore_mode == 'refclip_s': | |
| out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions) | |
| print('CLIP-S', out['CLIP-S']) | |
| out['RefCLIP-S'] = sum([p['RefCLIP-S'] for p in predictions]) / len(predictions) | |
| print('RefCLIP-S', out['RefCLIP-S']) | |
| if getattr(self.opt, 'use_grammar', False) and not getattr(self.opt, 'joint_out', False): | |
| out['grammar_prob'] = sum([p['grammar_prob'] for p in predictions]) / len(predictions) | |
| print('grammar_prob', out['grammar_prob']) | |
| out['BERT-S'] = sum([p['BERT-S'] for p in predictions]) / len(predictions) | |
| print('BERT-S', out['BERT-S']) | |
| else: | |
| out = {} | |
| out = d2comm.all_gather(out)[0] # Only the one from master node | |
| assert len(out) > 0 # make sure the head has index 0 | |
| # must all be tensors | |
| out = {k: torch.tensor(v) if not torch.is_tensor( | |
| v) else v for k, v in out.items()} | |
| # return { | |
| # 'progress_bar': {'val_loss': out['val_loss']}, | |
| # 'log': out, | |
| # } | |
| for k, v in out.items(): | |
| # if k in ['loss', 'clipscore', 'RefCLIP-S', 'CIDEr']: | |
| # if split != 'test': | |
| # self.log(f'{split}/{k}', v, prog_bar=True) | |
| # elif k == 'to_monitor': | |
| # if split != 'test': | |
| # self.log(f'{split}/{k}', v) | |
| # else: | |
| self.log(f'{split}/{k}', v) | |
| def test_epoch_end(self, outputs): | |
| # out = self.validation_epoch_end(outputs, 'test') | |
| # out['progress_bar'] = { | |
| # # 'test_loss': out['progress_bar']['val_loss'] | |
| # 'test_loss': out['progress_bar']['loss'] | |
| # } | |
| # out['log']['test_loss'] = out['log']['val_loss'] | |
| # del out['log']['val_loss'] | |
| # del out['log']['to_monitor'] | |
| # out['log'] = {'test_'+k if 'test' not in k else k:v \ | |
| # for k,v in out['log'].items()} | |
| # return out | |
| self.validation_epoch_end(outputs, 'test') | |
| def configure_optimizers(self): | |
| opt = self.opt | |
| model = self.model | |
| parameters = [p for p in model.parameters() if p.requires_grad] | |
| if opt.noamopt: | |
| # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer' | |
| optimizer = utils.get_std_opt( | |
| model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) | |
| elif opt.reduce_on_plateau: | |
| # optimizer = utils.build_optimizer(model.parameters(), opt) | |
| optimizer = utils.build_optimizer(parameters, opt) | |
| optimizer = utils.ReduceLROnPlateau(optimizer, | |
| factor=opt.reduce_on_plateau_factor, | |
| patience=opt.reduce_on_plateau_patience) | |
| else: | |
| # optimizer = utils.build_optimizer(model.parameters(), opt) | |
| optimizer = utils.build_optimizer(parameters, opt) | |
| return [optimizer], [] | |
| def optimizer_step(self, epoch, batch_idx, optimizer, | |
| optimizer_idx, *args, **kwargs): | |
| # warm up lr | |
| opt = self.opt | |
| iteration = self.trainer.global_step | |
| if opt.use_warmup and (iteration < opt.noamopt_warmup): | |
| opt.current_lr = opt.learning_rate * \ | |
| (iteration+1) / opt.noamopt_warmup | |
| utils.set_lr(optimizer, opt.current_lr) | |
| super().optimizer_step(epoch, batch_idx, optimizer, | |
| optimizer_idx, *args, **kwargs) | |
| def state_dict(self): | |
| """ | |
| Save the model state dict as well as opt and vocab | |
| """ | |
| state_dict = self.model.state_dict() | |
| device = next(iter(state_dict.values())).device | |
| assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case' | |
| state_dict.update({ | |
| '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device), | |
| '_opt': utils.serialize_to_tensor(self.opt).to(device) | |
| }) | |
| return state_dict | |
| def load_state_dict(self, state_dict=None, strict=True): | |
| if '_vocab' in state_dict: | |
| self.model.vocab = utils.deserialize(state_dict['_vocab']) | |
| del state_dict['_vocab'] | |
| # elif strict: | |
| # raise KeyError | |
| if '_opt' in state_dict: | |
| saved_model_opt = utils.deserialize(state_dict['_opt']) | |
| del state_dict['_opt'] | |
| opt = self.opt | |
| # Make sure the saved opt is compatible with the curren topt | |
| need_be_same = ["caption_model", | |
| "rnn_type", "rnn_size", "num_layers"] | |
| for checkme in need_be_same: | |
| if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ | |
| getattr(opt, checkme) in ['updown', 'topdown']: | |
| continue | |
| assert getattr(saved_model_opt, checkme) == getattr( | |
| opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme | |
| # elif strict: | |
| # raise KeyError | |
| self.model.load_state_dict(state_dict, strict) | |
| class OnEpochStartCallback(pl.Callback): | |
| def on_epoch_start(self, trainer, pl_module): | |
| # Update lr/training stage/scheduled sampling prob etc. | |
| opt = pl_module.opt | |
| model = pl_module.model | |
| epoch = trainer.current_epoch | |
| optimizer = trainer.optimizers[0] | |
| if not opt.noamopt and not opt.reduce_on_plateau: | |
| # Assign the learning rate | |
| if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: | |
| frac = ( | |
| epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every | |
| decay_factor = opt.learning_rate_decay_rate ** frac | |
| opt.current_lr = opt.learning_rate * decay_factor | |
| else: | |
| opt.current_lr = opt.learning_rate | |
| utils.set_lr(optimizer, opt.current_lr) # set the decayed rate | |
| # Assign the scheduled sampling prob | |
| if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: | |
| frac = ( | |
| epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every | |
| opt.ss_prob = min(opt.scheduled_sampling_increase_prob * | |
| frac, opt.scheduled_sampling_max_prob) | |
| model.ss_prob = opt.ss_prob | |
| # If start self critical training | |
| if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: | |
| sc_flag = True | |
| init_scorer(opt.cached_tokens) | |
| else: | |
| sc_flag = False | |
| # If start structure loss training | |
| if opt.structure_after != -1 and epoch >= opt.structure_after: | |
| struc_flag = True | |
| init_scorer(opt.cached_tokens) | |
| else: | |
| struc_flag = False | |
| pl_module.struc_flag = struc_flag | |
| pl_module.sc_flag = sc_flag | |
| class ModelCheckpoint(pl.callbacks.ModelCheckpoint): | |
| def on_keyboard_interrupt(self, trainer, pl_module): | |
| # Save model when keyboard interrupt | |
| filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') | |
| self._save_model(filepath) | |
| opt = opts.parse_opt() | |
| checkpoint_callback = ModelCheckpoint( | |
| filepath=opt.checkpoint_path, | |
| # dirpath=opt.checkpoint_path, | |
| save_last=True, | |
| save_top_k=1, | |
| verbose=True, | |
| # monitor='to_monitor', | |
| # monitor='val/to_monitor', | |
| monitor='val/CIDEr', | |
| mode='max', | |
| # prefix=opt.id+'_', | |
| prefix=opt.id, | |
| # filename=f'{opt.id}_', | |
| ) | |
| verbose = True | |
| # import torch | |
| # if torch.cuda.current_device() in [0, -1]: | |
| if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': | |
| verbose = False | |
| if verbose: | |
| print(opt) | |
| print(""" | |
| val_image_use, | |
| save_checkpoint_very | |
| save_every_epoch, | |
| save_history-ckpt will be ignored. | |
| """) | |
| # Lightning defines batch size as batch size per gpu | |
| assert opt.batch_size % torch.cuda.device_count() == 0 | |
| opt.batch_size = opt.batch_size // torch.cuda.device_count() | |
| # If resume from last checkpoint | |
| # if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')): | |
| # resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt') | |
| if opt.start_from is not None: | |
| resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt') | |
| if os.path.isfile(resume_from): | |
| if verbose: | |
| print('Loading checkpoint from', resume_from) | |
| else: | |
| print("Checkpoint not found:", resume_from) | |
| resume_from = None | |
| else: | |
| resume_from = None | |
| from pytorch_lightning.loggers import WandbLogger | |
| wandb_logger = WandbLogger( | |
| project='CLIP-ViL-COCOCaption', | |
| name=opt.id, | |
| ) | |
| if verbose: | |
| wandb_logger.experiment.config.update(opt) | |
| from pathlib import Path | |
| import glob | |
| import wandb | |
| # src_dir = Path(__file__).resolve().parent.parent | |
| glob_str = "**/*.py" | |
| base_path = './' | |
| wandb.save(glob_str=glob_str, base_path=base_path) | |
| # code = wandb.Artifact('project-source', type='code') | |
| # for path in glob.glob('**/*.py', recursive=True): | |
| # code.add_file(path, name='source/'+path) | |
| # print(path) | |
| # wandb.run.use_artifact(code) | |
| lit = LitModel(opt) | |
| # warning grad_clip_mode is ignored. | |
| trainer = pl.Trainer( | |
| callbacks=[ | |
| OnEpochStartCallback(), | |
| # pl.callbacks.lr_logger.LearningRateLogger() | |
| pl.callbacks.LearningRateMonitor() | |
| ], | |
| default_root_dir=opt.checkpoint_path, | |
| resume_from_checkpoint=resume_from, | |
| distributed_backend='ddp', | |
| check_val_every_n_epoch=1, | |
| max_epochs=opt.max_epochs, | |
| gradient_clip_val=opt.grad_clip_value, | |
| gpus=torch.cuda.device_count(), | |
| checkpoint_callback=checkpoint_callback, | |
| log_gpu_memory='min_max', | |
| # log_save_interval=opt.losses_log_every, | |
| log_every_n_steps=opt.losses_log_every, | |
| profiler=True, | |
| # profiler='simple', | |
| # row_log_interval=10, # what is it? | |
| flush_logs_every_n_steps=10, | |
| num_sanity_val_steps=0, | |
| # val_check_interval=0.01, | |
| # limit_train_batches=500, | |
| # progress_bar_refresh_rate=0, | |
| # fast_dev_run=True, | |
| precision=opt.precision, | |
| logger=wandb_logger | |
| ) | |
| if os.getenv('EVALUATE', '0') == '1': | |
| trainer.test(lit) | |
| else: | |
| trainer.fit(lit) | |