import gradio as gr from PIL import Image from pathlib import Path from vq_gan_3d import load_VQGAN # Numerical computing import numpy as np # PyTorch import torch import torch.nn.functional as F # Utilities to calculate grids from SMILES and visualization from utils import get_grid_from_smiles, plot_voxel_grid, change_grid_size ckpt_path = Path("./vq_gan_3d/weights/3DGrid-VQGAN_43.pt") folder = str(ckpt_path.parent) ckpt_file = ckpt_path.name vqgan = load_VQGAN(folder=folder, ckpt_filename=ckpt_file).eval() def comparison(SMILES): density_grids = get_grid_from_smiles([SMILES]) # 1) Prepare density grids → list of ready-to-use tensors processed_tensors = [] for item in density_grids: rho = item["rho"] # raw NumPy array from cube generation smi = item["smiles"] name = item["name"] tensor = torch.from_numpy(rho).float() # convert grid to float32 tensor tensor = torch.log1p(tensor) # apply log(ρ + 1) normalization # enforce consistent 128×128×128 input size for VQGAN if tensor.shape != torch.Size([128, 128, 128]): tensor = tensor.unsqueeze(0).unsqueeze(0) # add batch & channel dims tensor = F.interpolate( tensor, size=(128, 128, 128), mode="trilinear", align_corners=False )[0, 0] # remove extra dims after resizing print(f"[info] {smi} was interpolated to 128³") # store metadata alongside the processed tensor processed_tensors.append({ "name": name, "smiles": smi, "tensor": tensor }) # log shape and min/max to verify normalization and sizing print( f"{smi}: shape={tuple(tensor.shape)}, " f"min={tensor.min():.4f}, max={tensor.max():.4f}" ) # 2) Encode → Decode (inference with VQGAN) reconstructions = [] for item in processed_tensors: smi = item["smiles"] # original SMILES string name = item["name"] # unique grid name vol = item["tensor"] # preprocessed [128³] tensor # add batch & channel dims and move to the selected device x = vol.unsqueeze(0).unsqueeze(0) # shape [1,1,128,128,128] with torch.no_grad(): # disable gradient computation for faster inference indices = vqgan.encode(x) # map input volume to discrete latent codes recon = vqgan.decode(indices) # reconstruct volume from latent codes # convert reconstructed tensor and original tensor to NumPy arrays recon_np = recon.cpu().numpy()[0, 0] orig_np = vol.cpu().numpy() # compute mean squared error between original and reconstruction mse = np.mean((orig_np - recon_np) ** 2) print(f"{smi} → reconstruction done | MSE={mse:.6f}") # collect results for later visualization reconstructions.append({ "smiles": smi, "name": name, "original": orig_np, "reconstructed": recon_np }) original_grid_plot = plot_voxel_grid( change_grid_size( torch.from_numpy(reconstructions[0]["original"]).unsqueeze(0).unsqueeze(0), size=(48, 48, 48) ), title=f"Original 3D Grid Plot from {SMILES}" ) rec_grid_plot = plot_voxel_grid( change_grid_size( torch.from_numpy(reconstructions[0]["reconstructed"]).unsqueeze(0).unsqueeze(0), size=(48, 48, 48) ), title=f"Reconstructed 3D Grid Plot from {SMILES}" ) np.save("original_grid.npy", reconstructions[0]["original"]) np.save("reconstructed_grid.npy", reconstructions[0]["reconstructed"]) original_grid_plot.savefig("original_grid_plot.png", format='png') rec_grid_plot.savefig("reconstructed_grid_plot.png", format='png') original_grid_plot = Image.open("original_grid_plot.png") rec_grid_plot = Image.open("reconstructed_grid_plot.png") return [original_grid_plot, rec_grid_plot], mse, "original_grid.npy", "reconstructed_grid.npy" with gr.Blocks() as demo: gr.Markdown( """ # 3DGrid-VQGAN SMILES to 3D Grid Reconstruction **Single mode:** paste a SMILES string in the left box. **Batch mode:** upload a CSV file where each row has a SMILES in the first column. - **Maximum 1000 SMILES per batch.** Processing time increases with batch size due to Hugging Face environment limits. _This is just a demo environment; for heavy-duty usage, please visit:_ https://github.com/IBM/materials/tree/main/models/smi_ted to download the model and run your own experiments. - In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns. """ ) gr.Interface( fn=comparison, inputs=[ gr.Dropdown(choices=["CCCO", "CC", "CCO"], label="Provide a SMILES or pre-select one", allow_custom_value=True) ], outputs=[ gr.Gallery(label="3D Grid Reconstruction Comparison", columns=2), gr.Number(label="MSE"), gr.File(label="Original 3D Grid numpy file"), gr.File(label="Reconstructed 3D Grid numpy file") ] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0")