move model run
Browse files
app.py
CHANGED
|
@@ -139,7 +139,14 @@ def handle_name(name=None, pdb_input=None, model_version="ESM3"):
|
|
| 139 |
pdb_name = str(random.randint(0, 100000))
|
| 140 |
return f'{pdb_name}-Dyna1{"" if model_version == "ESM3" else "-ESM2"}'
|
| 141 |
|
| 142 |
-
@spaces.GPU(duration=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=False, model_version="ESM3", name=None, oauth_token: Optional[str] = None):
|
| 144 |
try:
|
| 145 |
# Validate ESM2 requires sequence
|
|
@@ -151,7 +158,6 @@ def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=Fa
|
|
| 151 |
|
| 152 |
base_name = handle_name(name, pdb_input, model_version)
|
| 153 |
|
| 154 |
-
|
| 155 |
if model_version == "ESM3":
|
| 156 |
model = ESM_model(method='esm3').to(DEVICE)
|
| 157 |
model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1.pt', map_location=DEVICE), strict=False)
|
|
@@ -187,11 +193,9 @@ def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=Fa
|
|
| 187 |
|
| 188 |
if not (sequence or (pdb_input and model_version == "ESM3")):
|
| 189 |
raise ValueError('Please provide a sequence' + (' or structure input' if model_version == "ESM3" else ''))
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
else:
|
| 194 |
-
logits = model(seq_input, sequence_id)
|
| 195 |
probabilities = utils.prob_adjusted(logits).cpu().detach().numpy()
|
| 196 |
|
| 197 |
seq_to_use = sequence if sequence else pdb_seq if pdb_input else sequence
|
|
|
|
| 139 |
pdb_name = str(random.randint(0, 100000))
|
| 140 |
return f'{pdb_name}-Dyna1{"" if model_version == "ESM3" else "-ESM2"}'
|
| 141 |
|
| 142 |
+
@spaces.GPU(duration=300)
|
| 143 |
+
def run_model(model_version='ESM2', seq_input=None, struct_input=None, sequence_id=None):
|
| 144 |
+
if model_version == "ESM3":
|
| 145 |
+
logits = model((seq_input, struct_input), sequence_id)
|
| 146 |
+
else:
|
| 147 |
+
logits = model(seq_input, sequence_id)
|
| 148 |
+
return logits
|
| 149 |
+
|
| 150 |
def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=False, model_version="ESM3", name=None, oauth_token: Optional[str] = None):
|
| 151 |
try:
|
| 152 |
# Validate ESM2 requires sequence
|
|
|
|
| 158 |
|
| 159 |
base_name = handle_name(name, pdb_input, model_version)
|
| 160 |
|
|
|
|
| 161 |
if model_version == "ESM3":
|
| 162 |
model = ESM_model(method='esm3').to(DEVICE)
|
| 163 |
model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1.pt', map_location=DEVICE), strict=False)
|
|
|
|
| 193 |
|
| 194 |
if not (sequence or (pdb_input and model_version == "ESM3")):
|
| 195 |
raise ValueError('Please provide a sequence' + (' or structure input' if model_version == "ESM3" else ''))
|
| 196 |
+
|
| 197 |
+
logits = run_model(model_version, seq_input, struct_input, sequence_id)
|
| 198 |
+
|
|
|
|
|
|
|
| 199 |
probabilities = utils.prob_adjusted(logits).cpu().detach().numpy()
|
| 200 |
|
| 201 |
seq_to_use = sequence if sequence else pdb_seq if pdb_input else sequence
|