Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class Generator(nn.Module): | |
| def __init__(self, noise_dim=100, num_classes=10, img_dim=28*28): | |
| super(Generator, self).__init__() | |
| self.label_emb = nn.Embedding(num_classes, num_classes) | |
| self.model = nn.Sequential( | |
| nn.Linear(noise_dim + num_classes, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 1024), | |
| nn.ReLU(), | |
| nn.Linear(1024, img_dim), | |
| nn.Tanh() | |
| ) | |
| def forward(self, z, labels): | |
| label_input = self.label_emb(labels) | |
| x = torch.cat([z, label_input], dim=1) | |
| x = self.model(x) | |
| return x.view(-1, 1, 28, 28) | |