Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from model import Generator | |
| # Load the model | |
| device = torch.device("cpu") | |
| generator = Generator() | |
| generator.load_state_dict(torch.load("generator_digit.pth", map_location=device)) | |
| generator.eval() | |
| def generate_images(digit): | |
| noise = torch.randn(5, 100) | |
| labels = torch.tensor([digit] * 5) | |
| with torch.no_grad(): | |
| images = generator(noise, labels).squeeze().numpy() | |
| # Plot the 5 images in one figure | |
| fig, axs = plt.subplots(1, 5, figsize=(10, 2)) | |
| for i in range(5): | |
| axs[i].imshow(images[i], cmap='gray') | |
| axs[i].axis('off') | |
| return fig | |
| # Gradio Interface using modern syntax | |
| demo = gr.Interface( | |
| fn=generate_images, | |
| inputs=gr.Slider(0, 9, step=1, label="Digit (0–9)"), | |
| outputs=gr.Plot(label="Generated Images"), | |
| title="MNIST Digit Generator", | |
| description="Generates 5 handwritten images of the selected digit using a trained GAN." | |
| ) | |
| demo.launch() | |