Spaces:
Running
Running
bugfix
Browse files- __pycache__/harim_scorer.cpython-39.pyc +0 -0
- harim_plus.py +3 -2
- harim_scorer.py +3 -2
__pycache__/harim_scorer.cpython-39.pyc
CHANGED
|
Binary files a/__pycache__/harim_scorer.cpython-39.pyc and b/__pycache__/harim_scorer.cpython-39.pyc differ
|
|
|
harim_plus.py
CHANGED
|
@@ -207,18 +207,19 @@ class Harimplus_Scorer:
|
|
| 207 |
emp_in = emp_in.to(self._device)
|
| 208 |
tgt_in = tgt_in.to(self._device)
|
| 209 |
tgt_mask = tgt_mask.to(self._device)
|
|
|
|
| 210 |
|
| 211 |
with torch.no_grad():
|
| 212 |
# token_type_ids attribute causes error
|
| 213 |
s2s_logits = self._encdec_model.forward(
|
| 214 |
input_ids = src_in.input_ids,
|
| 215 |
attention_mask = src_in.attention_mask,
|
| 216 |
-
labels = tgt_in.input_ids,
|
| 217 |
return_dict=True).logits
|
| 218 |
lm_logits = self._encdec_model.forward(
|
| 219 |
input_ids = emp_in.input_ids,
|
| 220 |
attention_mask = emp_in.attention_mask,
|
| 221 |
-
labels = tgt_in.input_ids,
|
| 222 |
return_dict=True).logits
|
| 223 |
sent_lengths = tgt_mask.sum(-1)
|
| 224 |
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|
|
|
|
| 207 |
emp_in = emp_in.to(self._device)
|
| 208 |
tgt_in = tgt_in.to(self._device)
|
| 209 |
tgt_mask = tgt_mask.to(self._device)
|
| 210 |
+
fill_ignore_mask = ~(tgt_mask.bool())
|
| 211 |
|
| 212 |
with torch.no_grad():
|
| 213 |
# token_type_ids attribute causes error
|
| 214 |
s2s_logits = self._encdec_model.forward(
|
| 215 |
input_ids = src_in.input_ids,
|
| 216 |
attention_mask = src_in.attention_mask,
|
| 217 |
+
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
| 218 |
return_dict=True).logits
|
| 219 |
lm_logits = self._encdec_model.forward(
|
| 220 |
input_ids = emp_in.input_ids,
|
| 221 |
attention_mask = emp_in.attention_mask,
|
| 222 |
+
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
| 223 |
return_dict=True).logits
|
| 224 |
sent_lengths = tgt_mask.sum(-1)
|
| 225 |
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|
harim_scorer.py
CHANGED
|
@@ -141,18 +141,19 @@ class Harimplus_Scorer:
|
|
| 141 |
emp_in = emp_in.to(self._device)
|
| 142 |
tgt_in = tgt_in.to(self._device)
|
| 143 |
tgt_mask = tgt_mask.to(self._device)
|
|
|
|
| 144 |
|
| 145 |
with torch.no_grad():
|
| 146 |
# token_type_ids attribute causes error
|
| 147 |
s2s_logits = self._encdec_model.forward(
|
| 148 |
input_ids = src_in.input_ids,
|
| 149 |
attention_mask = src_in.attention_mask,
|
| 150 |
-
labels = tgt_in.input_ids,
|
| 151 |
return_dict=True).logits
|
| 152 |
lm_logits = self._encdec_model.forward(
|
| 153 |
input_ids = emp_in.input_ids,
|
| 154 |
attention_mask = emp_in.attention_mask,
|
| 155 |
-
labels = tgt_in.input_ids,
|
| 156 |
return_dict=True).logits
|
| 157 |
sent_lengths = tgt_mask.sum(-1)
|
| 158 |
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|
|
|
|
| 141 |
emp_in = emp_in.to(self._device)
|
| 142 |
tgt_in = tgt_in.to(self._device)
|
| 143 |
tgt_mask = tgt_mask.to(self._device)
|
| 144 |
+
fill_ignore_mask = ~(tgt_mask.bool())
|
| 145 |
|
| 146 |
with torch.no_grad():
|
| 147 |
# token_type_ids attribute causes error
|
| 148 |
s2s_logits = self._encdec_model.forward(
|
| 149 |
input_ids = src_in.input_ids,
|
| 150 |
attention_mask = src_in.attention_mask,
|
| 151 |
+
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
| 152 |
return_dict=True).logits
|
| 153 |
lm_logits = self._encdec_model.forward(
|
| 154 |
input_ids = emp_in.input_ids,
|
| 155 |
attention_mask = emp_in.attention_mask,
|
| 156 |
+
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
| 157 |
return_dict=True).logits
|
| 158 |
sent_lengths = tgt_mask.sum(-1)
|
| 159 |
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|