Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -118,9 +118,9 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
| 118 |
base_model = get_peft_model(base_model, peft_config)
|
| 119 |
|
| 120 |
# Use the accelerator
|
| 121 |
-
base_model =
|
| 122 |
-
train_dataset =
|
| 123 |
-
test_dataset =
|
| 124 |
|
| 125 |
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
| 126 |
|
|
@@ -205,7 +205,7 @@ test_labels = truncate_labels(test_labels, max_sequence_length)
|
|
| 205 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
| 206 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
| 207 |
|
| 208 |
-
|
| 209 |
# Compute Class Weights
|
| 210 |
classes = [0, 1]
|
| 211 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
|
@@ -213,6 +213,7 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
|
|
| 213 |
accelerator = Accelerator()
|
| 214 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
| 215 |
|
|
|
|
| 216 |
# inference
|
| 217 |
# Path to the saved LoRA model
|
| 218 |
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|
|
|
|
| 118 |
base_model = get_peft_model(base_model, peft_config)
|
| 119 |
|
| 120 |
# Use the accelerator
|
| 121 |
+
base_model = accelerator.prepare(base_model)
|
| 122 |
+
train_dataset = accelerator.prepare(train_dataset)
|
| 123 |
+
test_dataset = accelerator.prepare(test_dataset)
|
| 124 |
|
| 125 |
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
| 126 |
|
|
|
|
| 205 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
| 206 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
| 207 |
|
| 208 |
+
|
| 209 |
# Compute Class Weights
|
| 210 |
classes = [0, 1]
|
| 211 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
|
|
|
| 213 |
accelerator = Accelerator()
|
| 214 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
| 215 |
|
| 216 |
+
'''
|
| 217 |
# inference
|
| 218 |
# Path to the saved LoRA model
|
| 219 |
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|