| import torch | |
| from torch import nn | |
| from torch.nn import Module | |
| from models.StyleCLIP.models.stylegan2.model import EqualLinear, PixelNorm | |
| class Mapper(Module): | |
| def __init__(self, opts): | |
| super(Mapper, self).__init__() | |
| self.opts = opts | |
| layers = [PixelNorm()] | |
| for i in range(4): | |
| layers.append( | |
| EqualLinear( | |
| 512, 512, lr_mul=0.01, activation='fused_lrelu' | |
| ) | |
| ) | |
| self.mapping = nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.mapping(x) | |
| return x | |
| class SingleMapper(Module): | |
| def __init__(self, opts): | |
| super(SingleMapper, self).__init__() | |
| self.opts = opts | |
| self.mapping = Mapper(opts) | |
| def forward(self, x): | |
| out = self.mapping(x) | |
| return out | |
| class LevelsMapper(Module): | |
| def __init__(self, opts): | |
| super(LevelsMapper, self).__init__() | |
| self.opts = opts | |
| if not opts.no_coarse_mapper: | |
| self.course_mapping = Mapper(opts) | |
| if not opts.no_medium_mapper: | |
| self.medium_mapping = Mapper(opts) | |
| if not opts.no_fine_mapper: | |
| self.fine_mapping = Mapper(opts) | |
| def forward(self, x): | |
| x_coarse = x[:, :4, :] | |
| x_medium = x[:, 4:8, :] | |
| x_fine = x[:, 8:, :] | |
| if not self.opts.no_coarse_mapper: | |
| x_coarse = self.course_mapping(x_coarse) | |
| else: | |
| x_coarse = torch.zeros_like(x_coarse) | |
| if not self.opts.no_medium_mapper: | |
| x_medium = self.medium_mapping(x_medium) | |
| else: | |
| x_medium = torch.zeros_like(x_medium) | |
| if not self.opts.no_fine_mapper: | |
| x_fine = self.fine_mapping(x_fine) | |
| else: | |
| x_fine = torch.zeros_like(x_fine) | |
| out = torch.cat([x_coarse, x_medium, x_fine], dim=1) | |
| return out | |