Spaces:
Runtime error
Runtime error
| import warnings | |
| from pytorch_lightning import LightningModule | |
| from fengshen.models import transformer_utils | |
| import torch | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| from transformers.file_utils import * | |
| from transformers.modeling_outputs import * | |
| from transformers.models.bart import * | |
| from transformers.models.bart.modeling_bart import BartClassificationHead | |
| _CONFIG_FOR_DOC = "BartConfig" | |
| # ------------------------ ZZ: CBart addition ------------------------ | |
| def _reorder_buffer(attn_cache, new_order): | |
| for k, input_buffer_k in attn_cache.items(): | |
| if input_buffer_k is not None: | |
| attn_cache[k] = input_buffer_k.index_select(0, new_order) | |
| return attn_cache | |
| def _make_linear_from_emb(emb): | |
| vocab_size, emb_size = emb.weight.shape | |
| lin_layer = nn.Linear(vocab_size, emb_size, bias=False) | |
| lin_layer.weight.data = emb.weight.data | |
| return lin_layer | |
| BART_GENERATION_EXAMPLE = r""" | |
| Summarization example:: | |
| >>> from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig | |
| >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') | |
| >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') | |
| >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." | |
| >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') | |
| >>> # Generate Summary | |
| >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) | |
| >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) | |
| Mask filling example:: | |
| >>> from transformers import BartTokenizer, BartForConditionalGeneration | |
| >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') | |
| >>> TXT = "My friends are <mask> but they eat too many carbs." | |
| >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') | |
| >>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] | |
| >>> logits = model(input_ids).logits | |
| >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() | |
| >>> probs = logits[0, masked_index].softmax(dim=0) | |
| >>> values, predictions = probs.topk(5) | |
| >>> tokenizer.decode(predictions).split() | |
| """ | |
| class CBartLMOutput(ModelOutput): | |
| """ | |
| Base class for CBart specific language models outputs. | |
| Args: | |
| .... | |
| """ | |
| loss: Optional[torch.FloatTensor] = None | |
| encoder_loss: Optional[torch.FloatTensor] = None | |
| decoder_loss: Optional[torch.FloatTensor] = None | |
| encoder_logits: torch.FloatTensor = None | |
| logits: torch.FloatTensor = None | |
| past_key_values: Optional[Tuple[torch.FloatTensor]] = None | |
| decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| encoder_last_hidden_state: Optional[torch.FloatTensor] = None | |
| encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| class BartForTextInfill(BartPretrainedModel): | |
| """ | |
| this class is designed for text infilling. | |
| During training, the encoder is used to predict replace, insert, | |
| and the decoder is used to generate original input. | |
| Compared with BartForConditionalGeneration class, | |
| we add a module over the encoder and add a new loss for the encoder. | |
| """ | |
| base_model_prefix = "model" | |
| authorized_missing_keys = [r"final_logits_bias", | |
| r"encoder\.version", r"decoder\.version"] | |
| def __init__(self, config: BartConfig): | |
| super().__init__(config) | |
| base_model = BartModel(config) | |
| self.model = base_model | |
| self.register_buffer("final_logits_bias", torch.zeros( | |
| (1, self.model.shared.num_embeddings))) | |
| # print( config.encoder_loss_type, config.num_labels) | |
| # add a new attribute into BartConfig class (revise BartConfig) | |
| self.encoder_loss_type = config.encoder_loss_type | |
| self.num_labels = config.num_labels | |
| if self.encoder_loss_type == 0: # 0 is classification loss, 1 is regression loss | |
| # add a classification module for the encoder | |
| self.classification_head = BartClassificationHead( | |
| config.d_model, config.d_model, config.num_labels, config.classif_dropout, | |
| ) | |
| else: | |
| # add a regression module for the encoder | |
| self.classification_head = BartClassificationHead( | |
| config.d_model, config.d_model, 1, config.classif_dropout, | |
| ) | |
| self.model._init_weights(self.classification_head.dense) | |
| self.model._init_weights(self.classification_head.out_proj) | |
| self.loss_weight = config.loss_weight | |
| self.register_buffer("label_weights", torch.zeros((self.num_labels))) | |
| def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: | |
| old_num_tokens = self.model.shared.num_embeddings | |
| new_embeddings = super().resize_token_embeddings(new_num_tokens) | |
| self.model.shared = new_embeddings | |
| self._resize_final_logits_bias(new_num_tokens, old_num_tokens) | |
| return new_embeddings | |
| def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: | |
| if new_num_tokens <= old_num_tokens: | |
| new_bias = self.final_logits_bias[:, :new_num_tokens] | |
| else: | |
| extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), | |
| device=self.final_logits_bias.device) | |
| new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) | |
| self.register_buffer("final_logits_bias", new_bias) | |
| def forward( | |
| self, | |
| input_ids, | |
| attention_mask=None, | |
| encoder_outputs=None, | |
| decoder_input_ids=None, | |
| decoder_attention_mask=None, | |
| past_key_values=None, | |
| encoder_labels=None, | |
| labels=None, | |
| use_cache=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=True, | |
| **unused, | |
| ): | |
| r""" | |
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | |
| Labels for computing the masked language modeling loss. | |
| Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). | |
| Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens | |
| with labels in ``[0, ..., config.vocab_size]``. | |
| Returns: | |
| Conditional generation example:: | |
| # Mask filling only works for bart-large | |
| from transformers import BartTokenizer, BartForConditionalGeneration | |
| tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') | |
| TXT = "My friends are <mask> but they eat too many carbs." | |
| model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') | |
| input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] | |
| logits = model(input_ids).logits | |
| masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() | |
| probs = logits[0, masked_index].softmax(dim=0) | |
| values, predictions = probs.topk(5) | |
| tokenizer.decode(predictions).split() | |
| # ['good', 'great', 'all', 'really', 'very'] | |
| """ | |
| if "lm_labels" in unused: | |
| warnings.warn( | |
| "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", | |
| FutureWarning, | |
| ) | |
| labels = unused.pop("lm_labels") | |
| if "decoder_cached_states" in unused: | |
| warnings.warn( | |
| "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.", | |
| FutureWarning, | |
| ) | |
| decoder_past_key_values = unused.pop("decoder_cached_states") | |
| return_dict = return_dict if return_dict is not None else False | |
| if labels is not None: | |
| use_cache = False | |
| outputs = self.model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| decoder_input_ids=decoder_input_ids, | |
| encoder_outputs=encoder_outputs, | |
| decoder_attention_mask=decoder_attention_mask, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # logits and loss for the encoder | |
| # last hidden state | |
| encoder_last_hidden_state = outputs['encoder_last_hidden_state'] | |
| # eos_mask = input_ids.eq(self.config.eos_token_id) | |
| # if len(torch.unique(eos_mask.sum(1))) > 1: | |
| # raise ValueError("All examples must have the same number of <eos> tokens.") | |
| # sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] | |
| encoder_logits = self.classification_head(encoder_last_hidden_state) | |
| encoder_loss = None | |
| if encoder_labels is not None: | |
| # classification loss | |
| if self.encoder_loss_type == 0: | |
| # ZZ: seems like MSE loss does not support weighting, so only CEL has weighting applied for now | |
| loss_fct = nn.CrossEntropyLoss(weight=self.label_weights) | |
| encoder_loss = loss_fct( | |
| encoder_logits.view(-1, self.config.num_labels), encoder_labels.view(-1)) | |
| # regression loss | |
| else: | |
| encoder_logits = encoder_logits.view( | |
| encoder_logits.size(0), -1) | |
| encoder_logits = torch.sigmoid( | |
| encoder_logits) * self.num_labels - 0.5 | |
| loss_fct = nn.MSELoss(reduction='none') | |
| _loss = loss_fct(encoder_logits, encoder_labels) | |
| encoder_loss = torch.mean(_loss[encoder_labels >= 0]) | |
| # encoder_loss =_loss[encoder_labels>=0] | |
| # logits and loss for the decoder | |
| lm_logits = F.linear( | |
| outputs[0], self.model.shared.weight, bias=self.final_logits_bias) | |
| masked_lm_loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| # TODO(SS): do we need to ignore pad tokens in labels? | |
| masked_lm_loss = loss_fct( | |
| lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) | |
| loss = None | |
| if masked_lm_loss is not None and encoder_loss is not None: | |
| loss = encoder_loss * self.loss_weight + masked_lm_loss | |
| if not return_dict: | |
| output = (lm_logits,) + outputs[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return CBartLMOutput( | |
| loss=loss, | |
| encoder_loss=encoder_loss, | |
| decoder_loss=masked_lm_loss, | |
| encoder_logits=encoder_logits, | |
| logits=lm_logits, | |
| past_key_values=outputs.past_key_values, | |
| decoder_hidden_states=outputs.decoder_hidden_states, | |
| decoder_attentions=outputs.decoder_attentions, | |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
| encoder_hidden_states=outputs.encoder_hidden_states, | |
| encoder_attentions=outputs.encoder_attentions, | |
| ) | |
| def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs): | |
| assert past is not None, "past has to be defined for encoder_outputs" | |
| encoder_outputs, past_key_values = past | |
| return { | |
| "input_ids": None, # encoder_outputs is defined. input_ids not needed | |
| "encoder_outputs": encoder_outputs, | |
| "past_key_values": past_key_values, | |
| "decoder_input_ids": decoder_input_ids, | |
| "attention_mask": attention_mask, | |
| # change this to avoid caching (presumably for debugging) | |
| "use_cache": use_cache, | |
| } | |
| def adjust_logits_during_generation(self, logits, cur_len, max_length): | |
| if cur_len == 1: | |
| self._force_token_ids_generation(logits, self.config.bos_token_id) | |
| if cur_len == max_length - 1 and self.config.eos_token_id is not None: | |
| self._force_token_ids_generation(logits, self.config.eos_token_id) | |
| return logits | |
| def _force_token_ids_generation(self, scores, token_ids) -> None: | |
| """force one of token_ids to be generated by setting prob of all other tokens to 0""" | |
| if isinstance(token_ids, int): | |
| token_ids = [token_ids] | |
| all_but_token_ids_mask = torch.tensor( | |
| [x for x in range(self.config.vocab_size) if x not in token_ids], | |
| dtype=torch.long, | |
| device=next(self.parameters()).device, | |
| ) | |
| assert len( | |
| scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]" | |
| scores[:, all_but_token_ids_mask] = -float("inf") | |
| def _reorder_cache(past, beam_idx): | |
| ((enc_out, enc_mask), past_key_values) = past | |
| reordered_past = [] | |
| for layer_past in past_key_values: | |
| # get the correct batch idx from decoder layer's batch dim for cross and self-attn | |
| layer_past_new = { | |
| attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() | |
| } | |
| reordered_past.append(layer_past_new) | |
| new_enc_out = enc_out if enc_out is None else enc_out.index_select( | |
| 0, beam_idx) | |
| new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select( | |
| 0, beam_idx) | |
| past = ((new_enc_out, new_enc_mask), reordered_past) | |
| return past | |
| def get_encoder(self): | |
| return self.model.encoder | |
| def get_output_embeddings(self): | |
| return _make_linear_from_emb(self.model.shared) # make it on the fly | |
| def get_encoder_logits(self, input_ids, attention_mask=None): | |
| # print(input_ids, attention_mask) | |
| # encoder_outputs = self.model.get_encoder_outputs( | |
| # self, | |
| # input_ids, | |
| # attention_mask=attention_mask, | |
| # output_attentions=None, | |
| # output_hidden_states=None, | |
| # return_dict=None, | |
| # ) | |
| encoder_outputs = self.model.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| # logits and loss for the encoder | |
| # last hidden state | |
| encoder_last_hidden_state = encoder_outputs['last_hidden_state'] | |
| encoder_logits = self.classification_head(encoder_last_hidden_state) | |
| # classification | |
| if self.encoder_loss_type == 0: | |
| # probs = torch.softmax(encoder_logits,dim=-1) | |
| pass | |
| # regression | |
| else: | |
| encoder_logits = encoder_logits.view(encoder_logits.size(0), -1) | |
| encoder_logits = torch.sigmoid( | |
| encoder_logits) * self.num_labels - 0.5 | |
| return encoder_outputs, encoder_logits | |
| class CBartLightning(LightningModule): | |
| def add_module_specific_args(parent_args): | |
| parser = parent_args.add_argument_group("CBart specific parameters") | |
| parser.add_argument('--num_labels', type=int, default=3) | |
| parser.add_argument('--encoder_loss_type', type=int, default=0) | |
| parser.add_argument('--loss_weight', type=float, default=1.0) | |
| parser.add_argument('--label_weights', type=float, nargs='+', default=[1.0, 1.0, 1.0]) | |
| parser.add_argument('--masked_lm', type=float, default=0) | |
| return parent_args | |
| def __init__( | |
| self, | |
| args, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters(args) | |
| self.model = BartForTextInfill.from_pretrained(args.model_path, num_labels=self.hparams.num_labels, | |
| encoder_loss_type=self.hparams.encoder_loss_type, | |
| loss_weight=self.hparams.loss_weight,) | |
| self.model.label_weights = torch.tensor( | |
| self.hparams.label_weights, dtype=torch.half) | |
| def forward(self, **inputs): | |
| return self.model(**inputs) | |
| def training_step(self, batch, batch_idx): | |
| outputs = self(**batch) | |
| return outputs | |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
| outputs = self(**batch) | |
| val_loss = outputs["loss"] | |
| return {"loss": val_loss} | |
| def setup(self, stage=None) -> None: | |
| if stage != "fit": | |
| return | |
| # Get dataloader by calling it - train_dataloader() is called after setup() by default | |
| train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() | |
| # Calculate total steps | |
| tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus) | |
| ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs) | |
| self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size | |
| def configure_optimizers(self): | |
| transformer_utils.configure_optimizers(self) | |