Spaces:
Running
Running
Added normalization of SMILES
Browse files
models/smi_ted/smi_ted_light/load.py
CHANGED
|
@@ -19,6 +19,13 @@ from huggingface_hub import hf_hub_download
|
|
| 19 |
# Data
|
| 20 |
import numpy as np
|
| 21 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Standard library
|
| 24 |
from functools import partial
|
|
@@ -30,6 +37,17 @@ from tqdm import tqdm
|
|
| 30 |
tqdm.pandas()
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
class MolTranBertTokenizer(BertTokenizer):
|
| 34 |
def __init__(self, vocab_file: str = '',
|
| 35 |
do_lower_case=False,
|
|
@@ -477,9 +495,17 @@ class Smi_ted(nn.Module):
|
|
| 477 |
if self.is_cuda_available:
|
| 478 |
self.encoder.cuda()
|
| 479 |
self.decoder.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
# tokenizer
|
| 482 |
-
idx, mask = self.tokenize(smiles)
|
| 483 |
|
| 484 |
###########
|
| 485 |
# Encoder #
|
|
@@ -515,6 +541,30 @@ class Smi_ted(nn.Module):
|
|
| 515 |
# reconstruct tokens
|
| 516 |
pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
|
| 517 |
pred_ids = torch.argmax(pred_ids, axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
return ((true_ids, pred_ids), # tokens
|
| 520 |
(true_cte, pred_cte), # token embeddings
|
|
@@ -548,9 +598,14 @@ class Smi_ted(nn.Module):
|
|
| 548 |
|
| 549 |
# handle single str or a list of str
|
| 550 |
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
# process in batches
|
|
|
|
| 554 |
embeddings = [
|
| 555 |
self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
|
| 556 |
for batch in tqdm(np.array_split(smiles, n_split))
|
|
@@ -562,8 +617,13 @@ class Smi_ted(nn.Module):
|
|
| 562 |
torch.cuda.empty_cache()
|
| 563 |
gc.collect()
|
| 564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
if return_torch:
|
| 566 |
-
return torch.tensor(
|
| 567 |
return pd.DataFrame(flat_list)
|
| 568 |
|
| 569 |
def decode(self, smiles_embeddings):
|
|
@@ -607,6 +667,7 @@ def load_smi_ted(folder="./smi_ted_light",
|
|
| 607 |
):
|
| 608 |
tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
|
| 609 |
model = Smi_ted(tokenizer)
|
|
|
|
| 610 |
repo_id = "ibm/materials.smi-ted"
|
| 611 |
filename = "smi-ted-Light_40.pt"
|
| 612 |
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
|
@@ -614,6 +675,4 @@ def load_smi_ted(folder="./smi_ted_light",
|
|
| 614 |
model.eval()
|
| 615 |
print('Vocab size:', len(tokenizer.vocab))
|
| 616 |
print(f'[INFERENCE MODE - {str(model)}]')
|
| 617 |
-
return model
|
| 618 |
-
|
| 619 |
-
|
|
|
|
| 19 |
# Data
|
| 20 |
import numpy as np
|
| 21 |
import pandas as pd
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
# Chemistry
|
| 25 |
+
from rdkit import Chem
|
| 26 |
+
from rdkit.Chem import PandasTools
|
| 27 |
+
from rdkit.Chem import Descriptors
|
| 28 |
+
PandasTools.RenderImagesInAllDataFrames(True)
|
| 29 |
|
| 30 |
# Standard library
|
| 31 |
from functools import partial
|
|
|
|
| 37 |
tqdm.pandas()
|
| 38 |
|
| 39 |
|
| 40 |
+
# function to canonicalize SMILES
|
| 41 |
+
def normalize_smiles(smi, canonical=True, isomeric=False):
|
| 42 |
+
try:
|
| 43 |
+
normalized = Chem.MolToSmiles(
|
| 44 |
+
Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
|
| 45 |
+
)
|
| 46 |
+
except:
|
| 47 |
+
normalized = None
|
| 48 |
+
return normalized
|
| 49 |
+
|
| 50 |
+
|
| 51 |
class MolTranBertTokenizer(BertTokenizer):
|
| 52 |
def __init__(self, vocab_file: str = '',
|
| 53 |
do_lower_case=False,
|
|
|
|
| 495 |
if self.is_cuda_available:
|
| 496 |
self.encoder.cuda()
|
| 497 |
self.decoder.cuda()
|
| 498 |
+
|
| 499 |
+
# handle single str or a list of str
|
| 500 |
+
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
|
| 501 |
+
|
| 502 |
+
# SMILES normalization
|
| 503 |
+
smiles = smiles.apply(normalize_smiles)
|
| 504 |
+
null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
|
| 505 |
+
smiles = smiles.dropna()
|
| 506 |
|
| 507 |
# tokenizer
|
| 508 |
+
idx, mask = self.tokenize(smiles.to_list())
|
| 509 |
|
| 510 |
###########
|
| 511 |
# Encoder #
|
|
|
|
| 541 |
# reconstruct tokens
|
| 542 |
pred_ids = self.decoder.lang_model(pred_cte.view(-1, self.max_len, self.n_embd))
|
| 543 |
pred_ids = torch.argmax(pred_ids, axis=-1)
|
| 544 |
+
|
| 545 |
+
# replacing null SMILES with NaN values
|
| 546 |
+
for idx in null_idx:
|
| 547 |
+
true_ids = true_ids.tolist()
|
| 548 |
+
pred_ids = pred_ids.tolist()
|
| 549 |
+
true_cte = true_cte.tolist()
|
| 550 |
+
pred_cte = pred_cte.tolist()
|
| 551 |
+
true_set = true_set.tolist()
|
| 552 |
+
pred_set = pred_set.tolist()
|
| 553 |
+
|
| 554 |
+
true_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
|
| 555 |
+
pred_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
|
| 556 |
+
true_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
|
| 557 |
+
pred_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
|
| 558 |
+
true_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
|
| 559 |
+
pred_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
|
| 560 |
+
|
| 561 |
+
if len(null_idx) > 0:
|
| 562 |
+
true_ids = torch.tensor(true_ids)
|
| 563 |
+
pred_ids = torch.tensor(pred_ids)
|
| 564 |
+
true_cte = torch.tensor(true_cte)
|
| 565 |
+
pred_cte = torch.tensor(pred_cte)
|
| 566 |
+
true_set = torch.tensor(true_set)
|
| 567 |
+
pred_set = torch.tensor(pred_set)
|
| 568 |
|
| 569 |
return ((true_ids, pred_ids), # tokens
|
| 570 |
(true_cte, pred_cte), # token embeddings
|
|
|
|
| 598 |
|
| 599 |
# handle single str or a list of str
|
| 600 |
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
|
| 601 |
+
|
| 602 |
+
# SMILES normalization
|
| 603 |
+
smiles = smiles.apply(normalize_smiles)
|
| 604 |
+
null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
|
| 605 |
+
smiles = smiles.dropna()
|
| 606 |
+
|
| 607 |
# process in batches
|
| 608 |
+
n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
|
| 609 |
embeddings = [
|
| 610 |
self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
|
| 611 |
for batch in tqdm(np.array_split(smiles, n_split))
|
|
|
|
| 617 |
torch.cuda.empty_cache()
|
| 618 |
gc.collect()
|
| 619 |
|
| 620 |
+
# replacing null SMILES with NaN values
|
| 621 |
+
for idx in null_idx:
|
| 622 |
+
flat_list.insert(idx, np.array([np.nan]*self.config['n_embd']))
|
| 623 |
+
flat_list = np.asarray(flat_list)
|
| 624 |
+
|
| 625 |
if return_torch:
|
| 626 |
+
return torch.tensor(flat_list)
|
| 627 |
return pd.DataFrame(flat_list)
|
| 628 |
|
| 629 |
def decode(self, smiles_embeddings):
|
|
|
|
| 667 |
):
|
| 668 |
tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
|
| 669 |
model = Smi_ted(tokenizer)
|
| 670 |
+
|
| 671 |
repo_id = "ibm/materials.smi-ted"
|
| 672 |
filename = "smi-ted-Light_40.pt"
|
| 673 |
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
|
|
|
| 675 |
model.eval()
|
| 676 |
print('Vocab size:', len(tokenizer.vocab))
|
| 677 |
print(f'[INFERENCE MODE - {str(model)}]')
|
| 678 |
+
return model
|
|
|
|
|
|