Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| # ---------------------------------------------------------------------------- | |
| # Improved preconditioning proposed in the paper "Elucidating the Design | |
| # Space of Diffusion-Based Generative networks" (EDM). | |
| class EDMPrecond(torch.nn.Module): | |
| def __init__( | |
| self, | |
| network, | |
| label_dim=0, # Number of class labels, 0 = unconditional. | |
| sigma_min=0, # Minimum supported noise level. | |
| sigma_max=float("inf"), # Maximum supported noise level. | |
| sigma_data=0.5, # Expected standard deviation of the training data. | |
| ): | |
| super().__init__() | |
| self.label_dim = label_dim | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.sigma_data = sigma_data | |
| self.network = network | |
| def forward(self, x, sigma, conditioning=None, **network_kwargs): | |
| x = x.to(torch.float32) | |
| sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) | |
| conditioning = ( | |
| None | |
| if self.label_dim == 0 | |
| else torch.zeros([1, self.label_dim], device=x.device) | |
| if conditioning is None | |
| else conditioning.to(torch.float32) | |
| ) | |
| c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) | |
| c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() | |
| c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() | |
| c_noise = sigma.log() / 4 | |
| F_x = self.network( | |
| (c_in * x), | |
| c_noise.flatten(), | |
| conditioning=conditioning, | |
| **network_kwargs, | |
| ) | |
| D_x = c_skip * x + c_out * F_x.to(torch.float32) | |
| return D_x | |
| def round_sigma(self, sigma): | |
| return torch.as_tensor(sigma) | |
| class DDPMPrecond(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, network, batch): | |
| F_x = network(batch) | |
| return F_x | |