import torch import torch.nn as nn class Generator(nn.Module): def __init__(self, noise_dim=100, num_classes=10, img_dim=28*28): super(Generator, self).__init__() self.label_emb = nn.Embedding(num_classes, num_classes) self.model = nn.Sequential( nn.Linear(noise_dim + num_classes, 256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, img_dim), nn.Tanh() ) def forward(self, z, labels): label_input = self.label_emb(labels) x = torch.cat([z, label_input], dim=1) x = self.model(x) return x.view(-1, 1, 28, 28)