Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -295,6 +295,10 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
|
|
| 295 |
accelerator = Accelerator()
|
| 296 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
'''
|
| 299 |
# inference
|
| 300 |
# Path to the saved LoRA model
|
|
@@ -323,14 +327,13 @@ with torch.no_grad():
|
|
| 323 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 324 |
predictions = torch.argmax(logits, dim=2)
|
| 325 |
|
| 326 |
-
|
| 327 |
# Define labels
|
| 328 |
id2label = {
|
| 329 |
0: "No binding site",
|
| 330 |
1: "Binding site"
|
| 331 |
}
|
| 332 |
|
| 333 |
-
'''
|
| 334 |
# Print the predicted labels for each token
|
| 335 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
| 336 |
if token not in ['<pad>', '<cls>', '<eos>']:
|
|
|
|
| 295 |
accelerator = Accelerator()
|
| 296 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
| 297 |
|
| 298 |
+
# Define labels and model
|
| 299 |
+
id2label = {0: "No binding site", 1: "Binding site"}
|
| 300 |
+
label2id = {v: k for k, v in id2label.items()}
|
| 301 |
+
|
| 302 |
'''
|
| 303 |
# inference
|
| 304 |
# Path to the saved LoRA model
|
|
|
|
| 327 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 328 |
predictions = torch.argmax(logits, dim=2)
|
| 329 |
|
| 330 |
+
|
| 331 |
# Define labels
|
| 332 |
id2label = {
|
| 333 |
0: "No binding site",
|
| 334 |
1: "Binding site"
|
| 335 |
}
|
| 336 |
|
|
|
|
| 337 |
# Print the predicted labels for each token
|
| 338 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
| 339 |
if token not in ['<pad>', '<cls>', '<eos>']:
|