rohitkhadka commited on
Commit
68e3ff0
·
verified ·
1 Parent(s): d5708ad

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +26 -0
  2. generator_digit.pth +3 -0
  3. model.py +23 -0
  4. requirements.txt +3 -3
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from model import Generator
6
+
7
+ # Load model
8
+ device = torch.device("cpu")
9
+ generator = Generator()
10
+ generator.load_state_dict(torch.load("generator_digit.pth", map_location=device))
11
+ generator.eval()
12
+
13
+ st.title("🧠 MNIST Digit Generator")
14
+ digit = st.selectbox("Select a digit (0-9):", list(range(10)))
15
+
16
+ if st.button("Generate 5 Images"):
17
+ noise = torch.randn(5, 100)
18
+ labels = torch.tensor([digit] * 5)
19
+ with torch.no_grad():
20
+ generated = generator(noise, labels)
21
+
22
+ fig, axs = plt.subplots(1, 5, figsize=(10, 2))
23
+ for i in range(5):
24
+ axs[i].imshow(generated[i].squeeze(), cmap='gray')
25
+ axs[i].axis('off')
26
+ st.pyplot(fig)
generator_digit.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72679173523a721a7ecf0499d0b3eda442158eddcb50c4989b4570ed5976ca0b
3
+ size 5959844
model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Generator(nn.Module):
5
+ def __init__(self, noise_dim=100, num_classes=10, img_dim=28*28):
6
+ super(Generator, self).__init__()
7
+ self.label_emb = nn.Embedding(num_classes, num_classes)
8
+ self.model = nn.Sequential(
9
+ nn.Linear(noise_dim + num_classes, 256),
10
+ nn.ReLU(),
11
+ nn.Linear(256, 512),
12
+ nn.ReLU(),
13
+ nn.Linear(512, 1024),
14
+ nn.ReLU(),
15
+ nn.Linear(1024, img_dim),
16
+ nn.Tanh()
17
+ )
18
+
19
+ def forward(self, z, labels):
20
+ label_input = self.label_emb(labels)
21
+ x = torch.cat([z, label_input], dim=1)
22
+ x = self.model(x)
23
+ return x.view(-1, 1, 28, 28)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- altair
2
- pandas
3
- streamlit
 
1
+ torch
2
+ streamlit
3
+ matplotlib