FourierVision / app.py
GenAIDevTOProd's picture
Update app.py
2f23199 verified
import gradio as gr
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
# CIFAR-10 class names
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# Load ResNet18 model and adapt final layer for CIFAR-10
resnet18 = models.resnet18(pretrained=True)
resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10) # Replace final layer
resnet18.load_state_dict(torch.load("resnet18_fft_cifar10.pth", map_location=torch.device('cpu')))
resnet18.eval()
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)), # ResNet18 expects 224x224
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# FFT Visualizer
def apply_fft_visualization(image: Image.Image):
img_np = np.array(image.resize((32, 32))) / 255.0
fft_images = []
for i in range(3):
channel = img_np[:, :, i]
fft = np.fft.fft2(channel)
fft_shift = np.fft.fftshift(fft)
magnitude = np.log1p(np.abs(fft_shift))
fft_images.append(magnitude)
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
for i in range(3):
axs[i].imshow(fft_images[i], cmap='inferno')
axs[i].set_title(['Red', 'Green', 'Blue'][i])
axs[i].axis('off')
plt.tight_layout()
return fig
# Prediction Function
def predict(img: Image.Image, mode="Raw"):
if mode == "FFT":
return None, apply_fft_visualization(img)
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = resnet18(img_tensor)
_, predicted = torch.max(outputs, 1)
label = classes[predicted.item()]
return label, None
# Gradio App
gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Radio(["Raw", "FFT"], label="Mode", value="Raw")
],
outputs=[
gr.Label(label="Prediction"),
gr.Plot(label="FFT Visualization")
],
title="CIFAR-10 Visual Analyzer (ResNet18)",
description="Upload an image and choose mode: Raw classification (ResNet18) or visualize FFT of RGB channels.\n\nDisclaimer: This model is trained on CIFAR-10 and works best on low-res, centered images."
).launch()