Spaces:
Running
Running
| """Adapted from https://github.com/SongweiGe/TATS""" | |
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
| import math | |
| import argparse | |
| import numpy as np | |
| import pickle as pkl | |
| import random | |
| import gc | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.backends.cudnn as cudnn | |
| import torch.distributed as dist | |
| from vq_gan_3d.utils import shift_dim, adopt_weight, comp_getattr | |
| from vq_gan_3d.model.lpips import LPIPS | |
| from vq_gan_3d.model.codebook import Codebook | |
| def silu(x): | |
| return x*torch.sigmoid(x) | |
| class SiLU(nn.Module): | |
| def __init__(self): | |
| super(SiLU, self).__init__() | |
| def forward(self, x): | |
| return silu(x) | |
| def hinge_d_loss(logits_real, logits_fake): | |
| loss_real = torch.mean(F.relu(1. - logits_real)) | |
| loss_fake = torch.mean(F.relu(1. + logits_fake)) | |
| d_loss = 0.5 * (loss_real + loss_fake) | |
| return d_loss | |
| def vanilla_d_loss(logits_real, logits_fake): | |
| d_loss = 0.5 * ( | |
| torch.mean(torch.nn.functional.softplus(-logits_real)) + | |
| torch.mean(torch.nn.functional.softplus(logits_fake))) | |
| return d_loss | |
| class MeanPooling(nn.Module): | |
| def __init__(self, kernel_size=16): | |
| super(MeanPooling, self).__init__() | |
| # Define a 3D average pooling layer | |
| self.pool = nn.AvgPool3d(kernel_size=kernel_size) | |
| def forward(self, x): | |
| # Apply average pooling | |
| x = self.pool(x) | |
| # Flatten the tensor to a single dimension per batch element | |
| x = x.view(x.size(0), -1) | |
| return x | |
| class VQGAN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self._set_seed(0) | |
| self.embedding_dim = 256 | |
| self.n_codes = 16384 | |
| self.encoder = Encoder(16, [4,4,4], 1, 'group', 'replicate', 32) | |
| self.decoder = Decoder(16, [4,4,4], 1, 'group', 32) | |
| self.enc_out_ch = self.encoder.out_channels | |
| self.pre_vq_conv = SamePadConv3d(self.enc_out_ch, 256, 1, padding_type='replicate') | |
| self.post_vq_conv = SamePadConv3d(256, self.enc_out_ch, 1) | |
| self.codebook = Codebook(16384, 256, no_random_restart=False, restart_thres=False) | |
| self.pooling = MeanPooling(kernel_size=4) | |
| self.gan_feat_weight = 4 | |
| # TODO: Changed batchnorm from sync to normal | |
| self.image_discriminator = NLayerDiscriminator(1, 64, 3, norm_layer=nn.BatchNorm2d) | |
| self.disc_loss = hinge_d_loss | |
| self.perceptual_model = LPIPS() | |
| self.image_gan_weight = 1 | |
| self.perceptual_weight = 4 | |
| self.l1_weight = 4 | |
| def encode(self, x, include_embeddings=False, quantize=True): | |
| h = self.pre_vq_conv(self.encoder(x)) | |
| if quantize: | |
| vq_output = self.codebook(h) | |
| if include_embeddings: | |
| return vq_output['embeddings'], vq_output['encodings'] | |
| else: | |
| return vq_output['encodings'] | |
| return h | |
| def decode(self, latent, quantize=False): | |
| if quantize: | |
| vq_output = self.codebook(latent) | |
| latent = vq_output['encodings'] | |
| h = F.embedding(latent, self.codebook.embeddings) | |
| h = self.post_vq_conv(shift_dim(h, -1, 1)) | |
| return self.decoder(h) | |
| def feature_extraction(self, x): | |
| """Extract embeddings given a grid.""" | |
| h = self.encode(x, include_embeddings=False, quantize=False) | |
| return self.pooling(h.permute(0, 2, 3, 4, 1)) | |
| def forward(self, global_step, x, optimizer_idx=None, log_image=False, gpu_id=0): | |
| B, C, T, H, W = x.shape | |
| z = self.pre_vq_conv(self.encoder(x)) | |
| vq_output = self.codebook(z, gpu_id) | |
| #vq_output['embeddings'] = torch.exp(vq_output['embeddings']) | |
| x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) | |
| recon_loss = (F.l1_loss(x_recon, x) * self.l1_weight) | |
| # Selects one random 2D image from each 3D Image | |
| frame_idx = torch.randint(0, T, [B]).to(gpu_id) | |
| frame_idx_selected = frame_idx.reshape(-1, | |
| 1, 1, 1, 1).repeat(1, C, 1, H, W) | |
| frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) | |
| frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) | |
| if log_image: | |
| return frames, frames_recon, x, x_recon | |
| if optimizer_idx == 0: | |
| # Autoencoder - train the "generator" | |
| # Perceptual loss | |
| perceptual_loss = 0 | |
| if self.perceptual_weight > 0: | |
| perceptual_loss = self.perceptual_model( | |
| frames, frames_recon).mean() * self.perceptual_weight | |
| # perceptual_loss = .123 | |
| # Discriminator loss (turned on after a certain epoch) | |
| logits_image_fake, pred_image_fake = self.image_discriminator( | |
| frames_recon) | |
| g_image_loss = -torch.mean(logits_image_fake) | |
| g_loss = self.image_gan_weight*g_image_loss | |
| disc_factor = adopt_weight( | |
| global_step, threshold=self.cfg.model.discriminator_iter_start) | |
| aeloss = disc_factor * g_loss | |
| # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator | |
| image_gan_feat_loss = 0 | |
| feat_weights = 4.0 / (3 + 1) | |
| if self.image_gan_weight > 0: | |
| logits_image_real, pred_image_real = self.image_discriminator( | |
| frames) | |
| for i in range(len(pred_image_fake)-1): | |
| image_gan_feat_loss += feat_weights * \ | |
| F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( | |
| )) * (self.image_gan_weight > 0) | |
| gan_feat_loss = disc_factor * self.gan_feat_weight * \ | |
| (image_gan_feat_loss) | |
| return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss, (g_image_loss, image_gan_feat_loss, vq_output['commitment_loss'], vq_output['perplexity']) | |
| if optimizer_idx == 1: | |
| # Train discriminator | |
| logits_image_real, _ = self.image_discriminator(frames.detach()) | |
| logits_image_fake, _ = self.image_discriminator( | |
| frames_recon.detach()) | |
| d_image_loss = self.disc_loss(logits_image_real, logits_image_fake) | |
| disc_factor = adopt_weight( | |
| global_step, threshold=self.cfg.model.discriminator_iter_start) | |
| discloss = disc_factor * \ | |
| (self.image_gan_weight*d_image_loss) | |
| return discloss | |
| perceptual_loss = self.perceptual_model( | |
| frames, frames_recon) * self.perceptual_weight | |
| return recon_loss, x_recon, vq_output, perceptual_loss | |
| def load_checkpoint(self, ckpt_path): | |
| # load checkpoint file | |
| ckpt_dict = torch.load(ckpt_path, map_location='cpu', weights_only=False) | |
| # load hyparameters | |
| self.config = ckpt_dict['hparams']['_content'] | |
| self.embedding_dim = self.config['model']['embedding_dim'] | |
| self.n_codes = self.config['model']['n_codes'] | |
| # instantiate modules | |
| self.encoder = Encoder( | |
| self.config['model']['n_hiddens'], | |
| self.config['model']['downsample'], | |
| self.config['dataset']['image_channels'], | |
| self.config['model']['norm_type'], | |
| self.config['model']['padding_type'], | |
| self.config['model']['num_groups'], | |
| ) | |
| self.decoder = Decoder( | |
| self.config['model']['n_hiddens'], | |
| self.config['model']['downsample'], | |
| self.config['dataset']['image_channels'], | |
| self.config['model']['norm_type'], | |
| self.config['model']['num_groups'] | |
| ) | |
| self.enc_out_ch = self.encoder.out_channels | |
| self.pre_vq_conv = SamePadConv3d(self.enc_out_ch, self.embedding_dim, 1, padding_type=self.config['model']['padding_type']) | |
| self.post_vq_conv = SamePadConv3d(self.embedding_dim, self.enc_out_ch, 1) | |
| self.codebook = Codebook( | |
| self.n_codes, | |
| self.embedding_dim, | |
| no_random_restart=self.config['model']['no_random_restart'], | |
| restart_thres=False | |
| ) | |
| self.gan_feat_weight = self.config['model']['gan_feat_weight'] | |
| # TODO: Changed batchnorm from sync to normal | |
| self.image_discriminator = NLayerDiscriminator( | |
| self.config['dataset']['image_channels'], | |
| self.config['model']['disc_channels'], | |
| self.config['model']['disc_layers'], | |
| norm_layer=nn.BatchNorm2d | |
| ) | |
| self.disc_loss = hinge_d_loss | |
| self.perceptual_model = LPIPS() | |
| self.image_gan_weight = self.config['model']['gan_feat_weight'] | |
| self.perceptual_weight = self.config['model']['perceptual_weight'] | |
| self.l1_weight = self.config['model']['l1_weight'] | |
| # restore model weights | |
| self.load_state_dict(ckpt_dict["MODEL_STATE"], strict=True) | |
| # load RNG states each time the model and states are loaded from checkpoint | |
| if 'rng' in self.config: | |
| rng = self.config['rng'] | |
| for key, value in rng.items(): | |
| if key =='torch_state': | |
| torch.set_rng_state(value.cpu()) | |
| elif key =='cuda_state': | |
| torch.cuda.set_rng_state(value.cpu()) | |
| elif key =='numpy_state': | |
| np.random.set_state(value) | |
| elif key =='python_state': | |
| random.setstate(value) | |
| else: | |
| print('unrecognized state') | |
| def log_images(self, batch, **kwargs): | |
| log = dict() | |
| x = batch['data'] | |
| x = x.to(self.device) | |
| frames, frames_rec, _, _ = self(x, log_image=True) | |
| log["inputs"] = frames | |
| log["reconstructions"] = frames_rec | |
| #log['mean_org'] = batch['mean_org'] | |
| #log['std_org'] = batch['std_org'] | |
| return log | |
| def _set_seed(self, value): | |
| print('Random Seed:', value) | |
| random.seed(value) | |
| torch.manual_seed(value) | |
| torch.cuda.manual_seed(value) | |
| torch.cuda.manual_seed_all(value) | |
| np.random.seed(value) | |
| cudnn.deterministic = True | |
| cudnn.benchmark = True | |
| cudnn.enabled = True | |
| def Normalize(in_channels, norm_type='group', num_groups=32): | |
| assert norm_type in ['group', 'batch'] | |
| if norm_type == 'group': | |
| # TODO Changed num_groups from 32 to 8 | |
| return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
| elif norm_type == 'batch': | |
| return torch.nn.SyncBatchNorm(in_channels) | |
| class Encoder(nn.Module): | |
| def __init__(self, n_hiddens = 16, downsample = [2,2,2] , image_channel=64, norm_type='group', padding_type='replicate', num_groups=32): | |
| super().__init__() | |
| n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) | |
| self.conv_blocks = nn.ModuleList() | |
| max_ds = n_times_downsample.max() | |
| self.conv_first = SamePadConv3d( | |
| image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) | |
| for i in range(max_ds): | |
| block = nn.Module() | |
| in_channels = n_hiddens * 2**i | |
| out_channels = n_hiddens * 2**(i+1) | |
| stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) | |
| block.down = SamePadConv3d( | |
| in_channels, out_channels, 4, stride=stride, padding_type=padding_type) | |
| block.res = ResBlock( | |
| out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) | |
| self.conv_blocks.append(block) | |
| n_times_downsample -= 1 | |
| self.final_block = nn.Sequential( | |
| Normalize(out_channels, norm_type, num_groups=num_groups), | |
| SiLU() | |
| ) | |
| self.out_channels = out_channels | |
| def forward(self, x): | |
| h = self.conv_first(x) | |
| for block in self.conv_blocks: | |
| h = block.down(h) | |
| h = block.res(h) | |
| h = self.final_block(h) | |
| return h | |
| class Decoder(nn.Module): | |
| def __init__(self, n_hiddens = 16, upsample= [4,4,4], image_channel=1, norm_type='group', num_groups=1): | |
| super().__init__() | |
| n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) | |
| print('n_times_upsample :', n_times_upsample) | |
| max_us = n_times_upsample.max() | |
| print('max_us :', max_us) | |
| in_channels = n_hiddens*2**max_us | |
| self.final_block = nn.Sequential( | |
| Normalize(in_channels, norm_type, num_groups=num_groups), | |
| SiLU() | |
| ) | |
| self.conv_blocks = nn.ModuleList() | |
| for i in range(max_us): | |
| block = nn.Module() | |
| in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1) | |
| out_channels = n_hiddens*2**(max_us-i) | |
| us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) | |
| block.up = SamePadConvTranspose3d( | |
| in_channels, out_channels, 4, stride=us) | |
| block.res1 = ResBlock( | |
| out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) | |
| block.res2 = ResBlock( | |
| out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) | |
| self.conv_blocks.append(block) | |
| n_times_upsample -= 1 | |
| self.conv_last = SamePadConv3d( | |
| out_channels, image_channel, kernel_size=3) | |
| def forward(self, x): | |
| h = self.final_block(x) | |
| for i, block in enumerate(self.conv_blocks): | |
| h = block.up(h) | |
| h = block.res1(h) | |
| h = block.res2(h) | |
| h = self.conv_last(h) | |
| return h | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups) | |
| self.conv1 = SamePadConv3d( | |
| in_channels, out_channels, kernel_size=3, padding_type=padding_type) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups) | |
| self.conv2 = SamePadConv3d( | |
| out_channels, out_channels, kernel_size=3, padding_type=padding_type) | |
| if self.in_channels != self.out_channels: | |
| self.conv_shortcut = SamePadConv3d( | |
| in_channels, out_channels, kernel_size=3, padding_type=padding_type) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = silu(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = silu(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| x = self.conv_shortcut(x) | |
| return x+h | |
| # Does not support dilation | |
| class SamePadConv3d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): | |
| super().__init__() | |
| if isinstance(kernel_size, int): | |
| kernel_size = (kernel_size,) * 3 | |
| if isinstance(stride, int): | |
| stride = (stride,) * 3 | |
| # assumes that the input shape is divisible by stride | |
| total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) | |
| pad_input = [] | |
| for p in total_pad[::-1]: # reverse since F.pad starts from last dim | |
| pad_input.append((p // 2 + p % 2, p // 2)) | |
| pad_input = sum(pad_input, tuple()) | |
| self.pad_input = pad_input | |
| self.padding_type = padding_type | |
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, | |
| stride=stride, padding=0, bias=bias) | |
| def forward(self, x): | |
| return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) | |
| class SamePadConvTranspose3d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): | |
| super().__init__() | |
| if isinstance(kernel_size, int): | |
| kernel_size = (kernel_size,) * 3 | |
| if isinstance(stride, int): | |
| stride = (stride,) * 3 | |
| total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) | |
| pad_input = [] | |
| for p in total_pad[::-1]: # reverse since F.pad starts from last dim | |
| pad_input.append((p // 2 + p % 2, p // 2)) | |
| pad_input = sum(pad_input, tuple()) | |
| self.pad_input = pad_input | |
| self.padding_type = padding_type | |
| self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, | |
| stride=stride, bias=bias, | |
| padding=tuple([k - 1 for k in kernel_size])) | |
| def forward(self, x): | |
| return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) | |
| class NLayerDiscriminator(nn.Module): | |
| def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): | |
| # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): | |
| super(NLayerDiscriminator, self).__init__() | |
| self.getIntermFeat = getIntermFeat | |
| self.n_layers = n_layers | |
| kw = 4 | |
| padw = int(np.ceil((kw-1.0)/2)) | |
| sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, | |
| stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] | |
| nf = ndf | |
| for n in range(1, n_layers): | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [[ | |
| nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), | |
| norm_layer(nf), nn.LeakyReLU(0.2, True) | |
| ]] | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [[ | |
| nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), | |
| norm_layer(nf), | |
| nn.LeakyReLU(0.2, True) | |
| ]] | |
| sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, | |
| stride=1, padding=padw)]] | |
| if use_sigmoid: | |
| sequence += [[nn.Sigmoid()]] | |
| if getIntermFeat: | |
| for n in range(len(sequence)): | |
| setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) | |
| else: | |
| sequence_stream = [] | |
| for n in range(len(sequence)): | |
| sequence_stream += sequence[n] | |
| self.model = nn.Sequential(*sequence_stream) | |
| def forward(self, input): | |
| if self.getIntermFeat: | |
| res = [input] | |
| for n in range(self.n_layers+2): | |
| model = getattr(self, 'model'+str(n)) | |
| res.append(model(res[-1])) | |
| return res[-1], res[1:] | |
| else: | |
| return self.model(input), _ | |
| class NLayerDiscriminator3D(nn.Module): | |
| def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): | |
| super(NLayerDiscriminator3D, self).__init__() | |
| self.getIntermFeat = getIntermFeat | |
| self.n_layers = n_layers | |
| kw = 4 | |
| padw = int(np.ceil((kw-1.0)/2)) | |
| sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, | |
| stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] | |
| nf = ndf | |
| for n in range(1, n_layers): | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [[ | |
| nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), | |
| norm_layer(nf), nn.LeakyReLU(0.2, True) | |
| ]] | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [[ | |
| nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), | |
| norm_layer(nf), | |
| nn.LeakyReLU(0.2, True) | |
| ]] | |
| sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, | |
| stride=1, padding=padw)]] | |
| if use_sigmoid: | |
| sequence += [[nn.Sigmoid()]] | |
| if getIntermFeat: | |
| for n in range(len(sequence)): | |
| setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) | |
| else: | |
| sequence_stream = [] | |
| for n in range(len(sequence)): | |
| sequence_stream += sequence[n] | |
| self.model = nn.Sequential(*sequence_stream) | |
| def forward(self, input): | |
| if self.getIntermFeat: | |
| res = [input] | |
| for n in range(self.n_layers+2): | |
| model = getattr(self, 'model'+str(n)) | |
| res.append(model(res[-1])) | |
| return res[-1], res[1:] | |
| else: | |
| return self.model(input), _ | |
| def load_VQGAN(folder="../data/checkpoints/pretrained", ckpt_filename="VQGAN_43.pt"): | |
| model = VQGAN() | |
| model.load_checkpoint(os.path.join(folder, ckpt_filename)) | |
| model.eval() | |
| return model |