Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from model import Generator | |
| # Load model | |
| device = torch.device("cpu") | |
| generator = Generator() | |
| generator.load_state_dict(torch.load("generator_digit.pth", map_location=device)) | |
| generator.eval() | |
| st.title("🧠 MNIST Digit Generator") | |
| digit = st.selectbox("Select a digit (0-9):", list(range(10))) | |
| if st.button("Generate 5 Images"): | |
| noise = torch.randn(5, 100) | |
| labels = torch.tensor([digit] * 5) | |
| with torch.no_grad(): | |
| generated = generator(noise, labels) | |
| fig, axs = plt.subplots(1, 5, figsize=(10, 2)) | |
| for i in range(5): | |
| axs[i].imshow(generated[i].squeeze(), cmap='gray') | |
| axs[i].axis('off') | |
| st.pyplot(fig) | |