vshirasuna's picture
Add application files
0d70aa0
raw
history blame
5.61 kB
import gradio as gr
from PIL import Image
from pathlib import Path
from vq_gan_3d import load_VQGAN
# Numerical computing
import numpy as np
# PyTorch
import torch
import torch.nn.functional as F
# Utilities to calculate grids from SMILES and visualization
from utils import get_grid_from_smiles, plot_voxel_grid, change_grid_size
ckpt_path = Path("./vq_gan_3d/weights/3DGrid-VQGAN_43.pt")
folder = str(ckpt_path.parent)
ckpt_file = ckpt_path.name
vqgan = load_VQGAN(folder=folder, ckpt_filename=ckpt_file).eval()
def comparison(SMILES):
density_grids = get_grid_from_smiles([SMILES])
# 1) Prepare density grids → list of ready-to-use tensors
processed_tensors = []
for item in density_grids:
rho = item["rho"] # raw NumPy array from cube generation
smi = item["smiles"]
name = item["name"]
tensor = torch.from_numpy(rho).float() # convert grid to float32 tensor
tensor = torch.log1p(tensor) # apply log(ρ + 1) normalization
# enforce consistent 128×128×128 input size for VQGAN
if tensor.shape != torch.Size([128, 128, 128]):
tensor = tensor.unsqueeze(0).unsqueeze(0) # add batch & channel dims
tensor = F.interpolate(
tensor,
size=(128, 128, 128),
mode="trilinear",
align_corners=False
)[0, 0] # remove extra dims after resizing
print(f"[info] {smi} was interpolated to 128³")
# store metadata alongside the processed tensor
processed_tensors.append({
"name": name,
"smiles": smi,
"tensor": tensor
})
# log shape and min/max to verify normalization and sizing
print(
f"{smi}: shape={tuple(tensor.shape)}, "
f"min={tensor.min():.4f}, max={tensor.max():.4f}"
)
# 2) Encode → Decode (inference with VQGAN)
reconstructions = []
for item in processed_tensors:
smi = item["smiles"] # original SMILES string
name = item["name"] # unique grid name
vol = item["tensor"] # preprocessed [128³] tensor
# add batch & channel dims and move to the selected device
x = vol.unsqueeze(0).unsqueeze(0) # shape [1,1,128,128,128]
with torch.no_grad(): # disable gradient computation for faster inference
indices = vqgan.encode(x) # map input volume to discrete latent codes
recon = vqgan.decode(indices) # reconstruct volume from latent codes
# convert reconstructed tensor and original tensor to NumPy arrays
recon_np = recon.cpu().numpy()[0, 0]
orig_np = vol.cpu().numpy()
# compute mean squared error between original and reconstruction
mse = np.mean((orig_np - recon_np) ** 2)
print(f"{smi} → reconstruction done | MSE={mse:.6f}")
# collect results for later visualization
reconstructions.append({
"smiles": smi,
"name": name,
"original": orig_np,
"reconstructed": recon_np
})
original_grid_plot = plot_voxel_grid(
change_grid_size(
torch.from_numpy(reconstructions[0]["original"]).unsqueeze(0).unsqueeze(0),
size=(48, 48, 48)
),
title=f"Original 3D Grid Plot from {SMILES}"
)
rec_grid_plot = plot_voxel_grid(
change_grid_size(
torch.from_numpy(reconstructions[0]["reconstructed"]).unsqueeze(0).unsqueeze(0),
size=(48, 48, 48)
),
title=f"Reconstructed 3D Grid Plot from {SMILES}"
)
np.save("original_grid.npy", reconstructions[0]["original"])
np.save("reconstructed_grid.npy", reconstructions[0]["reconstructed"])
original_grid_plot.savefig("original_grid_plot.png", format='png')
rec_grid_plot.savefig("reconstructed_grid_plot.png", format='png')
original_grid_plot = Image.open("original_grid_plot.png")
rec_grid_plot = Image.open("reconstructed_grid_plot.png")
return [original_grid_plot, rec_grid_plot], mse, "original_grid.npy", "reconstructed_grid.npy"
with gr.Blocks() as demo:
gr.Markdown(
"""
# 3DGrid-VQGAN SMILES to 3D Grid Reconstruction
**Single mode:** paste a SMILES string in the left box.
**Batch mode:** upload a CSV file where each row has a SMILES in the first column.
- **Maximum 1000 SMILES per batch.** Processing time increases with batch size due to Hugging Face environment limits.
_This is just a demo environment; for heavy-duty usage, please visit:_
https://github.com/IBM/materials/tree/main/models/smi_ted
to download the model and run your own experiments.
- In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
"""
)
gr.Interface(
fn=comparison,
inputs=[
gr.Dropdown(choices=["CCCO", "CC", "CCO"], label="Provide a SMILES or pre-select one", allow_custom_value=True)
],
outputs=[
gr.Gallery(label="3D Grid Reconstruction Comparison", columns=2),
gr.Number(label="MSE"),
gr.File(label="Original 3D Grid numpy file"),
gr.File(label="Reconstructed 3D Grid numpy file")
]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")