Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is licensed under a Creative Commons | |
| # Attribution-NonCommercial-ShareAlike 4.0 International License. | |
| # You should have received a copy of the license along with this | |
| # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ | |
| """Improved diffusion model architecture proposed in the paper | |
| "Analyzing and Improving the Training Dynamics of Diffusion Models".""" | |
| import numpy as np | |
| import torch | |
| #---------------------------------------------------------------------------- | |
| # Variant of constant() that inherits dtype and device from the given | |
| # reference tensor by default. | |
| _constant_cache = dict() | |
| def constant(value, shape=None, dtype=None, device=None, memory_format=None): | |
| value = np.asarray(value) | |
| if shape is not None: | |
| shape = tuple(shape) | |
| if dtype is None: | |
| dtype = torch.get_default_dtype() | |
| if device is None: | |
| device = torch.device('cpu') | |
| if memory_format is None: | |
| memory_format = torch.contiguous_format | |
| key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) | |
| tensor = _constant_cache.get(key, None) | |
| if tensor is None: | |
| tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) | |
| if shape is not None: | |
| tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) | |
| tensor = tensor.contiguous(memory_format=memory_format) | |
| _constant_cache[key] = tensor | |
| return tensor | |
| def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None): | |
| if dtype is None: | |
| dtype = ref.dtype | |
| if device is None: | |
| device = ref.device | |
| return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format) | |
| #---------------------------------------------------------------------------- | |
| # Normalize given tensor to unit magnitude with respect to the given | |
| # dimensions. Default = all dimensions except the first. | |
| def normalize(x, dim=None, eps=1e-4): | |
| if dim is None: | |
| dim = list(range(1, x.ndim)) | |
| norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) | |
| norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) | |
| return x / norm.to(x.dtype) | |
| class Normalize(torch.nn.Module): | |
| def __init__(self, dim=None, eps=1e-4): | |
| super().__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| def forward(self, x): | |
| return normalize(x, dim=self.dim, eps=self.eps) | |
| #---------------------------------------------------------------------------- | |
| # Upsample or downsample the given tensor with the given filter, | |
| # or keep it as is. | |
| def resample(x, f=[1, 1], mode='keep'): | |
| if mode == 'keep': | |
| return x | |
| f = np.float32(f) | |
| assert f.ndim == 1 and len(f) % 2 == 0 | |
| pad = (len(f) - 1) // 2 | |
| f = f / f.sum() | |
| f = np.outer(f, f)[np.newaxis, np.newaxis, :, :] | |
| f = const_like(x, f) | |
| c = x.shape[1] | |
| if mode == 'down': | |
| return torch.nn.functional.conv2d(x, | |
| f.tile([c, 1, 1, 1]), | |
| groups=c, | |
| stride=2, | |
| padding=(pad, )) | |
| assert mode == 'up' | |
| return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]), | |
| groups=c, | |
| stride=2, | |
| padding=(pad, )) | |
| #---------------------------------------------------------------------------- | |
| # Magnitude-preserving SiLU (Equation 81). | |
| def mp_silu(x): | |
| return torch.nn.functional.silu(x) / 0.596 | |
| class MPSiLU(torch.nn.Module): | |
| def forward(self, x): | |
| return mp_silu(x) | |
| #---------------------------------------------------------------------------- | |
| # Magnitude-preserving sum (Equation 88). | |
| def mp_sum(a, b, t=0.5): | |
| return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2) | |
| #---------------------------------------------------------------------------- | |
| # Magnitude-preserving concatenation (Equation 103). | |
| def mp_cat(a, b, dim=1, t=0.5): | |
| Na = a.shape[dim] | |
| Nb = b.shape[dim] | |
| C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2)) | |
| wa = C / np.sqrt(Na) * (1 - t) | |
| wb = C / np.sqrt(Nb) * t | |
| return torch.cat([wa * a, wb * b], dim=dim) | |
| #---------------------------------------------------------------------------- | |
| # Magnitude-preserving convolution or fully-connected layer (Equation 47) | |
| # with force weight normalization (Equation 66). | |
| class MPConv1D(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size)) | |
| self.weight_norm_removed = False | |
| def forward(self, x, gain=1): | |
| assert self.weight_norm_removed, 'call remove_weight_norm() before inference' | |
| w = self.weight * gain | |
| if w.ndim == 2: | |
| return x @ w.t() | |
| assert w.ndim == 3 | |
| return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, )) | |
| def remove_weight_norm(self): | |
| w = self.weight.to(torch.float32) | |
| w = normalize(w) # traditional weight normalization | |
| w = w / np.sqrt(w[0].numel()) | |
| w = w.to(self.weight.dtype) | |
| self.weight.data.copy_(w) | |
| self.weight_norm_removed = True | |
| return self | |