rhasan commited on
Commit
c031815
·
1 Parent(s): 6b3d060

fixed module load

Browse files
Files changed (3) hide show
  1. app.py +8 -7
  2. src/infer.py +1 -1
  3. src/paired_texts_modelling.py +2 -9
app.py CHANGED
@@ -17,7 +17,7 @@ def _warmup():
17
  return
18
  t0 = time.time()
19
  _ckpt_path = hf_hub_download(
20
- repo_id="rhasan/empathy",
21
  filename="UPLME_NewsEmp_tuned-lambdas.ckpt",
22
  repo_type="model",
23
  local_dir="/data/uplme_ckpt"
@@ -25,7 +25,7 @@ def _warmup():
25
  load_model(_ckpt_path)
26
  return f"Model loaded in {time.time() - t0:.1f} seconds."
27
 
28
- def predict_with_ci(essay: str, article: str) -> dict:
29
  _warmup()
30
  mean, var = predict(essay, article)
31
  # scores were originally in [1, 7]
@@ -35,7 +35,7 @@ def predict_with_ci(essay: str, article: str) -> dict:
35
  std = np.sqrt(var)
36
  ci_low = max(0.0, mean - 1.96 * std)
37
  ci_upp = min(100.0, mean + 1.96 * std)
38
- return {"mean": mean, "ci": (ci_low, ci_upp)}
39
 
40
  with gr.Blocks(title="Empathy Prediction") as demo:
41
  gr.Markdown("# Empathy Prediction with Uncertainty Estimation")
@@ -45,10 +45,11 @@ with gr.Blocks(title="Empathy Prediction") as demo:
45
  article_input = gr.Textbox(label="Article", lines=10, placeholder="Enter the article text here...")
46
  button = gr.Button("Predict")
47
  with gr.Column():
48
- output_mean = gr.Number(label="Predicted Empathy Mean", precision=4)
49
- ci = gr.Number(label="95\% CI", precision=4)
50
-
51
- button.click(fn=predict_with_ci, inputs=[essay_input, article_input], outputs=[output_mean, ci])
 
52
 
53
  if __name__ == "__main__":
54
  demo.launch()
 
17
  return
18
  t0 = time.time()
19
  _ckpt_path = hf_hub_download(
20
+ repo_id="rhasan/UPLME",
21
  filename="UPLME_NewsEmp_tuned-lambdas.ckpt",
22
  repo_type="model",
23
  local_dir="/data/uplme_ckpt"
 
25
  load_model(_ckpt_path)
26
  return f"Model loaded in {time.time() - t0:.1f} seconds."
27
 
28
+ def predict_with_ci(essay: str, article: str) -> tuple[float, float, float]:
29
  _warmup()
30
  mean, var = predict(essay, article)
31
  # scores were originally in [1, 7]
 
35
  std = np.sqrt(var)
36
  ci_low = max(0.0, mean - 1.96 * std)
37
  ci_upp = min(100.0, mean + 1.96 * std)
38
+ return mean, ci_low, ci_upp
39
 
40
  with gr.Blocks(title="Empathy Prediction") as demo:
41
  gr.Markdown("# Empathy Prediction with Uncertainty Estimation")
 
45
  article_input = gr.Textbox(label="Article", lines=10, placeholder="Enter the article text here...")
46
  button = gr.Button("Predict")
47
  with gr.Column():
48
+ output_mean = gr.Number(label="Predicted Empathy Mean", precision=2)
49
+ ci_low = gr.Number(label="95% CI Lower Bound", precision=2)
50
+ ci_upp = gr.Number(label="95% CI Upper Bound", precision=2)
51
+
52
+ button.click(fn=predict_with_ci, inputs=[essay_input, article_input], outputs=[output_mean, ci_low, ci_upp])
53
 
54
  if __name__ == "__main__":
55
  demo.launch()
src/infer.py CHANGED
@@ -5,7 +5,7 @@ FROM https://github.com/hasan-rakibul/UPLME/tree/main
5
  import torch
6
  from transformers import AutoTokenizer
7
 
8
- from paired_texts_modelling import LitPairedTextModel
9
 
10
  _device = None
11
  _model = None
 
5
  import torch
6
  from transformers import AutoTokenizer
7
 
8
+ from src.paired_texts_modelling import LitPairedTextModel
9
 
10
  _device = None
11
  _model = None
src/paired_texts_modelling.py CHANGED
@@ -10,9 +10,6 @@ from transformers import (
10
  )
11
  import logging
12
 
13
- import lightning as L
14
-
15
-
16
  logger = logging.getLogger(__name__)
17
 
18
  class CrossEncoderProbModel(torch.nn.Module):
@@ -78,12 +75,7 @@ class LitPairedTextModel(L.LightningModule):
78
  self.save_hyperparameters()
79
 
80
  self.approach = approach
81
- if self.approach == "cross-basic":
82
- self.model = CrossEncoderBasicModel(plm_name=plm_names[0])
83
- elif self.approach == "cross-prob":
84
- self.model = CrossEncoderProbModel(plm_name=plm_names[0])
85
- else:
86
- raise ValueError(f"Invalid approach: {self.approach}")
87
 
88
  self.lr = lr
89
  self.log_dir = log_dir
@@ -101,6 +93,7 @@ class LitPairedTextModel(L.LightningModule):
101
  self.test_outputs = []
102
 
103
  def forward(self, batch: dict) -> tuple[Tensor, Tensor, Tensor]:
 
104
  means, varss, hidden_states = [], [], []
105
 
106
  for _ in range(self.num_passes):
 
10
  )
11
  import logging
12
 
 
 
 
13
  logger = logging.getLogger(__name__)
14
 
15
  class CrossEncoderProbModel(torch.nn.Module):
 
75
  self.save_hyperparameters()
76
 
77
  self.approach = approach
78
+ self.model = CrossEncoderProbModel(plm_name=plm_names[0])
 
 
 
 
 
79
 
80
  self.lr = lr
81
  self.log_dir = log_dir
 
93
  self.test_outputs = []
94
 
95
  def forward(self, batch: dict) -> tuple[Tensor, Tensor, Tensor]:
96
+ self._enable_dropout_at_inference()
97
  means, varss, hidden_states = [], [], []
98
 
99
  for _ in range(self.num_passes):