Spaces:
Running
Running
victor-shirasuna
commited on
Commit
·
0011da8
1
Parent(s):
bd5808a
Fix missing utils.py
Browse files
utils.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Standard library
|
| 2 |
+
from heapq import nsmallest
|
| 3 |
+
|
| 4 |
+
# Numerical computing & plotting
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
|
| 8 |
+
# PySCF core (quantum chemistry)
|
| 9 |
+
from pyscf import lib
|
| 10 |
+
from pyscf import gto, scf, dft
|
| 11 |
+
from pyscf.dft import numint, gen_grid
|
| 12 |
+
from pyscf.tools import cubegen
|
| 13 |
+
from pyscf.geomopt.berny_solver import optimize
|
| 14 |
+
from pyscf.semiempirical import mindo3
|
| 15 |
+
|
| 16 |
+
# PyTorch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
# RDKit (molecule parsing & conformer generation)
|
| 20 |
+
from rdkit import Chem
|
| 21 |
+
from rdkit.Chem import rdDistGeom, AllChem
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def change_grid_size(tensor, size):
|
| 25 |
+
new_grid = F.interpolate(
|
| 26 |
+
tensor,
|
| 27 |
+
size=size,
|
| 28 |
+
mode="trilinear",
|
| 29 |
+
align_corners=False
|
| 30 |
+
)
|
| 31 |
+
return new_grid
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_grid_from_smiles(data_smi_l):
|
| 35 |
+
density_grids = []
|
| 36 |
+
for smi_it in data_smi_l:
|
| 37 |
+
fin_tmp_l = []
|
| 38 |
+
|
| 39 |
+
mol_it = Chem.MolFromSmiles(smi_it, sanitize=True)
|
| 40 |
+
|
| 41 |
+
# normalize to canonical SMILES for bookkeeping
|
| 42 |
+
if mol_it != None:
|
| 43 |
+
try:
|
| 44 |
+
can_smi_it = Chem.MolToSmiles(mol_it, kekuleSmiles=True)
|
| 45 |
+
except:
|
| 46 |
+
can_smi_it = Chem.MolToSmiles(mol_it, kekuleSmiles=False)
|
| 47 |
+
|
| 48 |
+
print('\n molecule ', smi_it)
|
| 49 |
+
|
| 50 |
+
# Embedding with Molecular Force Field
|
| 51 |
+
# embed 50 conformations, optimize with rdkit: N_MMFF = 50
|
| 52 |
+
# select 1 most stable MMFF conformations, optimize with pyscf N_PYSCF = 1
|
| 53 |
+
|
| 54 |
+
N_MMFF = 50
|
| 55 |
+
N_PYSCF = 1
|
| 56 |
+
|
| 57 |
+
confmol = Chem.AddHs(Chem.Mol(mol_it))
|
| 58 |
+
param = rdDistGeom.ETKDGv2()
|
| 59 |
+
param.pruneRmsThresh = 0.1
|
| 60 |
+
cids = rdDistGeom.EmbedMultipleConfs(confmol, N_MMFF, param)
|
| 61 |
+
|
| 62 |
+
if len(cids) == 0:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
res = AllChem.MMFFOptimizeMoleculeConfs(confmol)
|
| 67 |
+
energies = {c: res[c][1] for c in range(len(res))}
|
| 68 |
+
opt_mols = {}
|
| 69 |
+
top_energies = {}
|
| 70 |
+
top_cids = nsmallest(N_PYSCF, energies, key=energies.get)
|
| 71 |
+
except Exception as error:
|
| 72 |
+
print('Something went wrong, MMFFOptimize')
|
| 73 |
+
|
| 74 |
+
# PySCF optimization and cube generation
|
| 75 |
+
for cid in top_cids:
|
| 76 |
+
print('\n ----> Conformer ', str(cid), '\n')
|
| 77 |
+
molstr = Chem.MolToXYZBlock(confmol, confId=cid)
|
| 78 |
+
mol = gto.M(atom='; '.join(molstr.split('\n')[2:]))
|
| 79 |
+
mf = scf.RHF(mol)
|
| 80 |
+
|
| 81 |
+
mol_eq = optimize(mf, maxsteps=200)
|
| 82 |
+
opt_mols[cid] = mol_eq
|
| 83 |
+
mol_eq_f = scf.RHF(mol_eq).run()
|
| 84 |
+
top_energies[cid] = mol_eq_f.e_tot
|
| 85 |
+
|
| 86 |
+
box0 = max(mol_eq_f.mol.atom_coords()[:, 0]) - min(mol_eq_f.mol.atom_coords()[:, 0])
|
| 87 |
+
box1 = max(mol_eq_f.mol.atom_coords()[:, 1]) - min(mol_eq_f.mol.atom_coords()[:, 1])
|
| 88 |
+
box2 = max(mol_eq_f.mol.atom_coords()[:, 2]) - min(mol_eq_f.mol.atom_coords()[:, 2])
|
| 89 |
+
n0 = 6 * (int(box0) + 2)
|
| 90 |
+
n1 = 6 * (int(box1) + 2)
|
| 91 |
+
n2 = 6 * (int(box2) + 2)
|
| 92 |
+
el_cube = cubegen.density(
|
| 93 |
+
mol_eq_f.mol,
|
| 94 |
+
f"SMILES_{data_smi_l.index(smi_it)}_{cid}.cube",
|
| 95 |
+
mol_eq_f.make_rdm1(),
|
| 96 |
+
nx=n0,
|
| 97 |
+
ny=n1,
|
| 98 |
+
nz=n2,
|
| 99 |
+
)
|
| 100 |
+
rho = el_cube
|
| 101 |
+
density_grids.append(
|
| 102 |
+
{
|
| 103 |
+
"smiles": smi_it,
|
| 104 |
+
"name": f"SMILES_{data_smi_l.index(smi_it)}_{cid}",
|
| 105 |
+
"rho": rho
|
| 106 |
+
}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return density_grids
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def plot_voxel_grid(tensor, thresholds=[0.5, 0.25, 0.125, 0.0125], title='Voxel Grid Plot'):
|
| 113 |
+
"""
|
| 114 |
+
Plots a 3D voxel grid from a tensor and shows it inline.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
tensor (torch.Tensor): input shape [1,1,D,H,W]
|
| 118 |
+
thresholds (list of float): visibility cutoffs
|
| 119 |
+
title (str): plot title
|
| 120 |
+
"""
|
| 121 |
+
# Convert to NumPy and squeeze out batch/channel dims
|
| 122 |
+
data_np = tensor.detach().squeeze().cpu().numpy()
|
| 123 |
+
|
| 124 |
+
# Build normalized grid coordinates
|
| 125 |
+
x, y, z = np.indices(np.array(data_np.shape) + 1) / (max(data_np.shape) + 1)
|
| 126 |
+
|
| 127 |
+
# Predefine colors & alpha
|
| 128 |
+
alpha = 0.3
|
| 129 |
+
colors = [
|
| 130 |
+
[1.00, 0.00, 0.00, alpha],
|
| 131 |
+
[0.75, 0.00, 0.25, alpha],
|
| 132 |
+
[0.50, 0.00, 0.50, alpha],
|
| 133 |
+
[0.25, 0.00, 0.75, alpha],
|
| 134 |
+
[0.00, 0.00, 1.00, alpha],
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
fig = plt.figure(figsize=(6,6))
|
| 138 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 139 |
+
ax.set_box_aspect([1,1,1])
|
| 140 |
+
ax.tick_params(left=False, right=False, labelleft=False,
|
| 141 |
+
labelbottom=False, bottom=False)
|
| 142 |
+
ax.grid(False)
|
| 143 |
+
if title:
|
| 144 |
+
ax.set_title(title)
|
| 145 |
+
# Plot one layer per threshold
|
| 146 |
+
for i, thr in enumerate(thresholds):
|
| 147 |
+
mask = np.clip(data_np - thr, 0, 1)
|
| 148 |
+
ax.voxels(x, y, z, mask, facecolors=colors[i % len(colors)], linewidth=0.5, alpha=alpha)
|
| 149 |
+
|
| 150 |
+
plt.tight_layout()
|
| 151 |
+
return fig
|