| import torch | |
| from transformers import PreTrainedModel | |
| class BiEncoderModelRegression(torch.nn.Module): | |
| def __init__(self, base_model, config=None, loss_fn="mse"): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.cos = torch.nn.CosineSimilarity(dim=1) | |
| self.loss_fn = loss_fn | |
| self.config = config | |
| def forward(self, input_ids_text1, attention_mask_text1, input_ids_text2, attention_mask_text2, labels=None): | |
| outputs_text1 = self.base_model(input_ids_text1, attention_mask=attention_mask_text1) | |
| outputs_text2 = self.base_model(input_ids_text2, attention_mask=attention_mask_text2) | |
| cls_embedding_text1 = outputs_text1.last_hidden_state[:, 0, :] | |
| cls_embedding_text2 = outputs_text2.last_hidden_state[:, 0, :] | |
| cos_sim = self.cos(cls_embedding_text1, cls_embedding_text2) | |
| loss = None | |
| if labels is not None: | |
| if self.loss_fn == "mse": | |
| loss_fct = torch.nn.MSELoss() | |
| elif self.loss_fn == "mae": | |
| loss_fct = torch.nn.L1Loss() | |
| elif self.loss_fn == "cosine_embedding": | |
| loss_fct = torch.nn.CosineEmbeddingLoss() | |
| labels_cosine = 2 * (labels > 0.5).float() - 1 | |
| return {"loss": loss_fct(cls_embedding_text1, cls_embedding_text2, labels_cosine), "logits": cos_sim} | |
| loss = loss_fct(cos_sim, labels) | |
| return {"loss": loss, "logits": cos_sim} | |