Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Time : 2022/4/21 5:30 下午 | |
| # @Author : JianingWang | |
| # @File : fusion_siamese.py | |
| from typing import Optional | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| from dataclasses import dataclass | |
| from torch.nn import BCEWithLogitsLoss | |
| from transformers import MegatronBertModel, MegatronBertPreTrainedModel | |
| from transformers.file_utils import ModelOutput | |
| from transformers.models.bert import BertPreTrainedModel, BertModel | |
| from transformers.activations import ACT2FN | |
| from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from loss.focal_loss import FocalLoss | |
| # from roformer import RoFormerPreTrainedModel, RoFormerModel | |
| class BertPooler(nn.Module): | |
| def __init__(self, hidden_size, hidden_act): | |
| super().__init__() | |
| self.dense = nn.Linear(hidden_size, hidden_size) | |
| # self.activation = nn.Tanh() | |
| self.activation = ACT2FN[hidden_act] | |
| # self.dropout = nn.Dropout(hidden_dropout_prob) | |
| def forward(self, features): | |
| x = features[:, 0, :] # take <s> token (equiv. to [CLS]) | |
| # x = self.dropout(x) | |
| x = self.dense(x) | |
| x = self.activation(x) | |
| return x | |
| class BertForFusionSiamese(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.bert = BertModel(config) | |
| self.hidden_size = config.hidden_size | |
| self.hidden_act = config.hidden_act | |
| self.bert_poor = BertPooler(self.hidden_size, self.hidden_act) | |
| self.dense_1 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.dense_2 = nn.Linear(self.hidden_size, self.hidden_size) | |
| if hasattr(config, "cls_dropout_rate"): | |
| cls_dropout_rate = config.cls_dropout_rate | |
| else: | |
| cls_dropout_rate = config.hidden_dropout_prob | |
| self.dropout = nn.Dropout(cls_dropout_rate) | |
| self.classifier = nn.Linear(3 * self.hidden_size, config.num_labels) | |
| self.init_weights() | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| pseudo_label=None, | |
| segment_spans=None, | |
| pseuso_proba=None | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| logits, outputs = None, None | |
| inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, | |
| "position_ids": position_ids, | |
| "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions, | |
| "output_hidden_states": output_hidden_states, "return_dict": return_dict} | |
| inputs = {k: v for k, v in inputs.items() if v is not None} | |
| outputs = self.bert(**inputs) | |
| if "sequence_output" in outputs: | |
| sequence_output = outputs.sequence_output # [bz, seq_len, dim] | |
| else: | |
| sequence_output = outputs[0] # [bz, seq_len, dim] | |
| cls_output = self.bert_poor(sequence_output) # [bz, dim] | |
| if segment_spans is not None: | |
| # 如果输入的是两个segment,则分别进行平均池化 | |
| seg1_embeddings, seg2_embeddings = list(), list() | |
| for ei, sentence_embeddings in enumerate(sequence_output): | |
| # sentence_embedding: [seq_len, dim] | |
| seg1_start, seg1_end, seg2_start, seg2_end = segment_spans[ei] | |
| # print("sentence_embeddings[seg1_start, seg1_end].shape=", sentence_embeddings[seg1_start, seg1_end].shape) | |
| # print("torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape=", torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape) | |
| seg1_embeddings.append(torch.mean(sentence_embeddings[seg1_start: seg1_end], 0)) # [dim] | |
| seg2_embeddings.append(torch.mean(sentence_embeddings[seg2_start: seg2_end], 0)) # [dim] | |
| seg1_embeddings, seg2_embeddings = torch.stack(seg1_embeddings), torch.stack(seg2_embeddings) # [bz, dim] | |
| # print("seg1_embeddings.shape=", seg1_embeddings.shape) | |
| seg1_embeddings = self.bert_poor.activation(self.dense_1(seg1_embeddings)) | |
| seg2_embeddings = self.bert_poor.activation(self.dense_1(seg2_embeddings)) | |
| cls_output = torch.cat([cls_output, seg1_embeddings, seg2_embeddings], dim=-1) # [bz, 3*dim] | |
| # cls_output = cls_output + seg1_embeddings + seg2_embeddings # [bz, dim] | |
| pooler_output = self.dropout(cls_output) | |
| # pooler_output = self.LayerNorm(pooler_output) | |
| logits = self.classifier(pooler_output) | |
| loss = None | |
| if labels is not None: | |
| # loss_fct = FocalLoss() | |
| loss_fct = CrossEntropyLoss() | |
| # 伪标签 | |
| if pseudo_label is not None: | |
| train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1] | |
| train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1] | |
| train_loss = loss_fct(train_logits.view(-1, self.num_labels), | |
| train_labels.view(-1)) if train_labels.nelement() else 0 | |
| pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels), | |
| pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0 | |
| loss = 0.9 * train_loss + 0.1 * pseudo_loss | |
| else: | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| class BertForWSC(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.bert = BertModel(config) | |
| self.hidden_size = config.hidden_size | |
| self.hidden_act = config.hidden_act | |
| self.bert_poor = BertPooler(self.hidden_size, self.hidden_act) | |
| self.dense_1 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.dense_2 = nn.Linear(self.hidden_size, self.hidden_size) | |
| if hasattr(config, "cls_dropout_rate"): | |
| cls_dropout_rate = config.cls_dropout_rate | |
| else: | |
| cls_dropout_rate = config.hidden_dropout_prob | |
| self.dropout = nn.Dropout(cls_dropout_rate) | |
| self.classifier = nn.Linear(2 * self.hidden_size, config.num_labels) | |
| self.init_weights() | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| pseudo_label=None, | |
| span=None, | |
| pseuso_proba=None | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| logits, outputs = None, None | |
| inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, | |
| "position_ids": position_ids, | |
| "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions, | |
| "output_hidden_states": output_hidden_states, "return_dict": return_dict} | |
| inputs = {k: v for k, v in inputs.items() if v is not None} | |
| outputs = self.bert(**inputs) | |
| if "sequence_output" in outputs: | |
| sequence_output = outputs.sequence_output # [bz, seq_len, dim] | |
| else: | |
| sequence_output = outputs[0] # [bz, seq_len, dim] | |
| # cls_output = self.bert_poor(sequence_output) # [bz, dim] | |
| # 如果输入的是两个span,则分别进行平均池化 | |
| seg1_embeddings, seg2_embeddings = list(), list() | |
| # print("span=", span) | |
| for ei, sentence_embeddings in enumerate(sequence_output): | |
| # sentence_embedding: [seq_len, dim] | |
| seg1_start, seg1_end, seg2_start, seg2_end = span[ei] | |
| # print("sentence_embeddings[seg1_start, seg1_end].shape=", sentence_embeddings[seg1_start, seg1_end].shape) | |
| # print("torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape=", torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape) | |
| seg1_embeddings.append(torch.mean(sentence_embeddings[seg1_start+1: seg1_end], 0)) # [dim] | |
| seg2_embeddings.append(torch.mean(sentence_embeddings[seg2_start+1: seg2_end], 0)) # [dim] | |
| seg1_embeddings, seg2_embeddings = torch.stack(seg1_embeddings), torch.stack(seg2_embeddings) # [bz, dim] | |
| # print("seg1_embeddings.shape=", seg1_embeddings.shape) | |
| # seg1_embeddings = self.bert_poor.activation(self.dense_1(seg1_embeddings)) | |
| # seg2_embeddings = self.bert_poor.activation(self.dense_1(seg2_embeddings)) | |
| cls_output = torch.cat([seg1_embeddings, seg2_embeddings], dim=-1) # [bz, 3*dim] | |
| # cls_output = cls_output + seg1_embeddings + seg2_embeddings # [bz, dim] | |
| pooler_output = self.dropout(cls_output) | |
| # pooler_output = self.LayerNorm(pooler_output) | |
| logits = self.classifier(pooler_output) | |
| loss = None | |
| if labels is not None: | |
| # loss_fct = FocalLoss() | |
| loss_fct = CrossEntropyLoss() | |
| # 伪标签 | |
| if pseudo_label is not None: | |
| train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1] | |
| train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1] | |
| train_loss = loss_fct(train_logits.view(-1, self.num_labels), | |
| train_labels.view(-1)) if train_labels.nelement() else 0 | |
| pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels), | |
| pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0 | |
| loss = 0.9 * train_loss + 0.1 * pseudo_loss | |
| else: | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |