vshirasuna commited on
Commit
0d70aa0
·
1 Parent(s): 775e041

Add application files

Browse files
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9.7
2
+
3
+ WORKDIR /app
4
+ COPY requirements.txt .
5
+ RUN pip install -r --no-cache-dir requirements.txt
6
+ # preload models
7
+
8
+ COPY . .
9
+
10
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from pathlib import Path
4
+ from vq_gan_3d import load_VQGAN
5
+
6
+ # Numerical computing
7
+ import numpy as np
8
+
9
+ # PyTorch
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ # Utilities to calculate grids from SMILES and visualization
14
+ from utils import get_grid_from_smiles, plot_voxel_grid, change_grid_size
15
+
16
+ ckpt_path = Path("./vq_gan_3d/weights/3DGrid-VQGAN_43.pt")
17
+ folder = str(ckpt_path.parent)
18
+ ckpt_file = ckpt_path.name
19
+ vqgan = load_VQGAN(folder=folder, ckpt_filename=ckpt_file).eval()
20
+
21
+
22
+ def comparison(SMILES):
23
+ density_grids = get_grid_from_smiles([SMILES])
24
+
25
+ # 1) Prepare density grids → list of ready-to-use tensors
26
+ processed_tensors = []
27
+ for item in density_grids:
28
+ rho = item["rho"] # raw NumPy array from cube generation
29
+ smi = item["smiles"]
30
+ name = item["name"]
31
+
32
+ tensor = torch.from_numpy(rho).float() # convert grid to float32 tensor
33
+ tensor = torch.log1p(tensor) # apply log(ρ + 1) normalization
34
+
35
+ # enforce consistent 128×128×128 input size for VQGAN
36
+ if tensor.shape != torch.Size([128, 128, 128]):
37
+ tensor = tensor.unsqueeze(0).unsqueeze(0) # add batch & channel dims
38
+ tensor = F.interpolate(
39
+ tensor,
40
+ size=(128, 128, 128),
41
+ mode="trilinear",
42
+ align_corners=False
43
+ )[0, 0] # remove extra dims after resizing
44
+ print(f"[info] {smi} was interpolated to 128³")
45
+
46
+ # store metadata alongside the processed tensor
47
+ processed_tensors.append({
48
+ "name": name,
49
+ "smiles": smi,
50
+ "tensor": tensor
51
+ })
52
+
53
+ # log shape and min/max to verify normalization and sizing
54
+ print(
55
+ f"{smi}: shape={tuple(tensor.shape)}, "
56
+ f"min={tensor.min():.4f}, max={tensor.max():.4f}"
57
+ )
58
+
59
+ # 2) Encode → Decode (inference with VQGAN)
60
+ reconstructions = []
61
+ for item in processed_tensors:
62
+ smi = item["smiles"] # original SMILES string
63
+ name = item["name"] # unique grid name
64
+ vol = item["tensor"] # preprocessed [128³] tensor
65
+
66
+ # add batch & channel dims and move to the selected device
67
+ x = vol.unsqueeze(0).unsqueeze(0) # shape [1,1,128,128,128]
68
+
69
+ with torch.no_grad(): # disable gradient computation for faster inference
70
+ indices = vqgan.encode(x) # map input volume to discrete latent codes
71
+ recon = vqgan.decode(indices) # reconstruct volume from latent codes
72
+
73
+ # convert reconstructed tensor and original tensor to NumPy arrays
74
+ recon_np = recon.cpu().numpy()[0, 0]
75
+ orig_np = vol.cpu().numpy()
76
+
77
+ # compute mean squared error between original and reconstruction
78
+ mse = np.mean((orig_np - recon_np) ** 2)
79
+ print(f"{smi} → reconstruction done | MSE={mse:.6f}")
80
+
81
+ # collect results for later visualization
82
+ reconstructions.append({
83
+ "smiles": smi,
84
+ "name": name,
85
+ "original": orig_np,
86
+ "reconstructed": recon_np
87
+ })
88
+
89
+ original_grid_plot = plot_voxel_grid(
90
+ change_grid_size(
91
+ torch.from_numpy(reconstructions[0]["original"]).unsqueeze(0).unsqueeze(0),
92
+ size=(48, 48, 48)
93
+ ),
94
+ title=f"Original 3D Grid Plot from {SMILES}"
95
+ )
96
+ rec_grid_plot = plot_voxel_grid(
97
+ change_grid_size(
98
+ torch.from_numpy(reconstructions[0]["reconstructed"]).unsqueeze(0).unsqueeze(0),
99
+ size=(48, 48, 48)
100
+ ),
101
+ title=f"Reconstructed 3D Grid Plot from {SMILES}"
102
+ )
103
+
104
+ np.save("original_grid.npy", reconstructions[0]["original"])
105
+ np.save("reconstructed_grid.npy", reconstructions[0]["reconstructed"])
106
+
107
+ original_grid_plot.savefig("original_grid_plot.png", format='png')
108
+ rec_grid_plot.savefig("reconstructed_grid_plot.png", format='png')
109
+ original_grid_plot = Image.open("original_grid_plot.png")
110
+ rec_grid_plot = Image.open("reconstructed_grid_plot.png")
111
+
112
+ return [original_grid_plot, rec_grid_plot], mse, "original_grid.npy", "reconstructed_grid.npy"
113
+
114
+
115
+ with gr.Blocks() as demo:
116
+ gr.Markdown(
117
+ """
118
+ # 3DGrid-VQGAN SMILES to 3D Grid Reconstruction
119
+ **Single mode:** paste a SMILES string in the left box.
120
+ **Batch mode:** upload a CSV file where each row has a SMILES in the first column.
121
+ - **Maximum 1000 SMILES per batch.** Processing time increases with batch size due to Hugging Face environment limits.
122
+ _This is just a demo environment; for heavy-duty usage, please visit:_
123
+ https://github.com/IBM/materials/tree/main/models/smi_ted
124
+ to download the model and run your own experiments.
125
+ - In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
126
+ """
127
+ )
128
+ gr.Interface(
129
+ fn=comparison,
130
+ inputs=[
131
+ gr.Dropdown(choices=["CCCO", "CC", "CCO"], label="Provide a SMILES or pre-select one", allow_custom_value=True)
132
+ ],
133
+ outputs=[
134
+ gr.Gallery(label="3D Grid Reconstruction Comparison", columns=2),
135
+ gr.Number(label="MSE"),
136
+ gr.File(label="Original 3D Grid numpy file"),
137
+ gr.File(label="Reconstructed 3D Grid numpy file")
138
+ ]
139
+ )
140
+
141
+
142
+ if __name__ == "__main__":
143
+ demo.launch(server_name="0.0.0.0")
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ numpy==1.26.4
3
+ pandas==1.4.0
4
+ gradio>=4.33.1
5
+ huggingface-hub
6
+ pydantic==2.10.6
7
+ rdkit>=2024.3.5
8
+ imageio==2.34.1
9
+ hydra-core==1.3.2
10
+ omegaconf==2.3.0
11
+ pdbpp==0.10.2
12
+ torchvision==0.18.0
13
+ tqdm==4.66.4
14
+ requests>=2.32.0
15
+ lpips==0.1.4
16
+ pyscf>=2.10.0
17
+ pyberny>=0.6.3
18
+ git+https://github.com/pyscf/semiempirical
vq_gan_3d/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from vq_gan_3d.model import VQGAN, load_VQGAN
2
+ from vq_gan_3d.model import Codebook
3
+ from vq_gan_3d.model import LPIPS
4
+ from vq_gan_3d.dataset import VQGANDataset
5
+ from vq_gan_3d.utils import get_single_device
vq_gan_3d/dataset/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # from dataset.breast_uka import BreastUKA
2
+ # from dataset.mrnet import MRNetDataset
3
+ # from dataset.brats import BRATSDataset
4
+ # from dataset.adni import ADNIDataset
5
+ # from dataset.duke import DUKEDataset
6
+ # from dataset.lidc import LIDCDataset
7
+ from vq_gan_3d.dataset.default import VQGANDataset
vq_gan_3d/dataset/default.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torch.nn.functional as F
6
+ import torch.multiprocessing as mp
7
+
8
+
9
+ class VQGANDataset(Dataset):
10
+ def __init__(self, root_dir: str, file_paths: str, internal_resolution: int):
11
+ super().__init__()
12
+ self.root_dir = root_dir
13
+ self.file_paths = file_paths
14
+ self.internal_resolution = internal_resolution
15
+
16
+ def __len__(self):
17
+ return len(self.file_paths)
18
+
19
+ def __getitem__(self, idx: int):
20
+ filename = os.path.join(self.root_dir, self.file_paths[idx])
21
+ try:
22
+ numpy_file = np.load(filename)
23
+ torch_np = torch.from_numpy(numpy_file)
24
+ torch_np = torch_np.unsqueeze(0).unsqueeze(0).float() # Convert to float and move to appropriate device
25
+ interpolated_data = F.interpolate(input=torch_np, size=(self.internal_resolution, self.internal_resolution, self.internal_resolution), mode='trilinear')
26
+
27
+ # Apply tanh and log operations
28
+ interpolated_data_tanh = torch.tanh(interpolated_data)
29
+ interpolated_data_log = torch.log(interpolated_data + 1).squeeze(0) # Adding 1 to avoid log(0)
30
+
31
+ return interpolated_data_log
32
+ except Exception as e:
33
+ print(f"Error loading file '{filename}': {e}")
34
+ return None
vq_gan_3d/metrics.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import sys
6
+ import math
7
+ import lpips as lpips_metric
8
+ import piq
9
+ import numpy as np
10
+ from numpy import cov
11
+ from numpy import trace
12
+ from numpy import iscomplexobj
13
+ from numpy.random import random
14
+ from scipy.linalg import sqrtm
15
+ from skimage.metrics import structural_similarity as ssim
16
+ from skimage.metrics import peak_signal_noise_ratio
17
+ from tqdm import tqdm
18
+
19
+ import warnings
20
+ warnings.filterwarnings("ignore")
21
+
22
+
23
+ def blockPrint():
24
+ sys.stdout = open(os.devnull, 'w')
25
+
26
+
27
+ def enablePrint():
28
+ sys.stdout = sys.__stdout__
29
+
30
+
31
+ def normalize_tensor(outmap):
32
+ flattened_outmap = outmap.view(outmap.shape[0], -1, 1, 1) # Use 1's to preserve the number of dimensions for broadcasting later, as explained
33
+ outmap_min, _ = torch.min(flattened_outmap, dim=1, keepdim=True)
34
+ outmap_max, _ = torch.max(flattened_outmap, dim=1, keepdim=True)
35
+ outmap = (outmap - outmap_min) / (outmap_max - outmap_min)
36
+ return outmap
37
+
38
+
39
+ class ImageMetrics:
40
+
41
+ def __init__(self, grid_true, grid_pred, device='cpu'):
42
+ self.grid_true = grid_true.to(device) # [N, H, W]
43
+ self.grid_pred = grid_pred.to(device) # [N, H, W]
44
+ self.num_sequence = grid_true.shape[0]
45
+ self.loss_fn_vgg = lpips_metric.LPIPS(net='vgg', verbose=False).to(device) # closer to "traditional" perceptual loss, when used for optimization
46
+ self.device = device
47
+
48
+ def ssim(self):
49
+ """Structured Similarity Index Metric"""
50
+ a = normalize_tensor(self.grid_true.unsqueeze(0))
51
+ b = normalize_tensor(self.grid_pred.unsqueeze(0))
52
+ ssim = piq.ssim(a, b, data_range=1., reduction='none').squeeze().item()
53
+ return ssim
54
+
55
+ def mssim(self):
56
+ """Mean Structured Similarity Index Metric"""
57
+ mssim = 0
58
+ for idx in range(self.num_sequence):
59
+ max_value = max([self.grid_true[idx].max(), self.grid_pred[idx].max()])
60
+ min_value = min([self.grid_true[idx].min(), self.grid_pred[idx].min()])
61
+ data_range = abs(max_value - min_value)
62
+
63
+ a = self.grid_true[idx].detach().cpu().numpy()
64
+ b = self.grid_pred[idx].detach().cpu().numpy()
65
+
66
+ mssim += ssim(a, b, data_range=data_range.item())
67
+ return mssim / self.num_sequence
68
+
69
+ def multiscale_ssim(self):
70
+ """Multi-Scale SSIM"""
71
+ a = normalize_tensor(self.grid_true.unsqueeze(0))
72
+ b = normalize_tensor(self.grid_pred.unsqueeze(0))
73
+ ms_ssim_index = piq.multi_scale_ssim(a, b, data_range=1., kernel_size=7).item()
74
+ return ms_ssim_index
75
+
76
+ def psnr(self):
77
+ """Peak Signal-to-Noise Ratio"""
78
+ psnr = 0
79
+ for idx in range(self.num_sequence):
80
+ max_value = max([self.grid_true[idx].max(), self.grid_pred[idx].max()])
81
+ min_value = min([self.grid_true[idx].min(), self.grid_pred[idx].min()])
82
+ data_range = abs(max_value - min_value)
83
+
84
+ a = self.grid_true[idx].detach().cpu().numpy()
85
+ b = self.grid_pred[idx].detach().cpu().numpy()
86
+
87
+ psnr += peak_signal_noise_ratio(a, b, data_range=data_range.item())
88
+ return psnr / self.num_sequence
89
+
90
+ def _calculate_fid(self, img1, img2):
91
+ img1 = img1.detach().cpu().numpy()
92
+ img2 = img2.detach().cpu().numpy()
93
+
94
+ # calculate mean and covariance statistics
95
+ mu1, sigma1 = img1.mean(axis=0), cov(img1, rowvar=False)
96
+ mu2, sigma2 = img2.mean(axis=0), cov(img2, rowvar=False)
97
+ # calculate sum squared difference between means
98
+ ssdiff = np.sum((mu1 - mu2)**2.0)
99
+ # calculate sqrt of product between cov
100
+ covmean = sqrtm(sigma1.dot(sigma2))
101
+ # check and correct imaginary numbers from sqrt
102
+ if iscomplexobj(covmean):
103
+ covmean = covmean.real
104
+ # calculate score
105
+ fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
106
+ return fid
107
+
108
+ def fid(self):
109
+ """Frechet Inception Distance"""
110
+ fid = 0
111
+ for idx in range(self.num_sequence):
112
+ fid += self._calculate_fid(self.grid_true[idx], self.grid_pred[idx])
113
+ return fid / self.num_sequence
114
+
115
+ def _calculate_lpips(self, img1, img2):
116
+ img1 = (2 * img1 / img1.max() - 1) # normalize between -1 to 1
117
+ img2 = (2 * img2 / img2.max() - 1) # normalize between -1 to 1
118
+ perceptual_loss = self.loss_fn_vgg(img1, img2).squeeze()
119
+ return perceptual_loss.item()
120
+
121
+ def lpips(self):
122
+ """Learned Perceptual Image Patch Similarity"""
123
+ perceptual_loss = 0
124
+ for idx in range(self.num_sequence):
125
+ perceptual_loss += self._calculate_lpips(self.grid_true[idx], self.grid_pred[idx])
126
+ return perceptual_loss / self.num_sequence
127
+
128
+ def reconstruction(self):
129
+ return F.l1_loss(self.grid_true, self.grid_pred).item()
130
+
131
+ def IS(self):
132
+ """Inception Score"""
133
+ is_score = 0
134
+ for idx in range(self.num_sequence):
135
+ is_score += piq.IS(distance='l1')(self.grid_true[idx], self.grid_pred[idx])
136
+ return (is_score / self.num_sequence).item()
137
+
138
+ def kid(self):
139
+ """Kernel Inception Distance"""
140
+ kid_score = 0
141
+ for idx in range(self.num_sequence):
142
+ kid_score += piq.KID()(self.grid_true[idx], self.grid_pred[idx])
143
+ return (kid_score / self.num_sequence).item()
144
+
145
+ def get_metrics(self):
146
+ blockPrint()
147
+ metrics = dict(
148
+ SSIM=self.ssim(),
149
+ MSSIM=self.mssim(),
150
+ MS_SSIM=self.multiscale_ssim(),
151
+ PSNR=self.psnr(),
152
+ IS=self.IS(),
153
+ FID=self.fid(),
154
+ KID=self.kid(),
155
+ LPIPS=self.lpips(),
156
+ Reconstruction=self.reconstruction(),
157
+ )
158
+ enablePrint()
159
+ return metrics
vq_gan_3d/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from vq_gan_3d.model.vqgan import VQGAN, load_VQGAN
2
+ from vq_gan_3d.model.codebook import Codebook
3
+ from vq_gan_3d.model.lpips import LPIPS
vq_gan_3d/model/cache/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
vq_gan_3d/model/codebook.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.distributed as dist
7
+
8
+ from vq_gan_3d.utils import shift_dim
9
+
10
+
11
+ class Codebook(nn.Module):
12
+ def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0):
13
+ super().__init__()
14
+ self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim))
15
+ self.register_buffer('N', torch.zeros(n_codes))
16
+ self.register_buffer('z_avg', self.embeddings.data.clone())
17
+
18
+ self.n_codes = n_codes
19
+ self.embedding_dim = embedding_dim
20
+ self._need_init = True
21
+ self.no_random_restart = no_random_restart
22
+ self.restart_thres = restart_thres
23
+
24
+ def _tile(self, x):
25
+ d, ew = x.shape
26
+ if d < self.n_codes:
27
+ n_repeats = (self.n_codes + d - 1) // d
28
+ std = 0.01 / np.sqrt(ew)
29
+ x = x.repeat(n_repeats, 1)
30
+ x = x + torch.randn_like(x) * std
31
+ return x
32
+
33
+ def _init_embeddings(self, z):
34
+ # z: [b, c, t, h, w]
35
+ self._need_init = False
36
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
37
+ y = self._tile(flat_inputs)
38
+
39
+ d = y.shape[0]
40
+ _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
41
+ if dist.is_initialized():
42
+ dist.broadcast(_k_rand, 0)
43
+ self.embeddings.data.copy_(_k_rand)
44
+ self.z_avg.data.copy_(_k_rand)
45
+ self.N.data.copy_(torch.ones(self.n_codes))
46
+
47
+ def forward(self, z):
48
+ # z: [b, c, t, h, w]
49
+ if self._need_init and self.training:
50
+ self._init_embeddings(z)
51
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c]
52
+ distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
53
+ - 2 * flat_inputs @ self.embeddings.t() \
54
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c]
55
+
56
+ encoding_indices = torch.argmin(distances, dim=1)
57
+ encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(
58
+ flat_inputs) # [bthw, ncode]
59
+ encoding_indices = encoding_indices.view(
60
+ z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode]
61
+
62
+ embeddings = F.embedding(
63
+ encoding_indices, self.embeddings) # [b, t, h, w, c]
64
+ embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w]
65
+
66
+ commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
67
+
68
+ # EMA codebook update
69
+ if self.training:
70
+ n_total = encode_onehot.sum(dim=0)
71
+ encode_sum = flat_inputs.t() @ encode_onehot
72
+ if dist.is_initialized():
73
+ dist.all_reduce(n_total)
74
+ dist.all_reduce(encode_sum)
75
+
76
+ self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
77
+ self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
78
+
79
+ n = self.N.sum()
80
+ weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
81
+ encode_normalized = self.z_avg / weights.unsqueeze(1)
82
+ self.embeddings.data.copy_(encode_normalized)
83
+
84
+ y = self._tile(flat_inputs)
85
+ _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
86
+ if dist.is_initialized():
87
+ dist.broadcast(_k_rand, 0)
88
+
89
+ if not self.no_random_restart:
90
+ usage = (self.N.view(self.n_codes, 1)
91
+ >= self.restart_thres).float()
92
+ self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
93
+
94
+ embeddings_st = (embeddings - z).detach() + z
95
+ embeddings_st_exp = torch.exp(embeddings_st)
96
+
97
+
98
+ avg_probs = torch.mean(encode_onehot, dim=0)
99
+ perplexity = torch.exp(-torch.sum(avg_probs *
100
+ torch.log(avg_probs + 1e-10)))
101
+
102
+ return dict(embeddings=embeddings_st, encodings=encoding_indices,
103
+ commitment_loss=commitment_loss, perplexity=perplexity)
104
+
105
+ def dictionary_lookup(self, encodings):
106
+ embeddings = F.embedding(encodings, self.embeddings)
107
+ return embeddings
vq_gan_3d/model/lpips.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/SongweiGe/TATS"""
2
+
3
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
4
+
5
+
6
+ from collections import namedtuple
7
+ from torchvision import models
8
+ import torch.nn as nn
9
+ import torch
10
+ from tqdm import tqdm
11
+ import requests
12
+ import os
13
+ import hashlib
14
+ URL_MAP = {
15
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
16
+ }
17
+
18
+ CKPT_MAP = {
19
+ "vgg_lpips": "vgg.pth"
20
+ }
21
+
22
+ MD5_MAP = {
23
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
24
+ }
25
+
26
+
27
+ def download(url, local_path, chunk_size=1024):
28
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
29
+ with requests.get(url, stream=True) as r:
30
+ total_size = int(r.headers.get("content-length", 0))
31
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
32
+ with open(local_path, "wb") as f:
33
+ for data in r.iter_content(chunk_size=chunk_size):
34
+ if data:
35
+ f.write(data)
36
+ pbar.update(chunk_size)
37
+
38
+
39
+ def md5_hash(path):
40
+ with open(path, "rb") as f:
41
+ content = f.read()
42
+ return hashlib.md5(content).hexdigest()
43
+
44
+
45
+ def get_ckpt_path(name, root, check=False):
46
+ assert name in URL_MAP
47
+ path = os.path.join(root, CKPT_MAP[name])
48
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
49
+ print("Downloading {} model from {} to {}".format(
50
+ name, URL_MAP[name], path))
51
+ download(URL_MAP[name], path)
52
+ md5 = md5_hash(path)
53
+ assert md5 == MD5_MAP[name], md5
54
+ return path
55
+
56
+
57
+ class LPIPS(nn.Module):
58
+ # Learned perceptual metric
59
+ def __init__(self, use_dropout=True):
60
+ super().__init__()
61
+ self.scaling_layer = ScalingLayer()
62
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
63
+ self.net = vgg16(pretrained=False, requires_grad=True) # enabled grad
64
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
65
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
66
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
67
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
68
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
69
+ self.load_from_pretrained()
70
+ # for param in self.parameters():
71
+ # param.requires_grad = False
72
+
73
+ def load_from_pretrained(self, name="vgg_lpips"):
74
+ ckpt = get_ckpt_path(name, os.path.join(
75
+ os.path.dirname(os.path.abspath(__file__)), "cache"))
76
+ self.load_state_dict(torch.load(
77
+ ckpt, map_location=torch.device("cpu")), strict=False)
78
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
79
+
80
+ @classmethod
81
+ def from_pretrained(cls, name="vgg_lpips"):
82
+ if name != "vgg_lpips":
83
+ raise NotImplementedError
84
+ model = cls()
85
+ ckpt = get_ckpt_path(name, os.path.join(
86
+ os.path.dirname(os.path.abspath(__file__)), "cache"))
87
+ model.load_state_dict(torch.load(
88
+ ckpt, map_location=torch.device("cpu")), strict=False)
89
+ return model
90
+
91
+ def forward(self, input, target):
92
+ in0_input, in1_input = (self.scaling_layer(
93
+ input), self.scaling_layer(target))
94
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
95
+ feats0, feats1, diffs = {}, {}, {}
96
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
97
+ for kk in range(len(self.chns)):
98
+ feats0[kk], feats1[kk] = normalize_tensor(
99
+ outs0[kk]), normalize_tensor(outs1[kk])
100
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
101
+
102
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
103
+ for kk in range(len(self.chns))]
104
+ val = res[0]
105
+ for l in range(1, len(self.chns)):
106
+ val += res[l]
107
+ return val
108
+
109
+
110
+ class ScalingLayer(nn.Module):
111
+ def __init__(self):
112
+ super(ScalingLayer, self).__init__()
113
+ self.register_buffer('shift', torch.Tensor(
114
+ [-.030, -.088, -.188])[None, :, None, None])
115
+ self.register_buffer('scale', torch.Tensor(
116
+ [.458, .448, .450])[None, :, None, None])
117
+
118
+ def forward(self, inp):
119
+ return (inp - self.shift) / self.scale
120
+
121
+
122
+ class NetLinLayer(nn.Module):
123
+ """ A single linear layer which does a 1x1 conv """
124
+
125
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
126
+ super(NetLinLayer, self).__init__()
127
+ layers = [nn.Dropout(), ] if (use_dropout) else []
128
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1,
129
+ padding=0, bias=False), ]
130
+ self.model = nn.Sequential(*layers)
131
+
132
+
133
+ class vgg16(torch.nn.Module):
134
+ def __init__(self, requires_grad=False, pretrained=True):
135
+ super(vgg16, self).__init__()
136
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
137
+ self.slice1 = torch.nn.Sequential()
138
+ self.slice2 = torch.nn.Sequential()
139
+ self.slice3 = torch.nn.Sequential()
140
+ self.slice4 = torch.nn.Sequential()
141
+ self.slice5 = torch.nn.Sequential()
142
+ self.N_slices = 5
143
+ for x in range(4):
144
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
145
+ for x in range(4, 9):
146
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
147
+ for x in range(9, 16):
148
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
149
+ for x in range(16, 23):
150
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
151
+ for x in range(23, 30):
152
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
153
+ if not requires_grad:
154
+ for param in self.parameters():
155
+ param.requires_grad = False
156
+
157
+ def forward(self, X):
158
+ h = self.slice1(X)
159
+ h_relu1_2 = h
160
+ h = self.slice2(h)
161
+ h_relu2_2 = h
162
+ h = self.slice3(h)
163
+ h_relu3_3 = h
164
+ h = self.slice4(h)
165
+ h_relu4_3 = h
166
+ h = self.slice5(h)
167
+ h_relu5_3 = h
168
+ vgg_outputs = namedtuple(
169
+ "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
170
+ out = vgg_outputs(h_relu1_2, h_relu2_2,
171
+ h_relu3_3, h_relu4_3, h_relu5_3)
172
+ return out
173
+
174
+
175
+ def normalize_tensor(x, eps=1e-10):
176
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
177
+ return x/(norm_factor+eps)
178
+
179
+
180
+ def spatial_average(x, keepdim=True):
181
+ return x.mean([2, 3], keepdim=keepdim)
vq_gan_3d/model/vqgan.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/SongweiGe/TATS"""
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import math
5
+ import argparse
6
+ import numpy as np
7
+ import pickle as pkl
8
+ import random
9
+ import gc
10
+ import os
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.backends.cudnn as cudnn
15
+ import torch.distributed as dist
16
+
17
+ from vq_gan_3d.utils import shift_dim, adopt_weight, comp_getattr
18
+ from vq_gan_3d.model.lpips import LPIPS
19
+ from vq_gan_3d.model.codebook import Codebook
20
+
21
+
22
+ def silu(x):
23
+ return x*torch.sigmoid(x)
24
+
25
+
26
+ class SiLU(nn.Module):
27
+ def __init__(self):
28
+ super(SiLU, self).__init__()
29
+
30
+ def forward(self, x):
31
+ return silu(x)
32
+
33
+
34
+ def hinge_d_loss(logits_real, logits_fake):
35
+ loss_real = torch.mean(F.relu(1. - logits_real))
36
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
37
+ d_loss = 0.5 * (loss_real + loss_fake)
38
+ return d_loss
39
+
40
+
41
+ def vanilla_d_loss(logits_real, logits_fake):
42
+ d_loss = 0.5 * (
43
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
44
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
45
+ return d_loss
46
+
47
+
48
+ class MeanPooling(nn.Module):
49
+ def __init__(self, kernel_size=16):
50
+ super(MeanPooling, self).__init__()
51
+ # Define a 3D average pooling layer
52
+ self.pool = nn.AvgPool3d(kernel_size=kernel_size)
53
+
54
+ def forward(self, x):
55
+ # Apply average pooling
56
+ x = self.pool(x)
57
+ # Flatten the tensor to a single dimension per batch element
58
+ x = x.view(x.size(0), -1)
59
+ return x
60
+
61
+
62
+ class VQGAN(nn.Module):
63
+ def __init__(self):
64
+ super().__init__()
65
+
66
+ self._set_seed(0)
67
+ self.embedding_dim = 256
68
+ self.n_codes = 16384
69
+
70
+ self.encoder = Encoder(16, [4,4,4], 1, 'group', 'replicate', 32)
71
+ self.decoder = Decoder(16, [4,4,4], 1, 'group', 32)
72
+ self.enc_out_ch = self.encoder.out_channels
73
+ self.pre_vq_conv = SamePadConv3d(self.enc_out_ch, 256, 1, padding_type='replicate')
74
+ self.post_vq_conv = SamePadConv3d(256, self.enc_out_ch, 1)
75
+
76
+ self.codebook = Codebook(16384, 256, no_random_restart=False, restart_thres=False)
77
+
78
+ self.pooling = MeanPooling(kernel_size=4)
79
+
80
+ self.gan_feat_weight = 4
81
+ # TODO: Changed batchnorm from sync to normal
82
+ self.image_discriminator = NLayerDiscriminator(1, 64, 3, norm_layer=nn.BatchNorm2d)
83
+
84
+ self.disc_loss = hinge_d_loss
85
+ self.perceptual_model = LPIPS()
86
+ self.image_gan_weight = 1
87
+ self.perceptual_weight = 4
88
+ self.l1_weight = 4
89
+
90
+ def encode(self, x, include_embeddings=False, quantize=True):
91
+ h = self.pre_vq_conv(self.encoder(x))
92
+ if quantize:
93
+ vq_output = self.codebook(h)
94
+ if include_embeddings:
95
+ return vq_output['embeddings'], vq_output['encodings']
96
+ else:
97
+ return vq_output['encodings']
98
+ return h
99
+
100
+ def decode(self, latent, quantize=False):
101
+ if quantize:
102
+ vq_output = self.codebook(latent)
103
+ latent = vq_output['encodings']
104
+ h = F.embedding(latent, self.codebook.embeddings)
105
+ h = self.post_vq_conv(shift_dim(h, -1, 1))
106
+ return self.decoder(h)
107
+
108
+ def feature_extraction(self, x):
109
+ """Extract embeddings given a grid."""
110
+ h = self.encode(x, include_embeddings=False, quantize=False)
111
+ return self.pooling(h.permute(0, 2, 3, 4, 1))
112
+
113
+ def forward(self, global_step, x, optimizer_idx=None, log_image=False, gpu_id=0):
114
+ B, C, T, H, W = x.shape
115
+
116
+ z = self.pre_vq_conv(self.encoder(x))
117
+ vq_output = self.codebook(z, gpu_id)
118
+
119
+ #vq_output['embeddings'] = torch.exp(vq_output['embeddings'])
120
+ x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings']))
121
+
122
+ recon_loss = (F.l1_loss(x_recon, x) * self.l1_weight)
123
+
124
+ # Selects one random 2D image from each 3D Image
125
+ frame_idx = torch.randint(0, T, [B]).to(gpu_id)
126
+ frame_idx_selected = frame_idx.reshape(-1,
127
+ 1, 1, 1, 1).repeat(1, C, 1, H, W)
128
+ frames = torch.gather(x, 2, frame_idx_selected).squeeze(2)
129
+ frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2)
130
+
131
+ if log_image:
132
+ return frames, frames_recon, x, x_recon
133
+
134
+ if optimizer_idx == 0:
135
+ # Autoencoder - train the "generator"
136
+
137
+ # Perceptual loss
138
+ perceptual_loss = 0
139
+ if self.perceptual_weight > 0:
140
+ perceptual_loss = self.perceptual_model(
141
+ frames, frames_recon).mean() * self.perceptual_weight
142
+ # perceptual_loss = .123
143
+
144
+ # Discriminator loss (turned on after a certain epoch)
145
+ logits_image_fake, pred_image_fake = self.image_discriminator(
146
+ frames_recon)
147
+ g_image_loss = -torch.mean(logits_image_fake)
148
+ g_loss = self.image_gan_weight*g_image_loss
149
+ disc_factor = adopt_weight(
150
+ global_step, threshold=self.cfg.model.discriminator_iter_start)
151
+ aeloss = disc_factor * g_loss
152
+
153
+ # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator
154
+ image_gan_feat_loss = 0
155
+ feat_weights = 4.0 / (3 + 1)
156
+ if self.image_gan_weight > 0:
157
+ logits_image_real, pred_image_real = self.image_discriminator(
158
+ frames)
159
+ for i in range(len(pred_image_fake)-1):
160
+ image_gan_feat_loss += feat_weights * \
161
+ F.l1_loss(pred_image_fake[i], pred_image_real[i].detach(
162
+ )) * (self.image_gan_weight > 0)
163
+
164
+ gan_feat_loss = disc_factor * self.gan_feat_weight * \
165
+ (image_gan_feat_loss)
166
+
167
+ return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss, (g_image_loss, image_gan_feat_loss, vq_output['commitment_loss'], vq_output['perplexity'])
168
+
169
+ if optimizer_idx == 1:
170
+ # Train discriminator
171
+ logits_image_real, _ = self.image_discriminator(frames.detach())
172
+
173
+ logits_image_fake, _ = self.image_discriminator(
174
+ frames_recon.detach())
175
+
176
+ d_image_loss = self.disc_loss(logits_image_real, logits_image_fake)
177
+ disc_factor = adopt_weight(
178
+ global_step, threshold=self.cfg.model.discriminator_iter_start)
179
+ discloss = disc_factor * \
180
+ (self.image_gan_weight*d_image_loss)
181
+
182
+ return discloss
183
+
184
+ perceptual_loss = self.perceptual_model(
185
+ frames, frames_recon) * self.perceptual_weight
186
+ return recon_loss, x_recon, vq_output, perceptual_loss
187
+
188
+ def load_checkpoint(self, ckpt_path):
189
+ # load checkpoint file
190
+ ckpt_dict = torch.load(ckpt_path, map_location='cpu', weights_only=False)
191
+
192
+ # load hyparameters
193
+ self.config = ckpt_dict['hparams']['_content']
194
+ self.embedding_dim = self.config['model']['embedding_dim']
195
+ self.n_codes = self.config['model']['n_codes']
196
+
197
+ # instantiate modules
198
+ self.encoder = Encoder(
199
+ self.config['model']['n_hiddens'],
200
+ self.config['model']['downsample'],
201
+ self.config['dataset']['image_channels'],
202
+ self.config['model']['norm_type'],
203
+ self.config['model']['padding_type'],
204
+ self.config['model']['num_groups'],
205
+ )
206
+ self.decoder = Decoder(
207
+ self.config['model']['n_hiddens'],
208
+ self.config['model']['downsample'],
209
+ self.config['dataset']['image_channels'],
210
+ self.config['model']['norm_type'],
211
+ self.config['model']['num_groups']
212
+ )
213
+ self.enc_out_ch = self.encoder.out_channels
214
+ self.pre_vq_conv = SamePadConv3d(self.enc_out_ch, self.embedding_dim, 1, padding_type=self.config['model']['padding_type'])
215
+ self.post_vq_conv = SamePadConv3d(self.embedding_dim, self.enc_out_ch, 1)
216
+ self.codebook = Codebook(
217
+ self.n_codes,
218
+ self.embedding_dim,
219
+ no_random_restart=self.config['model']['no_random_restart'],
220
+ restart_thres=False
221
+ )
222
+ self.gan_feat_weight = self.config['model']['gan_feat_weight']
223
+ # TODO: Changed batchnorm from sync to normal
224
+ self.image_discriminator = NLayerDiscriminator(
225
+ self.config['dataset']['image_channels'],
226
+ self.config['model']['disc_channels'],
227
+ self.config['model']['disc_layers'],
228
+ norm_layer=nn.BatchNorm2d
229
+ )
230
+ self.disc_loss = hinge_d_loss
231
+ self.perceptual_model = LPIPS()
232
+ self.image_gan_weight = self.config['model']['gan_feat_weight']
233
+ self.perceptual_weight = self.config['model']['perceptual_weight']
234
+ self.l1_weight = self.config['model']['l1_weight']
235
+
236
+ # restore model weights
237
+ self.load_state_dict(ckpt_dict["MODEL_STATE"], strict=True)
238
+
239
+ # load RNG states each time the model and states are loaded from checkpoint
240
+ if 'rng' in self.config:
241
+ rng = self.config['rng']
242
+ for key, value in rng.items():
243
+ if key =='torch_state':
244
+ torch.set_rng_state(value.cpu())
245
+ elif key =='cuda_state':
246
+ torch.cuda.set_rng_state(value.cpu())
247
+ elif key =='numpy_state':
248
+ np.random.set_state(value)
249
+ elif key =='python_state':
250
+ random.setstate(value)
251
+ else:
252
+ print('unrecognized state')
253
+
254
+ def log_images(self, batch, **kwargs):
255
+ log = dict()
256
+ x = batch['data']
257
+ x = x.to(self.device)
258
+ frames, frames_rec, _, _ = self(x, log_image=True)
259
+ log["inputs"] = frames
260
+ log["reconstructions"] = frames_rec
261
+ #log['mean_org'] = batch['mean_org']
262
+ #log['std_org'] = batch['std_org']
263
+ return log
264
+
265
+ def _set_seed(self, value):
266
+ print('Random Seed:', value)
267
+ random.seed(value)
268
+ torch.manual_seed(value)
269
+ torch.cuda.manual_seed(value)
270
+ torch.cuda.manual_seed_all(value)
271
+ np.random.seed(value)
272
+ cudnn.deterministic = True
273
+ cudnn.benchmark = True
274
+ cudnn.enabled = True
275
+
276
+
277
+ def Normalize(in_channels, norm_type='group', num_groups=32):
278
+ assert norm_type in ['group', 'batch']
279
+ if norm_type == 'group':
280
+ # TODO Changed num_groups from 32 to 8
281
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
282
+ elif norm_type == 'batch':
283
+ return torch.nn.SyncBatchNorm(in_channels)
284
+
285
+
286
+ class Encoder(nn.Module):
287
+ def __init__(self, n_hiddens = 16, downsample = [2,2,2] , image_channel=64, norm_type='group', padding_type='replicate', num_groups=32):
288
+ super().__init__()
289
+ n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
290
+ self.conv_blocks = nn.ModuleList()
291
+ max_ds = n_times_downsample.max()
292
+
293
+ self.conv_first = SamePadConv3d(
294
+ image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
295
+
296
+ for i in range(max_ds):
297
+ block = nn.Module()
298
+ in_channels = n_hiddens * 2**i
299
+ out_channels = n_hiddens * 2**(i+1)
300
+ stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
301
+ block.down = SamePadConv3d(
302
+ in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
303
+ block.res = ResBlock(
304
+ out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
305
+ self.conv_blocks.append(block)
306
+ n_times_downsample -= 1
307
+
308
+ self.final_block = nn.Sequential(
309
+ Normalize(out_channels, norm_type, num_groups=num_groups),
310
+ SiLU()
311
+ )
312
+
313
+ self.out_channels = out_channels
314
+
315
+ def forward(self, x):
316
+ h = self.conv_first(x)
317
+ for block in self.conv_blocks:
318
+ h = block.down(h)
319
+ h = block.res(h)
320
+ h = self.final_block(h)
321
+ return h
322
+
323
+
324
+ class Decoder(nn.Module):
325
+ def __init__(self, n_hiddens = 16, upsample= [4,4,4], image_channel=1, norm_type='group', num_groups=1):
326
+ super().__init__()
327
+
328
+ n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
329
+ print('n_times_upsample :', n_times_upsample)
330
+ max_us = n_times_upsample.max()
331
+ print('max_us :', max_us)
332
+
333
+
334
+ in_channels = n_hiddens*2**max_us
335
+ self.final_block = nn.Sequential(
336
+ Normalize(in_channels, norm_type, num_groups=num_groups),
337
+ SiLU()
338
+ )
339
+
340
+ self.conv_blocks = nn.ModuleList()
341
+ for i in range(max_us):
342
+ block = nn.Module()
343
+ in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1)
344
+ out_channels = n_hiddens*2**(max_us-i)
345
+ us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
346
+ block.up = SamePadConvTranspose3d(
347
+ in_channels, out_channels, 4, stride=us)
348
+ block.res1 = ResBlock(
349
+ out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
350
+ block.res2 = ResBlock(
351
+ out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
352
+ self.conv_blocks.append(block)
353
+ n_times_upsample -= 1
354
+
355
+ self.conv_last = SamePadConv3d(
356
+ out_channels, image_channel, kernel_size=3)
357
+
358
+
359
+ def forward(self, x):
360
+ h = self.final_block(x)
361
+ for i, block in enumerate(self.conv_blocks):
362
+ h = block.up(h)
363
+ h = block.res1(h)
364
+ h = block.res2(h)
365
+ h = self.conv_last(h)
366
+ return h
367
+
368
+
369
+ class ResBlock(nn.Module):
370
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32):
371
+ super().__init__()
372
+ self.in_channels = in_channels
373
+ out_channels = in_channels if out_channels is None else out_channels
374
+ self.out_channels = out_channels
375
+ self.use_conv_shortcut = conv_shortcut
376
+
377
+ self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups)
378
+ self.conv1 = SamePadConv3d(
379
+ in_channels, out_channels, kernel_size=3, padding_type=padding_type)
380
+ self.dropout = torch.nn.Dropout(dropout)
381
+ self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups)
382
+ self.conv2 = SamePadConv3d(
383
+ out_channels, out_channels, kernel_size=3, padding_type=padding_type)
384
+ if self.in_channels != self.out_channels:
385
+ self.conv_shortcut = SamePadConv3d(
386
+ in_channels, out_channels, kernel_size=3, padding_type=padding_type)
387
+
388
+ def forward(self, x):
389
+ h = x
390
+ h = self.norm1(h)
391
+ h = silu(h)
392
+ h = self.conv1(h)
393
+ h = self.norm2(h)
394
+ h = silu(h)
395
+ h = self.conv2(h)
396
+
397
+ if self.in_channels != self.out_channels:
398
+ x = self.conv_shortcut(x)
399
+
400
+ return x+h
401
+
402
+
403
+ # Does not support dilation
404
+ class SamePadConv3d(nn.Module):
405
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
406
+ super().__init__()
407
+ if isinstance(kernel_size, int):
408
+ kernel_size = (kernel_size,) * 3
409
+ if isinstance(stride, int):
410
+ stride = (stride,) * 3
411
+
412
+ # assumes that the input shape is divisible by stride
413
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
414
+ pad_input = []
415
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
416
+ pad_input.append((p // 2 + p % 2, p // 2))
417
+ pad_input = sum(pad_input, tuple())
418
+ self.pad_input = pad_input
419
+ self.padding_type = padding_type
420
+
421
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
422
+ stride=stride, padding=0, bias=bias)
423
+
424
+ def forward(self, x):
425
+ return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
426
+
427
+
428
+ class SamePadConvTranspose3d(nn.Module):
429
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
430
+ super().__init__()
431
+ if isinstance(kernel_size, int):
432
+ kernel_size = (kernel_size,) * 3
433
+ if isinstance(stride, int):
434
+ stride = (stride,) * 3
435
+
436
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
437
+ pad_input = []
438
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
439
+ pad_input.append((p // 2 + p % 2, p // 2))
440
+ pad_input = sum(pad_input, tuple())
441
+ self.pad_input = pad_input
442
+ self.padding_type = padding_type
443
+
444
+ self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
445
+ stride=stride, bias=bias,
446
+ padding=tuple([k - 1 for k in kernel_size]))
447
+
448
+ def forward(self, x):
449
+ return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
450
+
451
+
452
+ class NLayerDiscriminator(nn.Module):
453
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
454
+ # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True):
455
+ super(NLayerDiscriminator, self).__init__()
456
+ self.getIntermFeat = getIntermFeat
457
+ self.n_layers = n_layers
458
+
459
+ kw = 4
460
+ padw = int(np.ceil((kw-1.0)/2))
461
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw,
462
+ stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
463
+
464
+ nf = ndf
465
+ for n in range(1, n_layers):
466
+ nf_prev = nf
467
+ nf = min(nf * 2, 512)
468
+ sequence += [[
469
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
470
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
471
+ ]]
472
+
473
+ nf_prev = nf
474
+ nf = min(nf * 2, 512)
475
+ sequence += [[
476
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
477
+ norm_layer(nf),
478
+ nn.LeakyReLU(0.2, True)
479
+ ]]
480
+
481
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw,
482
+ stride=1, padding=padw)]]
483
+
484
+ if use_sigmoid:
485
+ sequence += [[nn.Sigmoid()]]
486
+
487
+ if getIntermFeat:
488
+ for n in range(len(sequence)):
489
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
490
+ else:
491
+ sequence_stream = []
492
+ for n in range(len(sequence)):
493
+ sequence_stream += sequence[n]
494
+ self.model = nn.Sequential(*sequence_stream)
495
+
496
+ def forward(self, input):
497
+ if self.getIntermFeat:
498
+ res = [input]
499
+ for n in range(self.n_layers+2):
500
+ model = getattr(self, 'model'+str(n))
501
+ res.append(model(res[-1]))
502
+ return res[-1], res[1:]
503
+ else:
504
+ return self.model(input), _
505
+
506
+
507
+ class NLayerDiscriminator3D(nn.Module):
508
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
509
+ super(NLayerDiscriminator3D, self).__init__()
510
+ self.getIntermFeat = getIntermFeat
511
+ self.n_layers = n_layers
512
+
513
+ kw = 4
514
+ padw = int(np.ceil((kw-1.0)/2))
515
+ sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw,
516
+ stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
517
+
518
+ nf = ndf
519
+ for n in range(1, n_layers):
520
+ nf_prev = nf
521
+ nf = min(nf * 2, 512)
522
+ sequence += [[
523
+ nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
524
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
525
+ ]]
526
+
527
+ nf_prev = nf
528
+ nf = min(nf * 2, 512)
529
+ sequence += [[
530
+ nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
531
+ norm_layer(nf),
532
+ nn.LeakyReLU(0.2, True)
533
+ ]]
534
+
535
+ sequence += [[nn.Conv3d(nf, 1, kernel_size=kw,
536
+ stride=1, padding=padw)]]
537
+
538
+ if use_sigmoid:
539
+ sequence += [[nn.Sigmoid()]]
540
+
541
+ if getIntermFeat:
542
+ for n in range(len(sequence)):
543
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
544
+ else:
545
+ sequence_stream = []
546
+ for n in range(len(sequence)):
547
+ sequence_stream += sequence[n]
548
+ self.model = nn.Sequential(*sequence_stream)
549
+
550
+ def forward(self, input):
551
+ if self.getIntermFeat:
552
+ res = [input]
553
+ for n in range(self.n_layers+2):
554
+ model = getattr(self, 'model'+str(n))
555
+ res.append(model(res[-1]))
556
+ return res[-1], res[1:]
557
+ else:
558
+ return self.model(input), _
559
+
560
+
561
+ def load_VQGAN(folder="../data/checkpoints/pretrained", ckpt_filename="VQGAN_43.pt"):
562
+ model = VQGAN()
563
+ model.load_checkpoint(os.path.join(folder, ckpt_filename))
564
+ model.eval()
565
+ return model
vq_gan_3d/utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Adapted from https://github.com/SongweiGe/TATS"""
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import warnings
5
+ import torch
6
+ import imageio
7
+
8
+ import math
9
+ import numpy as np
10
+
11
+ import sys
12
+ import pdb as pdb_original
13
+ # import SimpleITK as sitk
14
+ import logging
15
+
16
+ import imageio.core.util
17
+ logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)
18
+
19
+
20
+ def get_single_device(cpu=True):
21
+ if cpu:
22
+ return torch.device('cpu')
23
+ elif torch.cuda.is_available():
24
+ return torch.device('cuda')
25
+ elif torch.xpu.is_available():
26
+ return torch.device('xpu')
27
+ elif torch.mps.is_available():
28
+ return torch.device('mps')
29
+ return None
30
+
31
+
32
+ class ForkedPdb(pdb_original.Pdb):
33
+ """A Pdb subclass that may be used
34
+ from a forked multiprocessing child
35
+
36
+ """
37
+
38
+ def interaction(self, *args, **kwargs):
39
+ _stdin = sys.stdin
40
+ try:
41
+ sys.stdin = open('/dev/stdin')
42
+ pdb_original.Pdb.interaction(self, *args, **kwargs)
43
+ finally:
44
+ sys.stdin = _stdin
45
+
46
+
47
+ # Shifts src_tf dim to dest dim
48
+ # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
49
+ def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
50
+ n_dims = len(x.shape)
51
+ if src_dim < 0:
52
+ src_dim = n_dims + src_dim
53
+ if dest_dim < 0:
54
+ dest_dim = n_dims + dest_dim
55
+
56
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
57
+
58
+ dims = list(range(n_dims))
59
+ del dims[src_dim]
60
+
61
+ permutation = []
62
+ ctr = 0
63
+ for i in range(n_dims):
64
+ if i == dest_dim:
65
+ permutation.append(src_dim)
66
+ else:
67
+ permutation.append(dims[ctr])
68
+ ctr += 1
69
+ x = x.permute(permutation)
70
+ if make_contiguous:
71
+ x = x.contiguous()
72
+ return x
73
+
74
+
75
+ # reshapes tensor start from dim i (inclusive)
76
+ # to dim j (exclusive) to the desired shape
77
+ # e.g. if x.shape = (b, thw, c) then
78
+ # view_range(x, 1, 2, (t, h, w)) returns
79
+ # x of shape (b, t, h, w, c)
80
+ def view_range(x, i, j, shape):
81
+ shape = tuple(shape)
82
+
83
+ n_dims = len(x.shape)
84
+ if i < 0:
85
+ i = n_dims + i
86
+
87
+ if j is None:
88
+ j = n_dims
89
+ elif j < 0:
90
+ j = n_dims + j
91
+
92
+ assert 0 <= i < j <= n_dims
93
+
94
+ x_shape = x.shape
95
+ target_shape = x_shape[:i] + shape + x_shape[j:]
96
+ return x.view(target_shape)
97
+
98
+
99
+ def accuracy(output, target, topk=(1,)):
100
+ """Computes the accuracy over the k top predictions for the specified values of k"""
101
+ with torch.no_grad():
102
+ maxk = max(topk)
103
+ batch_size = target.size(0)
104
+
105
+ _, pred = output.topk(maxk, 1, True, True)
106
+ pred = pred.t()
107
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
108
+
109
+ res = []
110
+ for k in topk:
111
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
112
+ res.append(correct_k.mul_(100.0 / batch_size))
113
+ return res
114
+
115
+
116
+ def tensor_slice(x, begin, size):
117
+ assert all([b >= 0 for b in begin])
118
+ size = [l - b if s == -1 else s
119
+ for s, b, l in zip(size, begin, x.shape)]
120
+ assert all([s >= 0 for s in size])
121
+
122
+ slices = [slice(b, b + s) for b, s in zip(begin, size)]
123
+ return x[slices]
124
+
125
+
126
+ def adopt_weight(global_step, threshold=0, value=0.):
127
+ weight = 1
128
+ if global_step < threshold:
129
+ weight = value
130
+ return weight
131
+
132
+ def comp_getattr(args, attr_name, default=None):
133
+ if hasattr(args, attr_name):
134
+ return getattr(args, attr_name)
135
+ else:
136
+ return default
137
+
138
+
139
+ def visualize_tensors(t, name=None, nest=0):
140
+ if name is not None:
141
+ print(name, "current nest: ", nest)
142
+ print("type: ", type(t))
143
+ if 'dict' in str(type(t)):
144
+ print(t.keys())
145
+ for k in t.keys():
146
+ if t[k] is None:
147
+ print(k, "None")
148
+ else:
149
+ if 'Tensor' in str(type(t[k])):
150
+ print(k, t[k].shape)
151
+ elif 'dict' in str(type(t[k])):
152
+ print(k, 'dict')
153
+ visualize_tensors(t[k], name, nest + 1)
154
+ elif 'list' in str(type(t[k])):
155
+ print(k, len(t[k]))
156
+ visualize_tensors(t[k], name, nest + 1)
157
+ elif 'list' in str(type(t)):
158
+ print("list length: ", len(t))
159
+ for t2 in t:
160
+ visualize_tensors(t2, name, nest + 1)
161
+ elif 'Tensor' in str(type(t)):
162
+ print(t.shape)
163
+ else:
164
+ print(t)
165
+ return ""
vq_gan_3d/weights/3DGrid-VQGAN_43.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22dbe6ccab1ee629c51e6b98bcc109dc898e7f033feb5c1f5a271d1ee3d0ab83
3
+ size 260643354