Spaces:
Runtime error
Runtime error
| import math | |
| from fairseq.criterions import register_criterion | |
| from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion | |
| from fairseq import metrics, utils | |
| from collections import deque | |
| import torch | |
| import torch.nn as nn | |
| class LabelSmoothedCrossEntropyCriterionWithContrastive( | |
| LabelSmoothedCrossEntropyCriterion | |
| ): | |
| def __init__(self, task, sentence_avg, label_smoothing, ignore_prefix_size=0, report_accuracy=False, | |
| contrastive_lambda=0.0, | |
| temperature=1.0): | |
| super().__init__(task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy) | |
| self.contrastive_lambda = contrastive_lambda | |
| self.temperature = temperature | |
| def add_args(parser): | |
| LabelSmoothedCrossEntropyCriterion.add_args(parser) | |
| parser.add_argument("--contrastive-lambda", type=float, | |
| default=0.0, | |
| help="The contrastive loss weight") | |
| parser.add_argument("--temperature", type=float, | |
| default=1.0,) | |
| def swap_sample(self, sample): | |
| target = sample["target"] | |
| prev_output_tokens = sample["net_input"]["prev_output_tokens"] | |
| src_tokens = torch.cat((prev_output_tokens[:, :1], sample["net_input"]['src_tokens']), dim=-1) | |
| return { | |
| "net_input": { | |
| "src_tokens": target.contiguous(), | |
| "src_lengths": (target != self.padding_idx).int().sum(dim=1), | |
| "prev_output_tokens": src_tokens[:, :-1].contiguous() | |
| }, | |
| 'nsentences': sample['nsentences'], | |
| 'ntokens': utils.item((src_tokens[:, 1:] != self.padding_idx).int().sum().data), | |
| "target": src_tokens[:, 1:].contiguous(), | |
| "id": sample["id"], | |
| } | |
| def forward(self, model, sample, reduce=True): | |
| net_output = model(**sample["net_input"]) | |
| loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) | |
| encoder_out = model.encoder.forward(sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]).encoder_out | |
| reverse_sample = self.swap_sample(sample) | |
| reversed_encoder_out = model.encoder.forward(reverse_sample["net_input"]["src_tokens"], reverse_sample["net_input"]["src_lengths"]).encoder_out | |
| contrastive_loss = self.get_contrastive_loss( | |
| encoder_out, | |
| reversed_encoder_out, | |
| sample, | |
| reverse_sample, | |
| ) | |
| sample_size = ( | |
| sample["target"].size(0) if self.sentence_avg else sample["ntokens"] | |
| ) | |
| nsentences = sample["target"].size(0) | |
| ntokens = sample["ntokens"] | |
| all_loss = loss + contrastive_loss * self.contrastive_lambda * ntokens / nsentences | |
| logging_output = { | |
| "loss": loss.data, | |
| "nll_loss": nll_loss.data, | |
| "ntokens": ntokens, | |
| "nsentences": nsentences, | |
| "sample_size": sample_size, | |
| } | |
| if isinstance(contrastive_loss, int): | |
| logging_output["contrastive_loss"] = 0 | |
| else: | |
| logging_output["contrastive_loss"] = utils.item(contrastive_loss.data) | |
| return all_loss, sample_size, logging_output | |
| def similarity_function(self, ): | |
| return nn.CosineSimilarity(dim=-1) | |
| def get_contrastive_loss(self, encoder_out1, encoder_out2, sample1, sample2): | |
| def _sentence_embedding(encoder_out, sample): | |
| encoder_output = encoder_out.transpose(0, 1) | |
| src_tokens = sample["net_input"]["src_tokens"] | |
| mask = (src_tokens != self.padding_idx) | |
| encoder_embedding = (encoder_output * mask.unsqueeze(-1)).sum(dim=1) / mask.float().sum(dim=1).unsqueeze(-1) # [batch, hidden_size] | |
| return encoder_embedding | |
| encoder_embedding1 = _sentence_embedding(encoder_out1, sample1) # [batch, hidden_size] | |
| encoder_embedding2 = _sentence_embedding(encoder_out2, sample2) # [batch, hidden_size] | |
| batch_size = encoder_embedding2.shape[0] | |
| feature_dim = encoder_embedding2.shape[1] | |
| anchor_feature = encoder_embedding1 | |
| contrast_feature = encoder_embedding2 | |
| similarity_function = self.similarity_function() | |
| anchor_dot_contrast = similarity_function(anchor_feature.expand((batch_size, batch_size, feature_dim)), | |
| torch.transpose(contrast_feature.expand((batch_size, batch_size, feature_dim)), 0, 1)) | |
| loss = -nn.LogSoftmax(0)(torch.div(anchor_dot_contrast, self.temperature)).diag().sum() | |
| return loss | |
| def reduce_metrics(cls, logging_outputs) -> None: | |
| super().reduce_metrics(logging_outputs) | |
| nsentences = utils.item( | |
| sum(log.get("nsentences", 0) for log in logging_outputs) | |
| ) | |
| contrastive_loss = utils.item( | |
| sum(log.get("contrastive_loss", 0) for log in logging_outputs) | |
| ) | |
| metrics.log_scalar( | |
| "contrastive_loss", | |
| contrastive_loss / nsentences / math.log(2), | |
| nsentences, | |
| round=3, | |
| ) | |