File size: 5,612 Bytes
0d70aa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
from PIL import Image
from pathlib import Path
from vq_gan_3d import load_VQGAN

# Numerical computing
import numpy as np

# PyTorch
import torch
import torch.nn.functional as F

# Utilities to calculate grids from SMILES and visualization
from utils import get_grid_from_smiles, plot_voxel_grid, change_grid_size

ckpt_path   = Path("./vq_gan_3d/weights/3DGrid-VQGAN_43.pt")
folder      = str(ckpt_path.parent)      
ckpt_file   = ckpt_path.name
vqgan = load_VQGAN(folder=folder, ckpt_filename=ckpt_file).eval()


def comparison(SMILES):
    density_grids = get_grid_from_smiles([SMILES])

    # 1) Prepare density grids → list of ready-to-use tensors
    processed_tensors = []
    for item in density_grids:
        rho   = item["rho"]                   # raw NumPy array from cube generation
        smi   = item["smiles"]
        name  = item["name"]
    
        tensor = torch.from_numpy(rho).float() # convert grid to float32 tensor
        tensor = torch.log1p(tensor)           # apply log(ρ + 1) normalization
    
        # enforce consistent 128×128×128 input size for VQGAN
        if tensor.shape != torch.Size([128, 128, 128]):
            tensor = tensor.unsqueeze(0).unsqueeze(0)  # add batch & channel dims
            tensor = F.interpolate(
                tensor,
                size=(128, 128, 128),
                mode="trilinear",
                align_corners=False
            )[0, 0]                                    # remove extra dims after resizing
            print(f"[info] {smi} was interpolated to 128³")
    
        # store metadata alongside the processed tensor
        processed_tensors.append({
            "name":   name,
            "smiles": smi,
            "tensor": tensor
        })
    
        # log shape and min/max to verify normalization and sizing
        print(
            f"{smi}: shape={tuple(tensor.shape)}, "
            f"min={tensor.min():.4f}, max={tensor.max():.4f}"
        )

    # 2) Encode → Decode (inference with VQGAN)
    reconstructions = []
    for item in processed_tensors:
        smi  = item["smiles"]       # original SMILES string
        name = item["name"]         # unique grid name
        vol  = item["tensor"]       # preprocessed [128³] tensor
    
        # add batch & channel dims and move to the selected device
        x = vol.unsqueeze(0).unsqueeze(0)  # shape [1,1,128,128,128]
    
        with torch.no_grad():  # disable gradient computation for faster inference
            indices = vqgan.encode(x)   # map input volume to discrete latent codes
            recon   = vqgan.decode(indices)  # reconstruct volume from latent codes
    
        # convert reconstructed tensor and original tensor to NumPy arrays
        recon_np = recon.cpu().numpy()[0, 0]
        orig_np  = vol.cpu().numpy()
    
        # compute mean squared error between original and reconstruction
        mse = np.mean((orig_np - recon_np) ** 2)
        print(f"{smi} → reconstruction done | MSE={mse:.6f}")
    
        # collect results for later visualization
        reconstructions.append({
            "smiles":       smi,
            "name":         name,
            "original":     orig_np,
            "reconstructed": recon_np
        })

    original_grid_plot = plot_voxel_grid(
        change_grid_size(
            torch.from_numpy(reconstructions[0]["original"]).unsqueeze(0).unsqueeze(0), 
            size=(48, 48, 48)
        ),
        title=f"Original 3D Grid Plot from {SMILES}"
    )
    rec_grid_plot = plot_voxel_grid(
        change_grid_size(
            torch.from_numpy(reconstructions[0]["reconstructed"]).unsqueeze(0).unsqueeze(0), 
            size=(48, 48, 48)
        ),
        title=f"Reconstructed 3D Grid Plot from {SMILES}"
    )

    np.save("original_grid.npy", reconstructions[0]["original"])
    np.save("reconstructed_grid.npy", reconstructions[0]["reconstructed"])

    original_grid_plot.savefig("original_grid_plot.png", format='png')
    rec_grid_plot.savefig("reconstructed_grid_plot.png", format='png')
    original_grid_plot = Image.open("original_grid_plot.png")
    rec_grid_plot = Image.open("reconstructed_grid_plot.png")
    
    return [original_grid_plot, rec_grid_plot], mse, "original_grid.npy", "reconstructed_grid.npy" 


with gr.Blocks() as demo:
    gr.Markdown(
        """
        # 3DGrid-VQGAN SMILES to 3D Grid Reconstruction
        **Single mode:** paste a SMILES string in the left box.  
        **Batch mode:** upload a CSV file where each row has a SMILES in the first column.  
        - **Maximum 1000 SMILES per batch.** Processing time increases with batch size due to Hugging Face environment limits.  
        _This is just a demo environment; for heavy-duty usage, please visit:_  
        https://github.com/IBM/materials/tree/main/models/smi_ted  
        to download the model and run your own experiments.
        - In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
        """
    )
    gr.Interface(
        fn=comparison, 
        inputs=[
            gr.Dropdown(choices=["CCCO", "CC", "CCO"], label="Provide a SMILES or pre-select one", allow_custom_value=True)
        ], 
        outputs=[
            gr.Gallery(label="3D Grid Reconstruction Comparison", columns=2),
            gr.Number(label="MSE"),
            gr.File(label="Original 3D Grid numpy file"),
            gr.File(label="Reconstructed 3D Grid numpy file")
        ]
    )


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")