Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| from pathlib import Path | |
| from vq_gan_3d import load_VQGAN | |
| # Numerical computing | |
| import numpy as np | |
| # PyTorch | |
| import torch | |
| import torch.nn.functional as F | |
| # Utilities to calculate grids from SMILES and visualization | |
| from utils import get_grid_from_smiles, plot_voxel_grid, change_grid_size | |
| ckpt_path = Path("./vq_gan_3d/weights/3DGrid-VQGAN_43.pt") | |
| folder = str(ckpt_path.parent) | |
| ckpt_file = ckpt_path.name | |
| vqgan = load_VQGAN(folder=folder, ckpt_filename=ckpt_file).eval() | |
| def comparison(SMILES): | |
| density_grids = get_grid_from_smiles([SMILES]) | |
| # 1) Prepare density grids → list of ready-to-use tensors | |
| processed_tensors = [] | |
| for item in density_grids: | |
| rho = item["rho"] # raw NumPy array from cube generation | |
| smi = item["smiles"] | |
| name = item["name"] | |
| tensor = torch.from_numpy(rho).float() # convert grid to float32 tensor | |
| tensor = torch.log1p(tensor) # apply log(ρ + 1) normalization | |
| # enforce consistent 128×128×128 input size for VQGAN | |
| if tensor.shape != torch.Size([128, 128, 128]): | |
| tensor = tensor.unsqueeze(0).unsqueeze(0) # add batch & channel dims | |
| tensor = F.interpolate( | |
| tensor, | |
| size=(128, 128, 128), | |
| mode="trilinear", | |
| align_corners=False | |
| )[0, 0] # remove extra dims after resizing | |
| print(f"[info] {smi} was interpolated to 128³") | |
| # store metadata alongside the processed tensor | |
| processed_tensors.append({ | |
| "name": name, | |
| "smiles": smi, | |
| "tensor": tensor | |
| }) | |
| # log shape and min/max to verify normalization and sizing | |
| print( | |
| f"{smi}: shape={tuple(tensor.shape)}, " | |
| f"min={tensor.min():.4f}, max={tensor.max():.4f}" | |
| ) | |
| # 2) Encode → Decode (inference with VQGAN) | |
| reconstructions = [] | |
| for item in processed_tensors: | |
| smi = item["smiles"] # original SMILES string | |
| name = item["name"] # unique grid name | |
| vol = item["tensor"] # preprocessed [128³] tensor | |
| # add batch & channel dims and move to the selected device | |
| x = vol.unsqueeze(0).unsqueeze(0) # shape [1,1,128,128,128] | |
| with torch.no_grad(): # disable gradient computation for faster inference | |
| indices = vqgan.encode(x) # map input volume to discrete latent codes | |
| recon = vqgan.decode(indices) # reconstruct volume from latent codes | |
| # convert reconstructed tensor and original tensor to NumPy arrays | |
| recon_np = recon.cpu().numpy()[0, 0] | |
| orig_np = vol.cpu().numpy() | |
| # compute mean squared error between original and reconstruction | |
| mse = np.mean((orig_np - recon_np) ** 2) | |
| print(f"{smi} → reconstruction done | MSE={mse:.6f}") | |
| # collect results for later visualization | |
| reconstructions.append({ | |
| "smiles": smi, | |
| "name": name, | |
| "original": orig_np, | |
| "reconstructed": recon_np | |
| }) | |
| original_grid_plot = plot_voxel_grid( | |
| change_grid_size( | |
| torch.from_numpy(reconstructions[0]["original"]).unsqueeze(0).unsqueeze(0), | |
| size=(48, 48, 48) | |
| ), | |
| title=f"Original 3D Grid Plot from {SMILES}" | |
| ) | |
| rec_grid_plot = plot_voxel_grid( | |
| change_grid_size( | |
| torch.from_numpy(reconstructions[0]["reconstructed"]).unsqueeze(0).unsqueeze(0), | |
| size=(48, 48, 48) | |
| ), | |
| title=f"Reconstructed 3D Grid Plot from {SMILES}" | |
| ) | |
| np.save("original_grid.npy", reconstructions[0]["original"]) | |
| np.save("reconstructed_grid.npy", reconstructions[0]["reconstructed"]) | |
| original_grid_plot.savefig("original_grid_plot.png", format='png') | |
| rec_grid_plot.savefig("reconstructed_grid_plot.png", format='png') | |
| original_grid_plot = Image.open("original_grid_plot.png") | |
| rec_grid_plot = Image.open("reconstructed_grid_plot.png") | |
| return [original_grid_plot, rec_grid_plot], mse, "original_grid.npy", "reconstructed_grid.npy" | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # 3DGrid-VQGAN SMILES to 3D Grid Reconstruction | |
| **Single mode:** paste a SMILES string in the left box. | |
| **Batch mode:** upload a CSV file where each row has a SMILES in the first column. | |
| - **Maximum 1000 SMILES per batch.** Processing time increases with batch size due to Hugging Face environment limits. | |
| _This is just a demo environment; for heavy-duty usage, please visit:_ | |
| https://github.com/IBM/materials/tree/main/models/smi_ted | |
| to download the model and run your own experiments. | |
| - In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns. | |
| """ | |
| ) | |
| gr.Interface( | |
| fn=comparison, | |
| inputs=[ | |
| gr.Dropdown(choices=["CCCO", "CC", "CCO"], label="Provide a SMILES or pre-select one", allow_custom_value=True) | |
| ], | |
| outputs=[ | |
| gr.Gallery(label="3D Grid Reconstruction Comparison", columns=2), | |
| gr.Number(label="MSE"), | |
| gr.File(label="Original 3D Grid numpy file"), | |
| gr.File(label="Reconstructed 3D Grid numpy file") | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0") |