rohitkhadka commited on
Commit
360fe9d
·
verified ·
1 Parent(s): 7b1f8f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -34
app.py CHANGED
@@ -1,34 +1,35 @@
1
- import torch
2
- import numpy as np
3
- import gradio as gr
4
- import matplotlib.pyplot as plt
5
- from model import Generator
6
-
7
- # Load the model
8
- device = torch.device("cpu")
9
- generator = Generator()
10
- generator.load_state_dict(torch.load("generator_digit.pth", map_location=device))
11
- generator.eval()
12
-
13
- def generate_images(digit):
14
- noise = torch.randn(5, 100)
15
- labels = torch.tensor([digit] * 5)
16
- with torch.no_grad():
17
- images = generator(noise, labels).squeeze().numpy()
18
-
19
- # Create a grid of images (5 horizontal)
20
- fig, axs = plt.subplots(1, 5, figsize=(10, 2))
21
- for i in range(5):
22
- axs[i].imshow(images[i], cmap='gray')
23
- axs[i].axis('off')
24
- return fig
25
-
26
- interface = gr.Interface(
27
- fn=generate_images,
28
- inputs=gr.inputs.Slider(0, 9, step=1, label="Digit (0–9)"),
29
- outputs=gr.outputs.Image(type="plot"),
30
- title="MNIST Digit Generator",
31
- description="Generate 5 images of the digit using a conditional GAN trained on MNIST."
32
- )
33
-
34
- interface.launch()
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ from model import Generator
6
+
7
+ # Load the model
8
+ device = torch.device("cpu")
9
+ generator = Generator()
10
+ generator.load_state_dict(torch.load("generator_digit.pth", map_location=device))
11
+ generator.eval()
12
+
13
+ def generate_images(digit):
14
+ noise = torch.randn(5, 100)
15
+ labels = torch.tensor([digit] * 5)
16
+ with torch.no_grad():
17
+ images = generator(noise, labels).squeeze().numpy()
18
+
19
+ # Plot the 5 images in one figure
20
+ fig, axs = plt.subplots(1, 5, figsize=(10, 2))
21
+ for i in range(5):
22
+ axs[i].imshow(images[i], cmap='gray')
23
+ axs[i].axis('off')
24
+ return fig
25
+
26
+ # Gradio Interface using modern syntax
27
+ demo = gr.Interface(
28
+ fn=generate_images,
29
+ inputs=gr.Slider(0, 9, step=1, label="Digit (0–9)"),
30
+ outputs=gr.Plot(label="Generated Images"),
31
+ title="MNIST Digit Generator",
32
+ description="Generates 5 handwritten images of the selected digit using a trained GAN."
33
+ )
34
+
35
+ demo.launch()