Digitgenerator / model.py
rohitkhadka's picture
Upload 4 files
7b1f8f1 verified
raw
history blame contribute delete
748 Bytes
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, noise_dim=100, num_classes=10, img_dim=28*28):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.model = nn.Sequential(
nn.Linear(noise_dim + num_classes, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, img_dim),
nn.Tanh()
)
def forward(self, z, labels):
label_input = self.label_emb(labels)
x = torch.cat([z, label_input], dim=1)
x = self.model(x)
return x.view(-1, 1, 28, 28)