DEIT / app.py
Godreign's picture
confidence value added
aa41700 verified
import gradio as gr
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from efficientnet_pytorch import EfficientNet
import timm
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import torchvision.transforms as T
import urllib.request
import json
import cv2
# ---------------------------
# Model Configs
# ---------------------------
MODEL_CONFIGS = {
"DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"},
"DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"},
"ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"},
"ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"},
"ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano"},
"EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"},
"EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"},
"ResNet-50": {"type": "timm", "id": "resnet50"},
"MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"},
"MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"},
"MobileViT-Small": {"type": "timm", "id": "mobilevit_s"},
"EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"},
"RegNetY-002": {"type": "timm", "id": "regnety_002"}
}
# ---------------------------
# ImageNet Labels
# ---------------------------
IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
with urllib.request.urlopen(IMAGENET_URL) as url:
IMAGENET_LABELS = json.load(url)
# ---------------------------
# Lazy Load
# ---------------------------
loaded_models = {}
def load_model(model_name):
if model_name in loaded_models:
return loaded_models[model_name]
config = MODEL_CONFIGS[model_name]
if config["type"] == "hf":
extractor = AutoFeatureExtractor.from_pretrained(config["id"])
model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
model.eval()
for param in model.parameters():
param.requires_grad = True
elif config["type"] == "timm":
model = timm.create_model(config["id"], pretrained=True)
model.eval()
for param in model.parameters():
param.requires_grad = True
extractor = None
elif config["type"] == "efficientnet":
model = EfficientNet.from_pretrained(config["id"])
model.eval()
for param in model.parameters():
param.requires_grad = True
extractor = None
loaded_models[model_name] = (model, extractor)
return model, extractor
# ---------------------------
# Adversarial Noise
# ---------------------------
def add_adversarial_noise(image, epsilon):
img_array = np.array(image).astype(np.float32) / 255.0
noise = np.random.randn(*img_array.shape) * epsilon
noisy_img = np.clip(img_array + noise, 0, 1)
return Image.fromarray((noisy_img * 255).astype(np.uint8))
# ---------------------------
# Grad-CAM for Class-Specific Attention
# ---------------------------
def get_gradcam_for_class(model, image_tensor, class_idx):
grad = None
fmap = None
def forward_hook(module, input, output):
nonlocal fmap
fmap = output.detach()
def backward_hook(module, grad_in, grad_out):
nonlocal grad
grad = grad_out[0].detach()
last_conv = None
for name, module in reversed(list(model.named_modules())):
if isinstance(module, torch.nn.Conv2d):
last_conv = module
break
if last_conv is None:
return np.ones((224, 224))
handle_fwd = last_conv.register_forward_hook(forward_hook)
handle_bwd = last_conv.register_full_backward_hook(backward_hook)
out = model(image_tensor)
score = out[0, class_idx]
model.zero_grad()
score.backward()
handle_fwd.remove()
handle_bwd.remove()
if grad is None or fmap is None:
return np.ones((224, 224))
weights = grad.mean(dim=(2, 3), keepdim=True)
cam = (weights * fmap).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = cam.squeeze().cpu().numpy()
cam = cv2.resize(cam, (224, 224))
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam
# ---------------------------
# ViT Attention for Class-Specific
# ---------------------------
def vit_attention_for_class(model, extractor, image, class_idx):
inputs = extractor(images=image, return_tensors="pt")
inputs['pixel_values'].requires_grad = True
outputs = model(**inputs)
score = outputs.logits[0, class_idx]
model.zero_grad()
score.backward()
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
attn = outputs.attentions[-1]
attn = attn.mean(1)
attn = attn[:, 0, 1:]
attn_map = attn.reshape(1, 14, 14)
attn_map = attn_map.squeeze().detach().cpu().numpy()
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
prob = F.softmax(outputs.logits, dim=-1)[0, class_idx].item()
return attn_map, prob
return np.ones((14, 14)), 0.0
# ---------------------------
# Grad-CAM Helper for CNNs
# ---------------------------
def get_gradcam(model, image_tensor):
grad = None
fmap = None
def forward_hook(module, input, output):
nonlocal fmap
fmap = output.detach()
def backward_hook(module, grad_in, grad_out):
nonlocal grad
grad = grad_out[0].detach()
last_conv = None
for name, module in reversed(list(model.named_modules())):
if isinstance(module, torch.nn.Conv2d):
last_conv = module
break
if last_conv is None:
return np.ones((224, 224))
handle_fwd = last_conv.register_forward_hook(forward_hook)
handle_bwd = last_conv.register_full_backward_hook(backward_hook)
out = model(image_tensor)
class_idx = out.argmax(dim=1).item()
score = out[0, class_idx]
model.zero_grad()
score.backward()
handle_fwd.remove()
handle_bwd.remove()
if grad is None or fmap is None:
return np.ones((224, 224))
weights = grad.mean(dim=(2, 3), keepdim=True)
cam = (weights * fmap).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = cam.squeeze().cpu().numpy()
cam = cv2.resize(cam, (224, 224))
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam
# ---------------------------
# ViT Attention Rollout
# ---------------------------
def vit_attention_rollout(outputs):
if not hasattr(outputs, 'attentions') or outputs.attentions is None:
return np.ones((14, 14))
attn = outputs.attentions[-1]
attn = attn.mean(1)
attn = attn[:, 0, 1:]
attn_map = attn.reshape(1, 14, 14)
attn_map = attn_map.squeeze().detach().cpu().numpy()
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
return attn_map
# ---------------------------
# Overlay Attention on Image
# ---------------------------
def overlay_attention(pil_img, attention_map):
heatmap = (attention_map * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = cv2.resize(heatmap, pil_img.size)
heatmap_pil = Image.fromarray(heatmap)
blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
return blended
# ---------------------------
# Main Prediction Function
# ---------------------------
def predict(image, model_name, noise_level):
try:
if image is None:
return {"Error": "Please upload an image"}, None, None
if model_name is None:
return {"Error": "Please select a model"}, None, None
if noise_level > 0:
image = add_adversarial_noise(image, noise_level)
model, extractor = load_model(model_name)
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
if MODEL_CONFIGS[model_name]["type"] == "hf":
with torch.no_grad():
inputs = extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0]
top5_prob, top5_idx = torch.topk(probs, k=5)
top5_labels = [model.config.id2label[idx.item()] for idx in top5_idx]
att_map = vit_attention_rollout(outputs)
else:
x = transform(image).unsqueeze(0)
x.requires_grad = True
with torch.no_grad():
outputs = model(x.detach())
probs = F.softmax(outputs, dim=-1)[0]
top5_prob, top5_idx = torch.topk(probs, k=5)
top5_labels = [IMAGENET_LABELS[idx.item()] for idx in top5_idx]
att_map = get_gradcam(model, x)
overlay = overlay_attention(image, att_map)
result = {label: float(prob) for label, prob in zip(top5_labels, top5_prob)}
return result, overlay, image
except Exception as e:
import traceback
print(f"Error: {traceback.format_exc()}")
return {"Error": str(e)}, None, None
# ---------------------------
# Class-Specific Attention with Confidence
# ---------------------------
def get_class_specific_attention(image, model_name, class_query):
try:
if image is None:
return None, "Please upload an image first"
if not class_query or class_query.strip() == "":
return None, "Please enter a class name"
class_query_lower = class_query.lower().strip()
matching_idx = None
matched_label = None
confidence = 0.0
model, extractor = load_model(model_name)
if MODEL_CONFIGS[model_name]["type"] == "hf":
for idx, label in model.config.id2label.items():
if class_query_lower in label.lower():
matching_idx = idx
matched_label = label
break
if matching_idx is None:
return None, f"Class '{class_query}' not found in model labels."
att_map, confidence = vit_attention_for_class(model, extractor, image, matching_idx)
else:
for idx, label in enumerate(IMAGENET_LABELS):
if class_query_lower in label.lower():
matching_idx = idx
matched_label = label
break
if matching_idx is None:
return None, f"Class '{class_query}' not found in ImageNet labels."
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
x = transform(image).unsqueeze(0)
x.requires_grad = True
att_map = get_gradcam_for_class(model, x, matching_idx)
with torch.no_grad():
outputs = model(x)
confidence = F.softmax(outputs, dim=-1)[0, matching_idx].item()
overlay = overlay_attention(image, att_map)
return overlay, f"βœ“ Attention map generated for class: '{matched_label}' (Index: {matching_idx}, Confidence: {confidence:.2f})"
except Exception as e:
import traceback
print(traceback.format_exc())
return None, f"Error generating attention map: {str(e)}"
# ---------------------------
# Sample Classes
# ---------------------------
SAMPLE_CLASSES = [
"cat", "dog", "tiger", "lion", "elephant",
"car", "truck", "airplane", "ship", "train",
"pizza", "hamburger", "coffee", "banana", "apple",
"chair", "table", "laptop", "keyboard", "mouse",
"person", "bicycle", "building", "tree", "flower"
]
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 Enhanced Multi-Model Image Classifier")
gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 13+ Models")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="πŸ“Έ Upload Image")
model_dropdown = gr.Dropdown(
choices=list(MODEL_CONFIGS.keys()),
label="πŸ€– Select Model",
value="DeiT-Tiny"
)
gr.Markdown("### 🎭 Adversarial Noise")
noise_slider = gr.Slider(
minimum=0,
maximum=0.3,
value=0,
step=0.01,
label="Noise Level (Ξ΅)",
info="Add random noise to test model robustness"
)
run_button = gr.Button("πŸš€ Run Model", variant="primary")
with gr.Column(scale=2):
output_label = gr.Label(num_top_classes=5, label="🎯 Top 5 Predictions")
output_image = gr.Image(label="πŸ” Attention Map (Top Prediction)")
processed_image = gr.Image(label="πŸ–ΌοΈ Processed Image (with noise if applied)", visible=False)
gr.Markdown("---")
gr.Markdown("### 🎨 Class-Specific Attention Visualization")
gr.Markdown("*Type any class name to see where the model looks for that specific object*")
with gr.Row():
with gr.Column(scale=1):
class_input = gr.Textbox(
label="πŸ” Enter Class Name",
placeholder="e.g., cat, dog, car, pizza...",
info="Type any ImageNet class name"
)
class_button = gr.Button("🎯 Generate Class-Specific Attention", variant="primary")
gr.Markdown("**πŸ’‘ Sample classes to try:**")
sample_buttons = gr.Radio(
choices=SAMPLE_CLASSES,
label="Click to auto-fill",
interactive=True
)
with gr.Column(scale=2):
class_output_image = gr.Image(label="πŸ” Class-Specific Attention Map")
class_status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("---")
gr.Markdown("""
### πŸ’‘ Tips:
- **Adversarial Noise**: Adjust the slider to add random noise and see how robust the model is
- **Class-Specific Attention**: Type any ImageNet class to visualize what the model looks for
- **Model Variety**: Try different architectures (ViT, CNN, Hybrid) to compare their behavior
""")
run_button.click(
predict,
inputs=[input_image, model_dropdown, noise_slider],
outputs=[output_label, output_image, processed_image]
)
sample_buttons.change(
lambda x: x,
inputs=[sample_buttons],
outputs=[class_input]
)
class_button.click(
get_class_specific_attention,
inputs=[input_image, model_dropdown, class_input],
outputs=[class_output_image, class_status]
)
if __name__ == "__main__":
demo.launch()