Spaces:
Running
on
Zero
Running
on
Zero
| import torch.nn as nn | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale) + shift | |
| class FinalLayer(nn.Module): | |
| def __init__(self, hidden_size, out_channels): | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_size, out_channels, bias=True) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.Linear(hidden_size, 2*hidden_size, bias=True) | |
| ) | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) | |
| x = modulate(self.norm_final(x), shift, scale) | |
| x = self.linear(x) | |
| return x |