GenAIDevTOProd commited on
Commit
0129c5b
·
verified ·
1 Parent(s): 63a4d81

Upload 2 files

Browse files

# CIFAR-10 Visual Analyzer 🔍📷

This app classifies uploaded images into one of the 10 CIFAR-10 classes using a fine-tuned ResNet-18 model, and optionally provides a Fourier Transform visualization of RGB channels.

## Model
- Base: ResNet-18
- Trained on: CIFAR-10 (with FFT exploration)
- Accuracy: ~78% (CPU-trained, 12 epochs)

## Features
- **Image Classification** into: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
- **FFT Mode**: Visualize frequency domain per RGB channel

## Usage
1. Upload any image
2. Choose between:
- **Raw**: Classify with the model
- **FFT**: Visualize RGB channel frequency maps

---

Built for educational and demonstrative purposes.

cifar10_fouriervision_gradioapp.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Cifar10-FourierVision-GradioApp.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1uw8cWaCxnSHf2CYhgeF_HYYdMeGP3odV
8
+ """
9
+
10
+ import gradio as gr
11
+ import torch
12
+ import torchvision.transforms as transforms
13
+ import torchvision.models as models
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ from PIL import Image
17
+
18
+ # CIFAR-10 class names
19
+ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
20
+ 'dog', 'frog', 'horse', 'ship', 'truck']
21
+
22
+ # Load ResNet18 model and adapt final layer for CIFAR-10
23
+ resnet18 = models.resnet18(pretrained=False)
24
+ resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10) # Replace final layer
25
+ resnet18.load_state_dict(torch.load("/content/sample_data/resnet18_fft_cifar10.pth", map_location=torch.device('cpu')))
26
+ resnet18.eval()
27
+
28
+ # Image transform
29
+ transform = transforms.Compose([
30
+ transforms.Resize((224, 224)), # ResNet18 expects 224x224
31
+ transforms.ToTensor(),
32
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
33
+ ])
34
+
35
+ # FFT Visualizer
36
+ def apply_fft_visualization(image: Image.Image):
37
+ img_np = np.array(image.resize((32, 32))) / 255.0
38
+ fft_images = []
39
+ for i in range(3):
40
+ channel = img_np[:, :, i]
41
+ fft = np.fft.fft2(channel)
42
+ fft_shift = np.fft.fftshift(fft)
43
+ magnitude = np.log1p(np.abs(fft_shift))
44
+ fft_images.append(magnitude)
45
+
46
+ fig, axs = plt.subplots(1, 3, figsize=(12, 4))
47
+ for i in range(3):
48
+ axs[i].imshow(fft_images[i], cmap='inferno')
49
+ axs[i].set_title(['Red', 'Green', 'Blue'][i])
50
+ axs[i].axis('off')
51
+ plt.tight_layout()
52
+ return fig
53
+
54
+ # Prediction Function
55
+ def predict(img: Image.Image, mode="Raw"):
56
+ if mode == "FFT":
57
+ return None, apply_fft_visualization(img)
58
+
59
+ img_tensor = transform(img).unsqueeze(0)
60
+ with torch.no_grad():
61
+ outputs = resnet18(img_tensor)
62
+ _, predicted = torch.max(outputs, 1)
63
+ label = classes[predicted.item()]
64
+ return label, None
65
+
66
+ # Gradio App
67
+ gr.Interface(
68
+ fn=predict,
69
+ inputs=[
70
+ gr.Image(type="pil", label="Upload Image"),
71
+ gr.Radio(["Raw", "FFT"], label="Mode", value="Raw")
72
+ ],
73
+ outputs=[
74
+ gr.Label(label="Prediction"),
75
+ gr.Plot(label="FFT Visualization")
76
+ ],
77
+ title="CIFAR-10 Visual Analyzer (ResNet18)",
78
+ description="Upload an image and choose mode: Raw classification (ResNet18) or visualize FFT of RGB channels.\n\nDisclaimer: This model is trained on CIFAR-10 and works best on low-res, centered images."
79
+ ).launch()
80
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ matplotlib
5
+ numpy
6
+ Pillow