|
|
""" |
|
|
FROM https://github.com/hasan-rakibul/UPLME/tree/main |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
import lightning as L |
|
|
from transformers import ( |
|
|
AutoModel, |
|
|
) |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CrossEncoderProbModel(torch.nn.Module): |
|
|
def __init__(self, plm_name: str): |
|
|
super().__init__() |
|
|
self.model = AutoModel.from_pretrained(plm_name) |
|
|
|
|
|
if plm_name.startswith("roberta"): |
|
|
|
|
|
self.pooling = "roberta-pooler" |
|
|
else: |
|
|
self.pooling = "cls" |
|
|
|
|
|
self.out_proj_m = torch.nn.Sequential( |
|
|
torch.nn.LayerNorm(self.model.config.hidden_size), |
|
|
torch.nn.Dropout(0.25), |
|
|
torch.nn.Linear(self.model.config.hidden_size, 1) |
|
|
) |
|
|
|
|
|
self.out_proj_v = torch.nn.Sequential( |
|
|
torch.nn.LayerNorm(self.model.config.hidden_size), |
|
|
torch.nn.Dropout(0.25), |
|
|
torch.nn.Linear(self.model.config.hidden_size, 1), |
|
|
torch.nn.Softplus() |
|
|
) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
output = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
|
|
|
if self.pooling == "mean": |
|
|
sentence_representation = ( |
|
|
(output.last_hidden_state * attention_mask.unsqueeze(-1)).sum(-2) / |
|
|
attention_mask.sum(dim=-1).unsqueeze(-1) |
|
|
) |
|
|
elif self.pooling == "cls": |
|
|
sentence_representation = output.last_hidden_state[:, 0, :] |
|
|
elif self.pooling == "roberta-pooler": |
|
|
sentence_representation = output.pooler_output |
|
|
|
|
|
mean = self.out_proj_m(sentence_representation) |
|
|
var = self.out_proj_v(sentence_representation) |
|
|
var = torch.clamp(var, min=1e-8, max=1000) |
|
|
|
|
|
return mean.squeeze(), var.squeeze(), sentence_representation, output.last_hidden_state |
|
|
|
|
|
class LitPairedTextModel(L.LightningModule): |
|
|
def __init__( |
|
|
self, |
|
|
plm_names: list[str], |
|
|
lr: float, |
|
|
log_dir: str, |
|
|
save_uc_metrics: bool, |
|
|
error_decay_factor: float, |
|
|
approach: str, |
|
|
sep_token_id: int, |
|
|
lambdas: list[float] = [], |
|
|
num_passes: int = 4 |
|
|
): |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
|
|
|
self.approach = approach |
|
|
self.model = CrossEncoderProbModel(plm_name=plm_names[0]) |
|
|
|
|
|
self.lr = lr |
|
|
self.log_dir = log_dir |
|
|
self.save_uc_metrics = save_uc_metrics |
|
|
|
|
|
self.error_decay_factor = error_decay_factor |
|
|
|
|
|
self.lambdas = lambdas |
|
|
self.sep_token_id = sep_token_id |
|
|
self.num_passes = num_passes |
|
|
|
|
|
self.penalty_type = "exp-decay" |
|
|
|
|
|
self.validation_outputs = [] |
|
|
self.test_outputs = [] |
|
|
|
|
|
def forward(self, batch: dict) -> tuple[Tensor, Tensor, Tensor]: |
|
|
self._enable_dropout_at_inference() |
|
|
means, varss, hidden_states = [], [], [] |
|
|
|
|
|
for _ in range(self.num_passes): |
|
|
if self.approach == "cross-prob": |
|
|
mean, var, _, hidden_state = self.model( |
|
|
input_ids=batch['input_ids'], |
|
|
attention_mask=batch['attention_mask'] |
|
|
) |
|
|
elif self.approach == "cross-basic": |
|
|
mean, hidden_state = self.model(batch) |
|
|
var = torch.zeros_like(mean) |
|
|
|
|
|
means.append(mean) |
|
|
varss.append(var) |
|
|
hidden_states.append(hidden_state) |
|
|
|
|
|
mean = torch.stack(means, dim=0).mean(dim=0) |
|
|
var = torch.stack(varss, dim=0).mean(dim=0) |
|
|
hidden_state = torch.stack(hidden_states, dim=0).mean(dim=0) |
|
|
|
|
|
return mean, var, hidden_state |
|
|
|
|
|
def _enable_dropout_at_inference(self): |
|
|
for m in self.model.modules(): |
|
|
if isinstance(m, torch.nn.Dropout): |
|
|
m.train() |
|
|
|