Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -73,6 +73,14 @@ def compute_loss(model, inputs):
|
|
| 73 |
loss = loss_fct(active_logits, active_labels)
|
| 74 |
return loss
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
# fine-tuning function
|
| 77 |
def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
| 78 |
|
|
|
|
| 73 |
loss = loss_fct(active_logits, active_labels)
|
| 74 |
return loss
|
| 75 |
|
| 76 |
+
# Define Custom Trainer Class
|
| 77 |
+
# Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
|
| 78 |
+
class WeightedTrainer(Trainer):
|
| 79 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
| 80 |
+
outputs = model(**inputs)
|
| 81 |
+
loss = compute_loss(model, inputs)
|
| 82 |
+
return (loss, outputs) if return_outputs else loss
|
| 83 |
+
|
| 84 |
# fine-tuning function
|
| 85 |
def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
| 86 |
|