Spaces:
Running
Running
Update harim_plus.py
Browse files- harim_plus.py +3 -3
harim_plus.py
CHANGED
|
@@ -150,7 +150,7 @@ class Harimplus_Scorer:
|
|
| 150 |
idx=0
|
| 151 |
minibatches = []
|
| 152 |
while True:
|
| 153 |
-
start =
|
| 154 |
end = idx+bsz
|
| 155 |
if start >= len(exs):
|
| 156 |
break
|
|
@@ -232,10 +232,10 @@ class Harimplus_Scorer:
|
|
| 232 |
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
| 233 |
return_dict=True).logits
|
| 234 |
sent_lengths = tgt_mask.sum(-1)
|
| 235 |
-
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
|
| 236 |
ll = ll_tok.sum(-1) / sent_lengths
|
| 237 |
|
| 238 |
-
harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
|
| 239 |
harim = harim_tok.sum(-1) / sent_lengths
|
| 240 |
|
| 241 |
harim_plus_normalized = (ll + self._lambda * harim) # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
|
|
|
|
| 150 |
idx=0
|
| 151 |
minibatches = []
|
| 152 |
while True:
|
| 153 |
+
start = id
|
| 154 |
end = idx+bsz
|
| 155 |
if start >= len(exs):
|
| 156 |
break
|
|
|
|
| 232 |
labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
|
| 233 |
return_dict=True).logits
|
| 234 |
sent_lengths = tgt_mask.sum(-1)
|
| 235 |
+
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, 1)#tgt_mask)
|
| 236 |
ll = ll_tok.sum(-1) / sent_lengths
|
| 237 |
|
| 238 |
+
harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, 1)#tgt_mask)
|
| 239 |
harim = harim_tok.sum(-1) / sent_lengths
|
| 240 |
|
| 241 |
harim_plus_normalized = (ll + self._lambda * harim) # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
|