victor-shirasuna commited on
Commit
0011da8
·
1 Parent(s): bd5808a

Fix missing utils.py

Browse files
Files changed (1) hide show
  1. utils.py +151 -0
utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard library
2
+ from heapq import nsmallest
3
+
4
+ # Numerical computing & plotting
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+ # PySCF core (quantum chemistry)
9
+ from pyscf import lib
10
+ from pyscf import gto, scf, dft
11
+ from pyscf.dft import numint, gen_grid
12
+ from pyscf.tools import cubegen
13
+ from pyscf.geomopt.berny_solver import optimize
14
+ from pyscf.semiempirical import mindo3
15
+
16
+ # PyTorch
17
+ import torch.nn.functional as F
18
+
19
+ # RDKit (molecule parsing & conformer generation)
20
+ from rdkit import Chem
21
+ from rdkit.Chem import rdDistGeom, AllChem
22
+
23
+
24
+ def change_grid_size(tensor, size):
25
+ new_grid = F.interpolate(
26
+ tensor,
27
+ size=size,
28
+ mode="trilinear",
29
+ align_corners=False
30
+ )
31
+ return new_grid
32
+
33
+
34
+ def get_grid_from_smiles(data_smi_l):
35
+ density_grids = []
36
+ for smi_it in data_smi_l:
37
+ fin_tmp_l = []
38
+
39
+ mol_it = Chem.MolFromSmiles(smi_it, sanitize=True)
40
+
41
+ # normalize to canonical SMILES for bookkeeping
42
+ if mol_it != None:
43
+ try:
44
+ can_smi_it = Chem.MolToSmiles(mol_it, kekuleSmiles=True)
45
+ except:
46
+ can_smi_it = Chem.MolToSmiles(mol_it, kekuleSmiles=False)
47
+
48
+ print('\n molecule ', smi_it)
49
+
50
+ # Embedding with Molecular Force Field
51
+ # embed 50 conformations, optimize with rdkit: N_MMFF = 50
52
+ # select 1 most stable MMFF conformations, optimize with pyscf N_PYSCF = 1
53
+
54
+ N_MMFF = 50
55
+ N_PYSCF = 1
56
+
57
+ confmol = Chem.AddHs(Chem.Mol(mol_it))
58
+ param = rdDistGeom.ETKDGv2()
59
+ param.pruneRmsThresh = 0.1
60
+ cids = rdDistGeom.EmbedMultipleConfs(confmol, N_MMFF, param)
61
+
62
+ if len(cids) == 0:
63
+ continue
64
+
65
+ try:
66
+ res = AllChem.MMFFOptimizeMoleculeConfs(confmol)
67
+ energies = {c: res[c][1] for c in range(len(res))}
68
+ opt_mols = {}
69
+ top_energies = {}
70
+ top_cids = nsmallest(N_PYSCF, energies, key=energies.get)
71
+ except Exception as error:
72
+ print('Something went wrong, MMFFOptimize')
73
+
74
+ # PySCF optimization and cube generation
75
+ for cid in top_cids:
76
+ print('\n ----> Conformer ', str(cid), '\n')
77
+ molstr = Chem.MolToXYZBlock(confmol, confId=cid)
78
+ mol = gto.M(atom='; '.join(molstr.split('\n')[2:]))
79
+ mf = scf.RHF(mol)
80
+
81
+ mol_eq = optimize(mf, maxsteps=200)
82
+ opt_mols[cid] = mol_eq
83
+ mol_eq_f = scf.RHF(mol_eq).run()
84
+ top_energies[cid] = mol_eq_f.e_tot
85
+
86
+ box0 = max(mol_eq_f.mol.atom_coords()[:, 0]) - min(mol_eq_f.mol.atom_coords()[:, 0])
87
+ box1 = max(mol_eq_f.mol.atom_coords()[:, 1]) - min(mol_eq_f.mol.atom_coords()[:, 1])
88
+ box2 = max(mol_eq_f.mol.atom_coords()[:, 2]) - min(mol_eq_f.mol.atom_coords()[:, 2])
89
+ n0 = 6 * (int(box0) + 2)
90
+ n1 = 6 * (int(box1) + 2)
91
+ n2 = 6 * (int(box2) + 2)
92
+ el_cube = cubegen.density(
93
+ mol_eq_f.mol,
94
+ f"SMILES_{data_smi_l.index(smi_it)}_{cid}.cube",
95
+ mol_eq_f.make_rdm1(),
96
+ nx=n0,
97
+ ny=n1,
98
+ nz=n2,
99
+ )
100
+ rho = el_cube
101
+ density_grids.append(
102
+ {
103
+ "smiles": smi_it,
104
+ "name": f"SMILES_{data_smi_l.index(smi_it)}_{cid}",
105
+ "rho": rho
106
+ }
107
+ )
108
+
109
+ return density_grids
110
+
111
+
112
+ def plot_voxel_grid(tensor, thresholds=[0.5, 0.25, 0.125, 0.0125], title='Voxel Grid Plot'):
113
+ """
114
+ Plots a 3D voxel grid from a tensor and shows it inline.
115
+
116
+ Args:
117
+ tensor (torch.Tensor): input shape [1,1,D,H,W]
118
+ thresholds (list of float): visibility cutoffs
119
+ title (str): plot title
120
+ """
121
+ # Convert to NumPy and squeeze out batch/channel dims
122
+ data_np = tensor.detach().squeeze().cpu().numpy()
123
+
124
+ # Build normalized grid coordinates
125
+ x, y, z = np.indices(np.array(data_np.shape) + 1) / (max(data_np.shape) + 1)
126
+
127
+ # Predefine colors & alpha
128
+ alpha = 0.3
129
+ colors = [
130
+ [1.00, 0.00, 0.00, alpha],
131
+ [0.75, 0.00, 0.25, alpha],
132
+ [0.50, 0.00, 0.50, alpha],
133
+ [0.25, 0.00, 0.75, alpha],
134
+ [0.00, 0.00, 1.00, alpha],
135
+ ]
136
+
137
+ fig = plt.figure(figsize=(6,6))
138
+ ax = fig.add_subplot(111, projection='3d')
139
+ ax.set_box_aspect([1,1,1])
140
+ ax.tick_params(left=False, right=False, labelleft=False,
141
+ labelbottom=False, bottom=False)
142
+ ax.grid(False)
143
+ if title:
144
+ ax.set_title(title)
145
+ # Plot one layer per threshold
146
+ for i, thr in enumerate(thresholds):
147
+ mask = np.clip(data_np - thr, 0, 1)
148
+ ax.voxels(x, y, z, mask, facecolors=colors[i % len(colors)], linewidth=0.5, alpha=alpha)
149
+
150
+ plt.tight_layout()
151
+ return fig