import streamlit as st import torch import numpy as np import matplotlib.pyplot as plt from model import Generator # Load model device = torch.device("cpu") generator = Generator() generator.load_state_dict(torch.load("generator_digit.pth", map_location=device)) generator.eval() st.title("🧠 MNIST Digit Generator") digit = st.selectbox("Select a digit (0-9):", list(range(10))) if st.button("Generate 5 Images"): noise = torch.randn(5, 100) labels = torch.tensor([digit] * 5) with torch.no_grad(): generated = generator(noise, labels) fig, axs = plt.subplots(1, 5, figsize=(10, 2)) for i in range(5): axs[i].imshow(generated[i].squeeze(), cmap='gray') axs[i].axis('off') st.pyplot(fig)