""" FROM https://github.com/hasan-rakibul/UPLME/tree/main """ import torch from torch import Tensor import lightning as L from transformers import ( AutoModel, ) import logging logger = logging.getLogger(__name__) class CrossEncoderProbModel(torch.nn.Module): def __init__(self, plm_name: str): super().__init__() self.model = AutoModel.from_pretrained(plm_name) if plm_name.startswith("roberta"): # only applicable for roberta self.pooling = "roberta-pooler" else: self.pooling = "cls" self.out_proj_m = torch.nn.Sequential( torch.nn.LayerNorm(self.model.config.hidden_size), torch.nn.Dropout(0.25), torch.nn.Linear(self.model.config.hidden_size, 1) ) self.out_proj_v = torch.nn.Sequential( torch.nn.LayerNorm(self.model.config.hidden_size), torch.nn.Dropout(0.25), torch.nn.Linear(self.model.config.hidden_size, 1), torch.nn.Softplus() ) def forward(self, input_ids, attention_mask): output = self.model( input_ids=input_ids, attention_mask=attention_mask ) if self.pooling == "mean": sentence_representation = ( (output.last_hidden_state * attention_mask.unsqueeze(-1)).sum(-2) / attention_mask.sum(dim=-1).unsqueeze(-1) ) elif self.pooling == "cls": sentence_representation = output.last_hidden_state[:, 0, :] elif self.pooling == "roberta-pooler": sentence_representation = output.pooler_output # (batch_size, hidden_dim) mean = self.out_proj_m(sentence_representation) var = self.out_proj_v(sentence_representation) var = torch.clamp(var, min=1e-8, max=1000) # following Seitzer-NeurIPS2022 return mean.squeeze(), var.squeeze(), sentence_representation, output.last_hidden_state class LitPairedTextModel(L.LightningModule): def __init__( self, plm_names: list[str], lr: float, log_dir: str, save_uc_metrics: bool, error_decay_factor: float, approach: str, sep_token_id: int, # required for alignment loss lambdas: list[float] = [], # initlisaed to compatible with old saved checkpoints num_passes: int = 4 ): super().__init__() self.save_hyperparameters() self.approach = approach self.model = CrossEncoderProbModel(plm_name=plm_names[0]) self.lr = lr self.log_dir = log_dir self.save_uc_metrics = save_uc_metrics self.error_decay_factor = error_decay_factor self.lambdas = lambdas self.sep_token_id = sep_token_id self.num_passes = num_passes self.penalty_type = "exp-decay" self.validation_outputs = [] self.test_outputs = [] def forward(self, batch: dict) -> tuple[Tensor, Tensor, Tensor]: self._enable_dropout_at_inference() means, varss, hidden_states = [], [], [] for _ in range(self.num_passes): if self.approach == "cross-prob": mean, var, _, hidden_state = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'] ) elif self.approach == "cross-basic": mean, hidden_state = self.model(batch) var = torch.zeros_like(mean) means.append(mean) varss.append(var) hidden_states.append(hidden_state) mean = torch.stack(means, dim=0).mean(dim=0) var = torch.stack(varss, dim=0).mean(dim=0) hidden_state = torch.stack(hidden_states, dim=0).mean(dim=0) return mean, var, hidden_state def _enable_dropout_at_inference(self): for m in self.model.modules(): if isinstance(m, torch.nn.Dropout): m.train()