rhasan commited on
Commit
6b3d060
·
1 Parent(s): d3b6e1f

infer working locally - first try at HF

Browse files
Files changed (4) hide show
  1. app.py +54 -0
  2. requirements.txt +6 -0
  3. src/infer.py +38 -0
  4. src/paired_texts_modelling.py +131 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import gradio as gr
4
+ import numpy as np
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from src.infer import load_model, predict
8
+
9
+ os.environ.setdefault("HF_HOME", "/data/.huggingface")
10
+
11
+ _model = None
12
+ _ckpt_path = None
13
+
14
+ def _warmup():
15
+ global _model, _ckpt_path
16
+ if _model is not None:
17
+ return
18
+ t0 = time.time()
19
+ _ckpt_path = hf_hub_download(
20
+ repo_id="rhasan/empathy",
21
+ filename="UPLME_NewsEmp_tuned-lambdas.ckpt",
22
+ repo_type="model",
23
+ local_dir="/data/uplme_ckpt"
24
+ )
25
+ load_model(_ckpt_path)
26
+ return f"Model loaded in {time.time() - t0:.1f} seconds."
27
+
28
+ def predict_with_ci(essay: str, article: str) -> dict:
29
+ _warmup()
30
+ mean, var = predict(essay, article)
31
+ # scores were originally in [1, 7]
32
+ # lets scale them to [0, 100]
33
+ mean = (mean - 1) / 6 * 100
34
+
35
+ std = np.sqrt(var)
36
+ ci_low = max(0.0, mean - 1.96 * std)
37
+ ci_upp = min(100.0, mean + 1.96 * std)
38
+ return {"mean": mean, "ci": (ci_low, ci_upp)}
39
+
40
+ with gr.Blocks(title="Empathy Prediction") as demo:
41
+ gr.Markdown("# Empathy Prediction with Uncertainty Estimation")
42
+ with gr.Row():
43
+ with gr.Column():
44
+ essay_input = gr.Textbox(label="Essay", lines=10, placeholder="Enter the essay text here...")
45
+ article_input = gr.Textbox(label="Article", lines=10, placeholder="Enter the article text here...")
46
+ button = gr.Button("Predict")
47
+ with gr.Column():
48
+ output_mean = gr.Number(label="Predicted Empathy Mean", precision=4)
49
+ ci = gr.Number(label="95\% CI", precision=4)
50
+
51
+ button.click(fn=predict_with_ci, inputs=[essay_input, article_input], outputs=[output_mean, ci])
52
+
53
+ if __name__ == "__main__":
54
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ lightning
5
+ numpy
6
+ huggingface_hub
src/infer.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FROM https://github.com/hasan-rakibul/UPLME/tree/main
3
+ """
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+
8
+ from paired_texts_modelling import LitPairedTextModel
9
+
10
+ _device = None
11
+ _model = None
12
+ _tokeniser = None
13
+
14
+ def load_model(ckpt_path: str):
15
+ global _model, _tokeniser, _device
16
+ plm_name = "roberta-base"
17
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ _model = LitPairedTextModel.load_from_checkpoint(ckpt_path).to(_device).eval()
19
+ _tokeniser = AutoTokenizer.from_pretrained(
20
+ plm_name,
21
+ use_fast=True,
22
+ add_prefix_space=False # the first word is tokenised differently if not a prefix space, but it might decrease performance, so False (09/24)
23
+ )
24
+
25
+ @torch.inference_mode()
26
+ def predict(essay: str, article: str) -> tuple[float, float]:
27
+ max_length = 512
28
+ toks = _tokeniser(
29
+ essay,
30
+ article,
31
+ truncation=True,
32
+ max_length=max_length,
33
+ return_tensors="pt"
34
+ ).to(_device)
35
+ mean, var, _ = _model(toks)
36
+ return mean.item(), var.item()
37
+
38
+
src/paired_texts_modelling.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FROM https://github.com/hasan-rakibul/UPLME/tree/main
3
+ """
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import lightning as L
8
+ from transformers import (
9
+ AutoModel,
10
+ )
11
+ import logging
12
+
13
+ import lightning as L
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class CrossEncoderProbModel(torch.nn.Module):
19
+ def __init__(self, plm_name: str):
20
+ super().__init__()
21
+ self.model = AutoModel.from_pretrained(plm_name)
22
+
23
+ if plm_name.startswith("roberta"):
24
+ # only applicable for roberta
25
+ self.pooling = "roberta-pooler"
26
+ else:
27
+ self.pooling = "cls"
28
+
29
+ self.out_proj_m = torch.nn.Sequential(
30
+ torch.nn.LayerNorm(self.model.config.hidden_size),
31
+ torch.nn.Dropout(0.25),
32
+ torch.nn.Linear(self.model.config.hidden_size, 1)
33
+ )
34
+
35
+ self.out_proj_v = torch.nn.Sequential(
36
+ torch.nn.LayerNorm(self.model.config.hidden_size),
37
+ torch.nn.Dropout(0.25),
38
+ torch.nn.Linear(self.model.config.hidden_size, 1),
39
+ torch.nn.Softplus()
40
+ )
41
+
42
+ def forward(self, input_ids, attention_mask):
43
+ output = self.model(
44
+ input_ids=input_ids,
45
+ attention_mask=attention_mask
46
+ )
47
+
48
+ if self.pooling == "mean":
49
+ sentence_representation = (
50
+ (output.last_hidden_state * attention_mask.unsqueeze(-1)).sum(-2) /
51
+ attention_mask.sum(dim=-1).unsqueeze(-1)
52
+ )
53
+ elif self.pooling == "cls":
54
+ sentence_representation = output.last_hidden_state[:, 0, :]
55
+ elif self.pooling == "roberta-pooler":
56
+ sentence_representation = output.pooler_output # (batch_size, hidden_dim)
57
+
58
+ mean = self.out_proj_m(sentence_representation)
59
+ var = self.out_proj_v(sentence_representation)
60
+ var = torch.clamp(var, min=1e-8, max=1000) # following Seitzer-NeurIPS2022
61
+
62
+ return mean.squeeze(), var.squeeze(), sentence_representation, output.last_hidden_state
63
+
64
+ class LitPairedTextModel(L.LightningModule):
65
+ def __init__(
66
+ self,
67
+ plm_names: list[str],
68
+ lr: float,
69
+ log_dir: str,
70
+ save_uc_metrics: bool,
71
+ error_decay_factor: float,
72
+ approach: str,
73
+ sep_token_id: int, # required for alignment loss
74
+ lambdas: list[float] = [], # initlisaed to compatible with old saved checkpoints
75
+ num_passes: int = 4
76
+ ):
77
+ super().__init__()
78
+ self.save_hyperparameters()
79
+
80
+ self.approach = approach
81
+ if self.approach == "cross-basic":
82
+ self.model = CrossEncoderBasicModel(plm_name=plm_names[0])
83
+ elif self.approach == "cross-prob":
84
+ self.model = CrossEncoderProbModel(plm_name=plm_names[0])
85
+ else:
86
+ raise ValueError(f"Invalid approach: {self.approach}")
87
+
88
+ self.lr = lr
89
+ self.log_dir = log_dir
90
+ self.save_uc_metrics = save_uc_metrics
91
+
92
+ self.error_decay_factor = error_decay_factor
93
+
94
+ self.lambdas = lambdas
95
+ self.sep_token_id = sep_token_id
96
+ self.num_passes = num_passes
97
+
98
+ self.penalty_type = "exp-decay"
99
+
100
+ self.validation_outputs = []
101
+ self.test_outputs = []
102
+
103
+ def forward(self, batch: dict) -> tuple[Tensor, Tensor, Tensor]:
104
+ means, varss, hidden_states = [], [], []
105
+
106
+ for _ in range(self.num_passes):
107
+ if self.approach == "cross-prob":
108
+ mean, var, _, hidden_state = self.model(
109
+ input_ids=batch['input_ids'],
110
+ attention_mask=batch['attention_mask']
111
+ )
112
+ elif self.approach == "cross-basic":
113
+ mean, hidden_state = self.model(batch)
114
+ var = torch.zeros_like(mean)
115
+
116
+ means.append(mean)
117
+ varss.append(var)
118
+ hidden_states.append(hidden_state)
119
+
120
+ mean = torch.stack(means, dim=0).mean(dim=0)
121
+ var = torch.stack(varss, dim=0).mean(dim=0)
122
+ hidden_state = torch.stack(hidden_states, dim=0).mean(dim=0)
123
+
124
+ return mean, var, hidden_state
125
+
126
+ def _enable_dropout_at_inference(self):
127
+ for m in self.model.modules():
128
+ if isinstance(m, torch.nn.Dropout):
129
+ m.train()
130
+
131
+