File size: 4,042 Bytes
6b3d060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c031815
6b3d060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c031815
6b3d060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
FROM https://github.com/hasan-rakibul/UPLME/tree/main
"""

import torch
from torch import Tensor
import lightning as L
from transformers import (
    AutoModel, 
)
import logging

logger = logging.getLogger(__name__)
    
class CrossEncoderProbModel(torch.nn.Module):
    def __init__(self, plm_name: str):
        super().__init__()
        self.model = AutoModel.from_pretrained(plm_name)

        if plm_name.startswith("roberta"):
            # only applicable for roberta
            self.pooling = "roberta-pooler"
        else:
            self.pooling = "cls"

        self.out_proj_m = torch.nn.Sequential(
            torch.nn.LayerNorm(self.model.config.hidden_size),
            torch.nn.Dropout(0.25),
            torch.nn.Linear(self.model.config.hidden_size, 1)
        )

        self.out_proj_v = torch.nn.Sequential(
            torch.nn.LayerNorm(self.model.config.hidden_size),
            torch.nn.Dropout(0.25),
            torch.nn.Linear(self.model.config.hidden_size, 1),
            torch.nn.Softplus()
        )

    def forward(self, input_ids, attention_mask):
        output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        if self.pooling == "mean":
            sentence_representation = (
                (output.last_hidden_state * attention_mask.unsqueeze(-1)).sum(-2) /
                attention_mask.sum(dim=-1).unsqueeze(-1)
            )
        elif self.pooling == "cls":
            sentence_representation = output.last_hidden_state[:, 0, :]
        elif self.pooling == "roberta-pooler":
            sentence_representation = output.pooler_output # (batch_size, hidden_dim)

        mean = self.out_proj_m(sentence_representation)
        var = self.out_proj_v(sentence_representation)
        var = torch.clamp(var, min=1e-8, max=1000) # following Seitzer-NeurIPS2022

        return mean.squeeze(), var.squeeze(), sentence_representation, output.last_hidden_state
    
class LitPairedTextModel(L.LightningModule):
    def __init__(
        self,
        plm_names: list[str],
        lr: float,
        log_dir: str,
        save_uc_metrics: bool,
        error_decay_factor: float,
        approach: str,
        sep_token_id: int, # required for alignment loss
        lambdas: list[float] = [], # initlisaed to compatible with old saved checkpoints
        num_passes: int = 4
    ):
        super().__init__()
        self.save_hyperparameters()

        self.approach = approach
        self.model = CrossEncoderProbModel(plm_name=plm_names[0])

        self.lr = lr
        self.log_dir = log_dir
        self.save_uc_metrics = save_uc_metrics

        self.error_decay_factor = error_decay_factor
        
        self.lambdas = lambdas
        self.sep_token_id = sep_token_id
        self.num_passes = num_passes
        
        self.penalty_type = "exp-decay"

        self.validation_outputs = []
        self.test_outputs = []
    
    def forward(self, batch: dict) -> tuple[Tensor, Tensor, Tensor]:
        self._enable_dropout_at_inference()
        means, varss, hidden_states = [], [], []

        for _ in range(self.num_passes):
            if self.approach == "cross-prob":
                mean, var, _, hidden_state = self.model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask']
                )
            elif self.approach == "cross-basic":
                mean, hidden_state = self.model(batch)
                var = torch.zeros_like(mean)
            
            means.append(mean)
            varss.append(var)
            hidden_states.append(hidden_state)
        
        mean = torch.stack(means, dim=0).mean(dim=0)
        var = torch.stack(varss, dim=0).mean(dim=0)
        hidden_state = torch.stack(hidden_states, dim=0).mean(dim=0)

        return mean, var, hidden_state
    
    def _enable_dropout_at_inference(self):
        for m in self.model.modules():
            if isinstance(m, torch.nn.Dropout):
                m.train()