DEIT / app.py
Godreign's picture
rollbak to old ui
9b1dcad verified
raw
history blame
17.3 kB
import gradio as gr
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from efficientnet_pytorch import EfficientNet
import timm
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import torchvision.transforms as T
import urllib.request
import json
import cv2
# ---------------------------
# Model Configs
# ---------------------------
MODEL_CONFIGS = {
"DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"},
"DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"},
"ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"},
"ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"},
"ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano"},
"EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"},
"EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"},
"ResNet-50": {"type": "timm", "id": "resnet50"},
"MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"},
"MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"},
"MobileViT-Small": {"type": "timm", "id": "mobilevit_s"},
"EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"},
"RegNetY-002": {"type": "timm", "id": "regnety_002"}
}
# ---------------------------
# ImageNet Labels
# ---------------------------
IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
with urllib.request.urlopen(IMAGENET_URL) as url:
IMAGENET_LABELS = json.load(url)
# ---------------------------
# Lazy Load
# ---------------------------
loaded_models = {}
def load_model(model_name):
if model_name in loaded_models:
return loaded_models[model_name]
config = MODEL_CONFIGS[model_name]
if config["type"] == "hf":
extractor = AutoFeatureExtractor.from_pretrained(config["id"])
model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
model.eval()
# Enable gradients for class-specific attention
for param in model.parameters():
param.requires_grad = True
elif config["type"] == "timm":
model = timm.create_model(config["id"], pretrained=True)
model.eval()
# Enable gradients for class-specific attention
for param in model.parameters():
param.requires_grad = True
extractor = None
elif config["type"] == "efficientnet":
model = EfficientNet.from_pretrained(config["id"])
model.eval()
# Enable gradients for class-specific attention
for param in model.parameters():
param.requires_grad = True
extractor = None
loaded_models[model_name] = (model, extractor)
return model, extractor
# ---------------------------
# Adversarial Noise
# ---------------------------
def add_adversarial_noise(image, epsilon):
"""Add random noise to image"""
img_array = np.array(image).astype(np.float32) / 255.0
noise = np.random.randn(*img_array.shape) * epsilon
noisy_img = np.clip(img_array + noise, 0, 1)
return Image.fromarray((noisy_img * 255).astype(np.uint8))
# ---------------------------
# Grad-CAM for Class-Specific Attention
# ---------------------------
def get_gradcam_for_class(model, image_tensor, class_idx):
grad = None
fmap = None
def forward_hook(module, input, output):
nonlocal fmap
fmap = output.detach()
def backward_hook(module, grad_in, grad_out):
nonlocal grad
grad = grad_out[0].detach()
# Find last conv layer
last_conv = None
for name, module in reversed(list(model.named_modules())):
if isinstance(module, torch.nn.Conv2d):
last_conv = module
break
if last_conv is None:
return np.ones((224, 224))
handle_fwd = last_conv.register_forward_hook(forward_hook)
handle_bwd = last_conv.register_full_backward_hook(backward_hook)
out = model(image_tensor)
score = out[0, class_idx]
model.zero_grad()
score.backward()
handle_fwd.remove()
handle_bwd.remove()
if grad is None or fmap is None:
return np.ones((224, 224))
weights = grad.mean(dim=(2, 3), keepdim=True)
cam = (weights * fmap).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = cam.squeeze().cpu().numpy()
cam = cv2.resize(cam, (224, 224))
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam
# ---------------------------
# ViT Attention for Class-Specific
# ---------------------------
def vit_attention_for_class(model, extractor, image, class_idx):
"""Get attention map for specific class in ViT"""
inputs = extractor(images=image, return_tensors="pt")
inputs['pixel_values'].requires_grad = True
outputs = model(**inputs)
score = outputs.logits[0, class_idx]
model.zero_grad()
score.backward()
# Use last layer attention
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
attn = outputs.attentions[-1]
attn = attn.mean(1)
attn = attn[:, 0, 1:]
attn_map = attn.reshape(1, 14, 14)
attn_map = attn_map.squeeze().detach().cpu().numpy()
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
return attn_map
return np.ones((14, 14))
# ---------------------------
# Grad-CAM Helper for CNNs (Top Prediction)
# ---------------------------
def get_gradcam(model, image_tensor):
grad = None
fmap = None
def forward_hook(module, input, output):
nonlocal fmap
fmap = output.detach()
def backward_hook(module, grad_in, grad_out):
nonlocal grad
grad = grad_out[0].detach()
last_conv = None
for name, module in reversed(list(model.named_modules())):
if isinstance(module, torch.nn.Conv2d):
last_conv = module
break
if last_conv is None:
return np.ones((224, 224))
handle_fwd = last_conv.register_forward_hook(forward_hook)
handle_bwd = last_conv.register_full_backward_hook(backward_hook)
out = model(image_tensor)
class_idx = out.argmax(dim=1).item()
score = out[0, class_idx]
model.zero_grad()
score.backward()
handle_fwd.remove()
handle_bwd.remove()
if grad is None or fmap is None:
return np.ones((224, 224))
weights = grad.mean(dim=(2, 3), keepdim=True)
cam = (weights * fmap).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = cam.squeeze().cpu().numpy()
cam = cv2.resize(cam, (224, 224))
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam
# ---------------------------
# ViT Attention Rollout Helper
# ---------------------------
def vit_attention_rollout(outputs):
if not hasattr(outputs, 'attentions') or outputs.attentions is None:
return np.ones((14, 14))
attn = outputs.attentions[-1]
attn = attn.mean(1)
attn = attn[:, 0, 1:]
attn_map = attn.reshape(1, 14, 14)
attn_map = attn_map.squeeze().detach().cpu().numpy()
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
return attn_map
# ---------------------------
# Create Gradient Legend
# ---------------------------
def create_gradient_legend():
"""Create a gradient legend image showing attention scale"""
width, height = 400, 60
gradient = np.zeros((height, width, 3), dtype=np.uint8)
# Create gradient from blue to red (matching COLORMAP_JET)
for i in range(width):
# OpenCV's COLORMAP_JET: blue (low) -> cyan -> green -> yellow -> red (high)
value = int(255 * i / width)
color_single = np.array([[[value]]], dtype=np.uint8)
color_rgb = cv2.applyColorMap(color_single, cv2.COLORMAP_JET)
gradient[:, i] = color_rgb[0, 0]
gradient = cv2.cvtColor(gradient, cv2.COLOR_BGR2RGB)
# Convert to PIL and add text
from PIL import ImageDraw, ImageFont
gradient_pil = Image.fromarray(gradient)
draw = ImageDraw.Draw(gradient_pil)
# Use default font
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
except:
font = ImageFont.load_default()
# Add text labels
draw.text((10, 20), "Low Attention", fill=(255, 255, 255), font=font)
draw.text((width - 120, 20), "High Attention", fill=(255, 255, 255), font=font)
return gradient_pil
def overlay_attention(pil_img, attention_map):
heatmap = (attention_map * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = cv2.resize(heatmap, pil_img.size)
heatmap_pil = Image.fromarray(heatmap)
blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
return blended
# ---------------------------
# Main Prediction Function
# ---------------------------
def predict(image, model_name, noise_level):
try:
if image is None:
return {"Error": "Please upload an image"}, None, None
if model_name is None:
return {"Error": "Please select a model"}, None, None
# Apply adversarial noise if requested
if noise_level > 0:
image = add_adversarial_noise(image, noise_level)
model, extractor = load_model(model_name)
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
if MODEL_CONFIGS[model_name]["type"] == "hf":
with torch.no_grad():
inputs = extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0]
top5_prob, top5_idx = torch.topk(probs, k=5)
top5_labels = [model.config.id2label[idx.item()] for idx in top5_idx]
att_map = vit_attention_rollout(outputs)
else:
x = transform(image).unsqueeze(0)
x.requires_grad = True
with torch.no_grad():
outputs = model(x.detach())
probs = F.softmax(outputs, dim=-1)[0]
top5_prob, top5_idx = torch.topk(probs, k=5)
top5_labels = [IMAGENET_LABELS[idx.item()] for idx in top5_idx]
att_map = get_gradcam(model, x)
overlay = overlay_attention(image, att_map)
result = {label: float(prob) for label, prob in zip(top5_labels, top5_prob)}
return result, overlay, image
except Exception as e:
import traceback
error_msg = f"Model '{model_name}' failed: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return {"Error": str(e)}, None, None
# ---------------------------
# Class-Specific Attention
# ---------------------------
def get_class_specific_attention(image, model_name, class_query):
try:
if image is None:
return None, None, "Please upload an image first"
if not class_query or class_query.strip() == "":
return None, None, "Please enter a class name"
# Find matching class
class_query_lower = class_query.lower().strip()
matching_idx = None
matched_label = None
model, extractor = load_model(model_name)
if MODEL_CONFIGS[model_name]["type"] == "hf":
# Search in HF model labels
for idx, label in model.config.id2label.items():
if class_query_lower in label.lower():
matching_idx = idx
matched_label = label
break
if matching_idx is None:
return None, None, f"Class '{class_query}' not found in model labels. Try a different class name or check sample classes."
# Get attention for this class
att_map = vit_attention_for_class(model, extractor, image, matching_idx)
else:
# Search in ImageNet labels
for idx, label in enumerate(IMAGENET_LABELS):
if class_query_lower in label.lower():
matching_idx = idx
matched_label = label
break
if matching_idx is None:
return None, None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check sample classes."
# Get Grad-CAM for this class
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
x = transform(image).unsqueeze(0)
x.requires_grad = True
att_map = get_gradcam_for_class(model, x, matching_idx)
overlay = overlay_attention(image, att_map)
legend = create_gradient_legend()
return overlay, legend, f"βœ“ Attention map generated for class: '{matched_label}' (Index: {matching_idx})"
except Exception as e:
import traceback
error_trace = traceback.format_exc()
print(error_trace)
return None, None, f"Error generating attention map: {str(e)}"
# ---------------------------
# Sample Classes
# ---------------------------
SAMPLE_CLASSES = [
"cat", "dog", "tiger", "lion", "elephant",
"car", "truck", "airplane", "ship", "train",
"pizza", "hamburger", "coffee", "banana", "apple",
"chair", "table", "laptop", "keyboard", "mouse",
"person", "bicycle", "building", "tree", "flower"
]
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 Enhanced Multi-Model Image Classifier")
gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 13+ Models")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="πŸ“Έ Upload Image")
model_dropdown = gr.Dropdown(
choices=list(MODEL_CONFIGS.keys()),
label="πŸ€– Select Model",
value="DeiT-Tiny"
)
gr.Markdown("### 🎭 Adversarial Noise")
noise_slider = gr.Slider(
minimum=0,
maximum=0.3,
value=0,
step=0.01,
label="Noise Level (Ξ΅)",
info="Add random noise to test model robustness"
)
run_button = gr.Button("πŸš€ Run Model", variant="primary")
with gr.Column(scale=2):
output_label = gr.Label(num_top_classes=5, label="🎯 Top 5 Predictions")
output_image = gr.Image(label="πŸ” Attention Map (Top Prediction)")
processed_image = gr.Image(label="πŸ–ΌοΈ Processed Image (with noise if applied)", visible=False)
gr.Markdown("---")
gr.Markdown("### 🎨 Class-Specific Attention Visualization")
gr.Markdown("*Type any class name to see where the model looks for that specific object*")
with gr.Row():
with gr.Column(scale=1):
class_input = gr.Textbox(
label="πŸ” Enter Class Name",
placeholder="e.g., cat, dog, car, pizza...",
info="Type any ImageNet class name"
)
class_button = gr.Button("🎯 Generate Class-Specific Attention", variant="primary")
gr.Markdown("**πŸ’‘ Sample classes to try:**")
sample_buttons = gr.Radio(
choices=SAMPLE_CLASSES,
label="Click to auto-fill",
interactive=True
)
with gr.Column(scale=2):
class_output_image = gr.Image(label="πŸ” Class-Specific Attention Map")
class_status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("---")
gr.Markdown("""
### πŸ’‘ Tips:
- **Adversarial Noise**: Adjust the slider to add random noise and see how robust the model is
- **Class-Specific Attention**: Type any ImageNet class to visualize what the model looks for (e.g., "tiger", "sports car", "pizza")
- **Model Variety**: Try different architectures (ViT, CNN, Hybrid) to compare their behavior
""")
# Event handlers
run_button.click(
predict,
inputs=[input_image, model_dropdown, noise_slider],
outputs=[output_label, output_image, processed_image]
)
# When user selects a sample class, update the text input
sample_buttons.change(
lambda x: x,
inputs=[sample_buttons],
outputs=[class_input]
)
# Generate attention map
class_button.click(
get_class_specific_attention,
inputs=[input_image, model_dropdown, class_input],
outputs=[class_output_image, gradient_legend, class_status]
)
if __name__ == "__main__":
demo.launch()