Digitgenerator / app.py
rohitkhadka's picture
Update app.py
360fe9d verified
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from model import Generator
# Load the model
device = torch.device("cpu")
generator = Generator()
generator.load_state_dict(torch.load("generator_digit.pth", map_location=device))
generator.eval()
def generate_images(digit):
noise = torch.randn(5, 100)
labels = torch.tensor([digit] * 5)
with torch.no_grad():
images = generator(noise, labels).squeeze().numpy()
# Plot the 5 images in one figure
fig, axs = plt.subplots(1, 5, figsize=(10, 2))
for i in range(5):
axs[i].imshow(images[i], cmap='gray')
axs[i].axis('off')
return fig
# Gradio Interface using modern syntax
demo = gr.Interface(
fn=generate_images,
inputs=gr.Slider(0, 9, step=1, label="Digit (0–9)"),
outputs=gr.Plot(label="Generated Images"),
title="MNIST Digit Generator",
description="Generates 5 handwritten images of the selected digit using a trained GAN."
)
demo.launch()