Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -140,7 +140,8 @@ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
|
| 140 |
no_cuda=False,
|
| 141 |
seed=8893,
|
| 142 |
fp16=True,
|
| 143 |
-
report_to='wandb'
|
|
|
|
| 144 |
)
|
| 145 |
|
| 146 |
# Initialize Trainer
|
|
@@ -160,6 +161,8 @@ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
|
| 160 |
trainer.save_model(save_path)
|
| 161 |
tokenizer.save_pretrained(save_path)
|
| 162 |
|
|
|
|
|
|
|
| 163 |
# Load the data from pickle files (replace with your local paths)
|
| 164 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
| 165 |
train_sequences = pickle.load(f)
|
|
@@ -217,6 +220,9 @@ inputs = tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_l
|
|
| 217 |
with torch.no_grad():
|
| 218 |
logits = loaded_model(**inputs).logits
|
| 219 |
|
|
|
|
|
|
|
|
|
|
| 220 |
# Get predictions
|
| 221 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 222 |
predictions = torch.argmax(logits, dim=2)
|
|
@@ -233,7 +239,7 @@ for token, prediction in zip(tokens, predictions[0].numpy()):
|
|
| 233 |
print((token, id2label[prediction]))
|
| 234 |
|
| 235 |
# debug result
|
| 236 |
-
dubug_result = predictions #class_weights
|
| 237 |
|
| 238 |
demo = gr.Blocks(title="DEMO FOR ESM2Bind")
|
| 239 |
|
|
|
|
| 140 |
no_cuda=False,
|
| 141 |
seed=8893,
|
| 142 |
fp16=True,
|
| 143 |
+
#report_to='wandb'
|
| 144 |
+
report_to=None
|
| 145 |
)
|
| 146 |
|
| 147 |
# Initialize Trainer
|
|
|
|
| 161 |
trainer.save_model(save_path)
|
| 162 |
tokenizer.save_pretrained(save_path)
|
| 163 |
|
| 164 |
+
return save_path
|
| 165 |
+
|
| 166 |
# Load the data from pickle files (replace with your local paths)
|
| 167 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
| 168 |
train_sequences = pickle.load(f)
|
|
|
|
| 220 |
with torch.no_grad():
|
| 221 |
logits = loaded_model(**inputs).logits
|
| 222 |
|
| 223 |
+
# train
|
| 224 |
+
saved_path = train_function_no_sweeps(base_model_path,train_dataset, test_dataset)
|
| 225 |
+
|
| 226 |
# Get predictions
|
| 227 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 228 |
predictions = torch.argmax(logits, dim=2)
|
|
|
|
| 239 |
print((token, id2label[prediction]))
|
| 240 |
|
| 241 |
# debug result
|
| 242 |
+
dubug_result = saved_path #predictions #class_weights
|
| 243 |
|
| 244 |
demo = gr.Blocks(title="DEMO FOR ESM2Bind")
|
| 245 |
|