Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Time : 2022/3/15 21:26 | |
| # @Author : ruihan.wjn | |
| # @File : pk-plm.py | |
| """ | |
| This code is implemented for the paper ""Knowledge Prompting in Pre-trained Langauge Models for Natural Langauge Understanding"" | |
| """ | |
| from time import time | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.nn import CrossEntropyLoss | |
| from collections import OrderedDict | |
| from transformers.models.bert import BertPreTrainedModel, BertModel | |
| from transformers.models.roberta import RobertaModel, RobertaPreTrainedModel, RobertaTokenizer, RobertaForMaskedLM | |
| from transformers.models.deberta import DebertaModel, DebertaPreTrainedModel, DebertaTokenizer, DebertaForMaskedLM | |
| from transformers.models.bert.modeling_bert import BertOnlyMLMHead, BertPreTrainingHeads | |
| from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaLMHead | |
| from transformers.models.deberta.modeling_deberta import DebertaModel, DebertaLMPredictionHead | |
| """ | |
| kg enhanced corpus structure example: | |
| { | |
| "token_ids": [20, 46098, 3277, 680, 10, 4066, 278, 9, 11129, 4063, 877, 579, 8, 8750, 14720, 8, 22498, 548, | |
| 19231, 46098, 3277, 6, 25, 157, 25, 130, 3753, 46098, 3277, 4, 3684, 19809, 10960, 9, 5, 30731, 2788, 914, 5, | |
| 1675, 8151, 35], "entity_pos": [[8, 11], [13, 15], [26, 27]], | |
| "entity_qid": ["Q17582", "Q231978", "Q427013"], | |
| "relation_pos": null, | |
| "relation_pid": null | |
| } | |
| """ | |
| from enum import Enum | |
| class SiameseDistanceMetric(Enum): | |
| """ | |
| The metric for the contrastive loss | |
| """ | |
| EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) | |
| MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) | |
| COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y) | |
| class ContrastiveLoss(nn.Module): | |
| """ | |
| Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the | |
| two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. | |
| Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf | |
| :param model: SentenceTransformer model | |
| :param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used | |
| :param margin: Negative samples (label == 0) should have a distance of at least the margin value. | |
| :param size_average: Average by the size of the mini-batch. | |
| Example:: | |
| from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses | |
| from sentence_transformers.readers import InputExample | |
| model = SentenceTransformer("distilbert-base-nli-mean-tokens") | |
| train_examples = [InputExample(texts=["This is a positive pair", "Where the distance will be minimized"], label=1), | |
| InputExample(texts=["This is a negative pair", "Their distance will be increased"], label=0)] | |
| train_dataset = SentencesDataset(train_examples, model) | |
| train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) | |
| train_loss = losses.ContrastiveLoss(model=model) | |
| """ | |
| def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True): | |
| super(ContrastiveLoss, self).__init__() | |
| self.distance_metric = distance_metric | |
| self.margin = margin | |
| self.size_average = size_average | |
| def forward(self, sent_embs1, sent_embs2, labels: torch.Tensor): | |
| rep_anchor, rep_other = sent_embs1, sent_embs2 | |
| distances = self.distance_metric(rep_anchor, rep_other) | |
| losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2)) | |
| return losses.mean() if self.size_average else losses.sum() | |
| class NSPHead(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.seq_relationship = nn.Linear(config.hidden_size, 2) | |
| def forward(self, pooled_output): | |
| seq_relationship_score = self.seq_relationship(pooled_output) | |
| return seq_relationship_score | |
| class RoBertaKPPLMForProcessedWikiKGPLM(RobertaForMaskedLM): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.config = config | |
| # self.roberta = RobertaModel(config) | |
| try: | |
| classifier_dropout = ( | |
| config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob | |
| ) | |
| except: | |
| classifier_dropout = (config.hidden_dropout_prob) | |
| self.dropout = nn.Dropout(classifier_dropout) | |
| # self.cls = BertOnlyMLMHead(config) | |
| # self.lm_head = RobertaLMHead(config) # Masked Language Modeling head | |
| self.detector = NSPHead(config) # Knowledge Noise Detection head | |
| self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size) | |
| # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)]) | |
| self.contrastive_loss_fn = ContrastiveLoss() | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| labels=None, | |
| # entity_label=None, | |
| entity_candidate=None, | |
| # relation_label=None, | |
| relation_candidate=None, | |
| noise_detect_label=None, | |
| task_id=None, | |
| mask_id=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| # start_time = time() | |
| mlm_labels = labels | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # print("attention_mask.shape=", attention_mask.shape) | |
| # print("input_ids[0]=", input_ids[0]) | |
| # print("token_type_ids[0]=", token_type_ids[0]) | |
| # attention_mask = None | |
| outputs = self.roberta( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output = outputs[0] | |
| prediction_scores = self.lm_head(sequence_output) # mlm head | |
| # noise_detect_scores = self.detector(pooled_output) # knowledge noise detector use pool output | |
| noise_detect_scores = self.detector(sequence_output[:, 0, :]) # knowledge noise detector use cls embedding | |
| # ner | |
| # sequence_output = self.dropout(sequence_output) | |
| # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0) | |
| # mlm | |
| masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None | |
| total_loss = list() | |
| if mlm_labels is not None: | |
| loss_fct = CrossEntropyLoss() # -100 index = padding token | |
| masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1)) | |
| total_loss.append(masked_lm_loss) | |
| # if noise_detect_label is not None: | |
| # noise_detect_scores = noise_detect_scores[task_id == 1] | |
| # noise_detect_label = noise_detect_label[task_id == 1] | |
| # | |
| # if len(noise_detect_label) > 0: | |
| # loss_fct = CrossEntropyLoss() | |
| # noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1)) | |
| # total_loss.append(noise_detect_loss) | |
| entity_candidate = entity_candidate[task_id == 2] | |
| if len(entity_candidate) > 0: | |
| batch_size = entity_candidate.shape[0] | |
| candidate_num = entity_candidate.shape[1] | |
| # print("negative_num=", negative_num) | |
| # 获取被mask实体的embedding | |
| batch_entity_query_embedding = list() | |
| for ei, input_id in enumerate(input_ids[task_id == 2]): | |
| batch_entity_query_embedding.append( | |
| torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0)) # [hidden_dim] | |
| batch_entity_query_embedding = torch.stack(batch_entity_query_embedding) # [bz, dim] | |
| # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape) | |
| batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding) # [bz, dim] | |
| batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1)) # [bz, 11, dim] | |
| batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1]) # [bz * 11, dim] | |
| # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape) | |
| # 获得positive和negative的BERT表示 | |
| # entity_candidiate: [bz, 11, len] | |
| entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1]) # [bz * 11, len] | |
| entity_candidate_embedding = self.roberta.embeddings(input_ids=entity_candidate) # [bz * 11, len, dim] | |
| entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1)) # [bz * 11, dim] | |
| contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda() | |
| contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] | |
| entity_loss = self.contrastive_loss_fn( | |
| batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label | |
| ) | |
| total_loss.append(entity_loss) | |
| relation_candidate = relation_candidate[task_id == 3] | |
| if len(relation_candidate) > 0: | |
| batch_size = relation_candidate.shape[0] | |
| candidate_num = relation_candidate.shape[1] | |
| # print("negative_num=", negative_num) | |
| # 获取被mask relation的embedding | |
| batch_relation_query_embedding = list() | |
| for ei, input_id in enumerate(input_ids[task_id == 3]): | |
| batch_relation_query_embedding.append( | |
| torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0)) # [hidden_dim] | |
| batch_relation_query_embedding = torch.stack(batch_relation_query_embedding) # [bz, dim] | |
| # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape) | |
| batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding) # [bz, dim] | |
| batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat( | |
| (1, candidate_num, 1)) # [bz, 11, dim] | |
| batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1]) # [bz * 11, dim] | |
| # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape) | |
| # 获得positive和negative的BERT表示 | |
| # entity_candidiate: [bz, 11, len] | |
| relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1]) # [bz * 11, len] | |
| relation_candidate_embedding = self.roberta.embeddings(input_ids=relation_candidate) # [bz * 11, len, dim] | |
| relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1)) # [bz * 11, dim] | |
| contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda() | |
| contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] | |
| relation_loss = self.contrastive_loss_fn( | |
| batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label | |
| ) | |
| total_loss.append(relation_loss) | |
| total_loss = torch.sum(torch.stack(total_loss), -1) | |
| # end_time = time() | |
| # print("neural_mode_time: {}".format(end_time - start_time)) | |
| # print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0)) | |
| # print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape) | |
| # print("logits=", prediction_scores.argmax(2)) | |
| # print("logits.shape=", prediction_scores.argmax(2).shape) | |
| return OrderedDict([ | |
| ("loss", total_loss), | |
| ("mlm_loss", masked_lm_loss.unsqueeze(0)), | |
| # ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None), | |
| # ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None), | |
| # ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None), | |
| ("logits", prediction_scores.argmax(2)), | |
| # ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None), | |
| ]) | |
| class DeBertaKPPLMForProcessedWikiKGPLM(DebertaForMaskedLM): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.config = config | |
| # self.roberta = RobertaModel(config) | |
| try: | |
| classifier_dropout = ( | |
| config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob | |
| ) | |
| except: | |
| classifier_dropout = (config.hidden_dropout_prob) | |
| self.dropout = nn.Dropout(classifier_dropout) | |
| # self.cls = BertOnlyMLMHead(config) | |
| # self.lm_head = RobertaLMHead(config) # Masked Language Modeling head | |
| self.detector = NSPHead(config) # Knowledge Noise Detection head | |
| self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size) | |
| # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)]) | |
| self.contrastive_loss_fn = ContrastiveLoss() | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| labels=None, | |
| # entity_label=None, | |
| entity_candidate=None, | |
| # relation_label=None, | |
| relation_candidate=None, | |
| noise_detect_label=None, | |
| task_id=None, | |
| mask_id=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| # start_time = time() | |
| mlm_labels = labels | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # print("attention_mask.shape=", attention_mask.shape) | |
| # print("input_ids[0]=", input_ids[0]) | |
| # print("token_type_ids[0]=", token_type_ids[0]) | |
| # attention_mask = None | |
| outputs = self.deberta( | |
| input_ids, | |
| # attention_mask=attention_mask, | |
| attention_mask=None, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output = outputs[0] | |
| prediction_scores = self.cls(sequence_output) # mlm head | |
| # noise_detect_scores = self.detector(pooled_output) # knowledge noise detector use pool output | |
| noise_detect_scores = self.detector(sequence_output[:, 0, :]) # knowledge noise detector use cls embedding | |
| # ner | |
| # sequence_output = self.dropout(sequence_output) | |
| # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0) | |
| # mlm | |
| masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None | |
| total_loss = list() | |
| if mlm_labels is not None: | |
| loss_fct = CrossEntropyLoss() # -100 index = padding token | |
| masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1)) | |
| total_loss.append(masked_lm_loss) | |
| # if noise_detect_label is not None: | |
| # noise_detect_scores = noise_detect_scores[task_id == 1] | |
| # noise_detect_label = noise_detect_label[task_id == 1] | |
| # | |
| # if len(noise_detect_label) > 0: | |
| # loss_fct = CrossEntropyLoss() | |
| # noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1)) | |
| # total_loss.append(noise_detect_loss) | |
| entity_candidate = entity_candidate[task_id == 2] | |
| if len(entity_candidate) > 0: | |
| batch_size = entity_candidate.shape[0] | |
| candidate_num = entity_candidate.shape[1] | |
| # print("negative_num=", negative_num) | |
| # 获取被mask实体的embedding | |
| batch_entity_query_embedding = list() | |
| for ei, input_id in enumerate(input_ids[task_id == 2]): | |
| batch_entity_query_embedding.append( | |
| torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0)) # [hidden_dim] | |
| batch_entity_query_embedding = torch.stack(batch_entity_query_embedding) # [bz, dim] | |
| # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape) | |
| batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding) # [bz, dim] | |
| batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1)) # [bz, 11, dim] | |
| batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1]) # [bz * 11, dim] | |
| # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape) | |
| # 获得positive和negative的BERT表示 | |
| # entity_candidiate: [bz, 11, len] | |
| entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1]) # [bz * 11, len] | |
| entity_candidate_embedding = self.deberta.embeddings(input_ids=entity_candidate) # [bz * 11, len, dim] | |
| entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1)) # [bz * 11, dim] | |
| contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda() | |
| contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] | |
| entity_loss = self.contrastive_loss_fn( | |
| batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label | |
| ) | |
| total_loss.append(entity_loss) | |
| relation_candidate = relation_candidate[task_id == 3] | |
| if len(relation_candidate) > 0: | |
| batch_size = relation_candidate.shape[0] | |
| candidate_num = relation_candidate.shape[1] | |
| # print("negative_num=", negative_num) | |
| # 获取被mask relation的embedding | |
| batch_relation_query_embedding = list() | |
| for ei, input_id in enumerate(input_ids[task_id == 3]): | |
| batch_relation_query_embedding.append( | |
| torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0)) # [hidden_dim] | |
| batch_relation_query_embedding = torch.stack(batch_relation_query_embedding) # [bz, dim] | |
| # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape) | |
| batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding) # [bz, dim] | |
| batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat( | |
| (1, candidate_num, 1)) # [bz, 11, dim] | |
| batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1]) # [bz * 11, dim] | |
| # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape) | |
| # 获得positive和negative的BERT表示 | |
| # entity_candidiate: [bz, 11, len] | |
| relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1]) # [bz * 11, len] | |
| relation_candidate_embedding = self.deberta.embeddings(input_ids=relation_candidate) # [bz * 11, len, dim] | |
| relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1)) # [bz * 11, dim] | |
| contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda() | |
| contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] | |
| relation_loss = self.contrastive_loss_fn( | |
| batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label | |
| ) | |
| total_loss.append(relation_loss) | |
| total_loss = torch.sum(torch.stack(total_loss), -1) | |
| # end_time = time() | |
| # print("neural_mode_time: {}".format(end_time - start_time)) | |
| # print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0)) | |
| # print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape) | |
| # print("logits=", prediction_scores.argmax(2)) | |
| # print("logits.shape=", prediction_scores.argmax(2).shape) | |
| return OrderedDict([ | |
| ("loss", total_loss), | |
| ("mlm_loss", masked_lm_loss.unsqueeze(0)), | |
| # ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None), | |
| # ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None), | |
| # ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None), | |
| ("logits", prediction_scores.argmax(2)), | |
| # ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None), | |
| ]) | |
| class RoBertaForWikiKGPLM(RobertaPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.config = config | |
| self.roberta = RobertaModel(config) | |
| classifier_dropout = ( | |
| config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob | |
| ) | |
| self.dropout = nn.Dropout(classifier_dropout) | |
| # self.cls = BertOnlyMLMHead(config) | |
| self.lm_head = RobertaLMHead(config) # Masked Language Modeling head | |
| self.detector = NSPHead(config) # Knowledge Noise Detection head | |
| self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size) | |
| # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)]) | |
| self.contrastive_loss_fn = ContrastiveLoss() | |
| self.post_init() | |
| self.tokenizer = RobertaTokenizer.from_pretrained(config.name_or_path) | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| mlm_labels=None, | |
| entity_label=None, | |
| entity_negative=None, | |
| relation_label=None, | |
| relation_negative=None, | |
| noise_detect_label=None, | |
| task_id=None, | |
| mask_id=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| # start_time = time() | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # print("attention_mask.shape=", attention_mask.shape) | |
| # print("input_ids[0]=", input_ids[0]) | |
| # print("token_type_ids[0]=", token_type_ids[0]) | |
| # attention_mask = None | |
| outputs = self.roberta( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output, pooled_output = outputs[:2] | |
| prediction_scores = self.lm_head(sequence_output) # mlm head | |
| noise_detect_scores = self.detector(pooled_output) # knowledge noise detector | |
| # ner | |
| # sequence_output = self.dropout(sequence_output) | |
| # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0) | |
| # mlm | |
| masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None | |
| if mlm_labels is not None: | |
| loss_fct = CrossEntropyLoss() # -100 index = padding token | |
| masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1)) | |
| if noise_detect_label is not None: | |
| loss_fct = CrossEntropyLoss() | |
| noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1)) | |
| total_loss = masked_lm_loss + noise_detect_loss | |
| if entity_label is not None and entity_negative is not None: | |
| batch_size = input_ids.shape[0] | |
| negative_num = entity_negative.shape[1] | |
| # print("negative_num=", negative_num) | |
| # 获取被mask实体的embedding | |
| batch_query_embedding = list() | |
| for ei, input_id in enumerate(input_ids): | |
| batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim] | |
| batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim] | |
| # print("batch_query_embedding.shape=", batch_query_embedding.shape) | |
| batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim] | |
| batch_query_embedding = batch_query_embedding.unsqueeze(1).repeat((1, negative_num + 1, 1)) # [bz, 11, dim] | |
| batch_query_embedding = batch_query_embedding.view(-1, batch_query_embedding.shape[-1]) # [bz * 11, dim] | |
| # print("batch_query_embedding.shape=", batch_query_embedding.shape) | |
| # 获得positive和negative的BERT表示 | |
| # entity_label: [bz, len], entity_negative: [bz, 10, len] | |
| entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len] | |
| entity_label_embedding = self.roberta.embeddings(input_ids=entity_label) # [bz, len, dim] | |
| entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim] | |
| entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim] | |
| entity_negative_embedding = self.roberta.embeddings(input_ids=entity_negative) # [bz * 10, len, dim] | |
| entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim] | |
| entity_negative_embedding = entity_negative_embedding \ | |
| .view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim] | |
| contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda() | |
| contrastive_label = contrastive_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] | |
| # print("entity_negative_embedding.shape=", entity_negative_embedding.shape) | |
| # print("entity_label_embedding.shape=", entity_label_embedding.shape) | |
| candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim] | |
| candidate_embedding = candidate_embedding.view(-1, candidate_embedding.shape[-1]) # [bz * 11, dim] | |
| # print("candidate_embedding.shape=", candidate_embedding.shape) | |
| entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label) | |
| total_loss = masked_lm_loss + entity_loss | |
| # if ner_labels is not None: | |
| # loss_fct = CrossEntropyLoss() | |
| # # Only keep active parts of the loss | |
| # | |
| # active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1 | |
| # active_logits = ner_logits.reshape(-1, self.config.num_ner_labels) | |
| # active_labels = torch.where( | |
| # active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels) | |
| # ) | |
| # ner_loss = loss_fct(active_logits, active_labels) | |
| # | |
| # if masked_lm_loss: | |
| # total_loss = masked_lm_loss + ner_loss * 4 | |
| # print("total_loss=", total_loss) | |
| # print("mlm_loss=", masked_lm_loss) | |
| # end_time = time() | |
| # print("neural_mode_time: {}".format(end_time - start_time)) | |
| return OrderedDict([ | |
| ("loss", total_loss), | |
| ("mlm_loss", masked_lm_loss.unsqueeze(0)), | |
| ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None), | |
| ("entity_loss", entity_loss.unsqueeze(0) if entity_label is not None else None), | |
| ("logits", prediction_scores.argmax(2)), | |
| ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None else None), | |
| ]) | |
| # MaskedLMOutput( | |
| # loss=total_loss, | |
| # logits=prediction_scores.argmax(2), | |
| # ner_l | |
| # hidden_states=outputs.hidden_states, | |
| # attentions=outputs.attentions, | |
| # ) | |
| class BertForWikiKGPLM(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.config = config | |
| self.bert = BertModel(config) | |
| classifier_dropout = ( | |
| config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob | |
| ) | |
| self.dropout = nn.Dropout(classifier_dropout) | |
| # self.cls = BertOnlyMLMHead(config) | |
| self.cls = BertPreTrainedModel(config) | |
| self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size) | |
| # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)]) | |
| self.contrastive_loss_fn = ContrastiveLoss() | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| mlm_labels=None, | |
| entity_label=None, | |
| entity_negative=None, | |
| relation_label=None, | |
| relation_negative=None, | |
| noise_detect_label=None, | |
| task_id=None, | |
| mask_id=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| print("attention_mask.shape=", attention_mask.shape) | |
| print("input_ids[0]=", input_ids[0]) | |
| print("token_type_ids[0]=", token_type_ids[0]) | |
| attention_mask = None | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output, pooled_output = outputs[:2] | |
| prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) | |
| # ner | |
| # sequence_output = self.dropout(sequence_output) | |
| # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0) | |
| # mlm | |
| masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None | |
| if mlm_labels is not None: | |
| loss_fct = CrossEntropyLoss() # -100 index = padding token | |
| masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1)) | |
| if noise_detect_label is not None: | |
| loss_fct = CrossEntropyLoss() | |
| noise_detect_loss = loss_fct(seq_relationship_score.view(-1, 2), noise_detect_label.view(-1)) | |
| total_loss = masked_lm_loss + noise_detect_loss | |
| if entity_label is not None and entity_negative is not None: | |
| negative_num = entity_negative.shape[1] | |
| # 获取被mask实体的embedding | |
| batch_query_embedding = list() | |
| for ei, input_id in enumerate(input_ids): | |
| batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim] | |
| batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim] | |
| batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim] | |
| batch_query_embedding = batch_query_embedding.repeat((1, negative_num + 1, 1)) # [bz, 11, dim] | |
| # 获得positive和negative的BERT表示 | |
| # entity_label: [bz, len], entity_negative: [bz, 10, len] | |
| entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len] | |
| entity_label_embedding = self.bert.embeddings(input_id=entity_label) # [bz, len, dim] | |
| entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim] | |
| entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim] | |
| entity_negative_embedding = self.bert.embeddings(input_id=entity_negative) # [bz * 10, len, dim] | |
| entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim] | |
| entity_negative_embedding = entity_negative_embedding \ | |
| .view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim] | |
| contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda() | |
| candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim] | |
| entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label) | |
| total_loss = masked_lm_loss + entity_loss | |
| # if ner_labels is not None: | |
| # loss_fct = CrossEntropyLoss() | |
| # # Only keep active parts of the loss | |
| # | |
| # active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1 | |
| # active_logits = ner_logits.reshape(-1, self.config.num_ner_labels) | |
| # active_labels = torch.where( | |
| # active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels) | |
| # ) | |
| # ner_loss = loss_fct(active_logits, active_labels) | |
| # | |
| # if masked_lm_loss: | |
| # total_loss = masked_lm_loss + ner_loss * 4 | |
| return OrderedDict([ | |
| ("loss", total_loss), | |
| ("mlm_loss", masked_lm_loss.unsqueeze(0)), | |
| ("noise_detect_loss", noise_detect_loss.unsqueeze(0)), | |
| ("entity_loss", entity_loss.unsqueeze(0)), | |
| ("logits", prediction_scores.argmax(2)), | |
| ("noise_detect_logits", seq_relationship_score.argmax(3)), | |
| () | |
| ]) | |
| # MaskedLMOutput( | |
| # loss=total_loss, | |
| # logits=prediction_scores.argmax(2), | |
| # ner_l | |
| # hidden_states=outputs.hidden_states, | |
| # attentions=outputs.attentions, | |
| # ) | |