rohitkhadka commited on
Commit
7b1f8f1
·
verified ·
1 Parent(s): ed003b0

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +34 -0
  2. generator_digit.pth +3 -0
  3. model.py +23 -0
  4. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ from model import Generator
6
+
7
+ # Load the model
8
+ device = torch.device("cpu")
9
+ generator = Generator()
10
+ generator.load_state_dict(torch.load("generator_digit.pth", map_location=device))
11
+ generator.eval()
12
+
13
+ def generate_images(digit):
14
+ noise = torch.randn(5, 100)
15
+ labels = torch.tensor([digit] * 5)
16
+ with torch.no_grad():
17
+ images = generator(noise, labels).squeeze().numpy()
18
+
19
+ # Create a grid of images (5 horizontal)
20
+ fig, axs = plt.subplots(1, 5, figsize=(10, 2))
21
+ for i in range(5):
22
+ axs[i].imshow(images[i], cmap='gray')
23
+ axs[i].axis('off')
24
+ return fig
25
+
26
+ interface = gr.Interface(
27
+ fn=generate_images,
28
+ inputs=gr.inputs.Slider(0, 9, step=1, label="Digit (0–9)"),
29
+ outputs=gr.outputs.Image(type="plot"),
30
+ title="MNIST Digit Generator",
31
+ description="Generate 5 images of the digit using a conditional GAN trained on MNIST."
32
+ )
33
+
34
+ interface.launch()
generator_digit.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72679173523a721a7ecf0499d0b3eda442158eddcb50c4989b4570ed5976ca0b
3
+ size 5959844
model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Generator(nn.Module):
5
+ def __init__(self, noise_dim=100, num_classes=10, img_dim=28*28):
6
+ super(Generator, self).__init__()
7
+ self.label_emb = nn.Embedding(num_classes, num_classes)
8
+ self.model = nn.Sequential(
9
+ nn.Linear(noise_dim + num_classes, 256),
10
+ nn.ReLU(),
11
+ nn.Linear(256, 512),
12
+ nn.ReLU(),
13
+ nn.Linear(512, 1024),
14
+ nn.ReLU(),
15
+ nn.Linear(1024, img_dim),
16
+ nn.Tanh()
17
+ )
18
+
19
+ def forward(self, z, labels):
20
+ label_input = self.label_emb(labels)
21
+ x = torch.cat([z, label_input], dim=1)
22
+ x = self.model(x)
23
+ return x.view(-1, 1, 28, 28)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ gradio
3
+ matplotlib