Spaces:
Build error
Build error
| from typing import Iterable | |
| import torch | |
| from torch import nn | |
| class NoiseRegularizer(nn.Module): | |
| def forward(self, noises: Iterable[torch.Tensor]): | |
| loss = 0 | |
| for noise in noises: | |
| size = noise.shape[2] | |
| while True: | |
| loss = ( | |
| loss | |
| + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) | |
| + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) | |
| ) | |
| if size <= 8: | |
| break | |
| noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2]) | |
| noise = noise.mean([3, 5]) | |
| size //= 2 | |
| return loss | |
| def normalize(noises: Iterable[torch.Tensor]): | |
| for noise in noises: | |
| mean = noise.mean() | |
| std = noise.std() | |
| noise.data.add_(-mean).div_(std) | |