Update app.py
Browse files
app.py
CHANGED
|
@@ -55,8 +55,8 @@ def inference(title, abstract, threshold=0.95):
|
|
| 55 |
attention_mask = encoding["attention_mask"].to(device)
|
| 56 |
|
| 57 |
with torch.no_grad():
|
| 58 |
-
res_probs = class_model(input_ids, attention_mask)
|
| 59 |
-
|
| 60 |
probs = res_probs.squeeze(0) # (8,)
|
| 61 |
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 62 |
|
|
|
|
| 55 |
attention_mask = encoding["attention_mask"].to(device)
|
| 56 |
|
| 57 |
with torch.no_grad():
|
| 58 |
+
res_probs = torch.exp(class_model(input_ids, attention_mask))
|
| 59 |
+
|
| 60 |
probs = res_probs.squeeze(0) # (8,)
|
| 61 |
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 62 |
|