Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| # CIFAR-10 class names | |
| classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck'] | |
| # Load ResNet18 model and adapt final layer for CIFAR-10 | |
| resnet18 = models.resnet18(pretrained=True) | |
| resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10) # Replace final layer | |
| resnet18.load_state_dict(torch.load("resnet18_fft_cifar10.pth", map_location=torch.device('cpu'))) | |
| resnet18.eval() | |
| # Image transform | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # ResNet18 expects 224x224 | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| # FFT Visualizer | |
| def apply_fft_visualization(image: Image.Image): | |
| img_np = np.array(image.resize((32, 32))) / 255.0 | |
| fft_images = [] | |
| for i in range(3): | |
| channel = img_np[:, :, i] | |
| fft = np.fft.fft2(channel) | |
| fft_shift = np.fft.fftshift(fft) | |
| magnitude = np.log1p(np.abs(fft_shift)) | |
| fft_images.append(magnitude) | |
| fig, axs = plt.subplots(1, 3, figsize=(12, 4)) | |
| for i in range(3): | |
| axs[i].imshow(fft_images[i], cmap='inferno') | |
| axs[i].set_title(['Red', 'Green', 'Blue'][i]) | |
| axs[i].axis('off') | |
| plt.tight_layout() | |
| return fig | |
| # Prediction Function | |
| def predict(img: Image.Image, mode="Raw"): | |
| if mode == "FFT": | |
| return None, apply_fft_visualization(img) | |
| img_tensor = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = resnet18(img_tensor) | |
| _, predicted = torch.max(outputs, 1) | |
| label = classes[predicted.item()] | |
| return label, None | |
| # Gradio App | |
| gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Radio(["Raw", "FFT"], label="Mode", value="Raw") | |
| ], | |
| outputs=[ | |
| gr.Label(label="Prediction"), | |
| gr.Plot(label="FFT Visualization") | |
| ], | |
| title="CIFAR-10 Visual Analyzer (ResNet18)", | |
| description="Upload an image and choose mode: Raw classification (ResNet18) or visualize FFT of RGB channels.\n\nDisclaimer: This model is trained on CIFAR-10 and works best on low-res, centered images." | |
| ).launch() |