Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -26,6 +26,8 @@ from transformers import (
|
|
| 26 |
Trainer
|
| 27 |
)
|
| 28 |
|
|
|
|
|
|
|
| 29 |
from datasets import Dataset
|
| 30 |
from accelerate import Accelerator
|
| 31 |
# Imports specific to the custom peft lora model
|
|
@@ -105,10 +107,39 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
|
|
| 105 |
accelerator = Accelerator()
|
| 106 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
with demo:
|
| 112 |
-
gr.Markdown("# DEMO FOR
|
| 113 |
gr.Textbox(dubug_result)
|
| 114 |
demo.launch()
|
|
|
|
| 26 |
Trainer
|
| 27 |
)
|
| 28 |
|
| 29 |
+
from peft import PeftModel
|
| 30 |
+
|
| 31 |
from datasets import Dataset
|
| 32 |
from accelerate import Accelerator
|
| 33 |
# Imports specific to the custom peft lora model
|
|
|
|
| 107 |
accelerator = Accelerator()
|
| 108 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
| 109 |
|
| 110 |
+
# inference
|
| 111 |
+
# Path to the saved LoRA model
|
| 112 |
+
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|
| 113 |
+
# ESM2 base model
|
| 114 |
+
base_model_path = "facebook/esm2_t12_35M_UR50D"
|
| 115 |
+
|
| 116 |
+
# Load the model
|
| 117 |
+
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
|
| 118 |
+
loaded_model = PeftModel.from_pretrained(base_model, model_path)
|
| 119 |
+
|
| 120 |
+
# Ensure the model is in evaluation mode
|
| 121 |
+
loaded_model.eval()
|
| 122 |
+
|
| 123 |
+
# Protein sequence for inference
|
| 124 |
+
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
|
| 125 |
+
|
| 126 |
+
# Tokenize the sequence
|
| 127 |
+
inputs = tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
|
| 128 |
+
|
| 129 |
+
# Run the model
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
logits = loaded_model(**inputs).logits
|
| 132 |
+
|
| 133 |
+
# Get predictions
|
| 134 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 135 |
+
predictions = torch.argmax(logits, dim=2)
|
| 136 |
+
|
| 137 |
+
# debug result
|
| 138 |
+
dubug_result = predictions #class_weights
|
| 139 |
+
|
| 140 |
+
demo = gr.Blocks(title="DEMO FOR ESM2Bind")
|
| 141 |
|
| 142 |
with demo:
|
| 143 |
+
gr.Markdown("# DEMO FOR ESM2Bind")
|
| 144 |
gr.Textbox(dubug_result)
|
| 145 |
demo.launch()
|