empathy / src /paired_texts_modelling.py
rhasan's picture
adding fig
2d855ab
"""
FROM https://github.com/hasan-rakibul/UPLME/tree/main
"""
import torch
from torch import Tensor
import lightning as L
from transformers import (
AutoModel,
)
import logging
logger = logging.getLogger(__name__)
class CrossEncoderProbModel(torch.nn.Module):
def __init__(self, plm_name: str):
super().__init__()
self.model = AutoModel.from_pretrained(plm_name)
if plm_name.startswith("roberta"):
# only applicable for roberta
self.pooling = "roberta-pooler"
else:
self.pooling = "cls"
self.out_proj_m = torch.nn.Sequential(
torch.nn.LayerNorm(self.model.config.hidden_size),
torch.nn.Dropout(0.25),
torch.nn.Linear(self.model.config.hidden_size, 1)
)
self.out_proj_v = torch.nn.Sequential(
torch.nn.LayerNorm(self.model.config.hidden_size),
torch.nn.Dropout(0.25),
torch.nn.Linear(self.model.config.hidden_size, 1),
torch.nn.Softplus()
)
def forward(self, input_ids, attention_mask):
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
if self.pooling == "mean":
sentence_representation = (
(output.last_hidden_state * attention_mask.unsqueeze(-1)).sum(-2) /
attention_mask.sum(dim=-1).unsqueeze(-1)
)
elif self.pooling == "cls":
sentence_representation = output.last_hidden_state[:, 0, :]
elif self.pooling == "roberta-pooler":
sentence_representation = output.pooler_output # (batch_size, hidden_dim)
mean = self.out_proj_m(sentence_representation)
var = self.out_proj_v(sentence_representation)
var = torch.clamp(var, min=1e-8, max=1000) # following Seitzer-NeurIPS2022
return mean.squeeze(), var.squeeze(), sentence_representation, output.last_hidden_state
class LitPairedTextModel(L.LightningModule):
def __init__(
self,
plm_names: list[str],
lr: float,
log_dir: str,
save_uc_metrics: bool,
error_decay_factor: float,
approach: str,
sep_token_id: int, # required for alignment loss
lambdas: list[float] = [], # initlisaed to compatible with old saved checkpoints
num_passes: int = 4
):
super().__init__()
self.save_hyperparameters()
self.approach = approach
self.model = CrossEncoderProbModel(plm_name=plm_names[0])
self.lr = lr
self.log_dir = log_dir
self.save_uc_metrics = save_uc_metrics
self.error_decay_factor = error_decay_factor
self.lambdas = lambdas
self.sep_token_id = sep_token_id
self.num_passes = num_passes
self.penalty_type = "exp-decay"
self.validation_outputs = []
self.test_outputs = []
def forward(self, batch: dict) -> tuple[Tensor, Tensor, Tensor]:
self._enable_dropout_at_inference()
means, varss, hidden_states = [], [], []
for _ in range(self.num_passes):
if self.approach == "cross-prob":
mean, var, _, hidden_state = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
elif self.approach == "cross-basic":
mean, hidden_state = self.model(batch)
var = torch.zeros_like(mean)
means.append(mean)
varss.append(var)
hidden_states.append(hidden_state)
mean = torch.stack(means, dim=0).mean(dim=0)
var = torch.stack(varss, dim=0).mean(dim=0)
hidden_state = torch.stack(hidden_states, dim=0).mean(dim=0)
return mean, var, hidden_state
def _enable_dropout_at_inference(self):
for m in self.model.modules():
if isinstance(m, torch.nn.Dropout):
m.train()