|
|
import gradio as gr |
|
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
from efficientnet_pytorch import EfficientNet |
|
|
import timm |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torchvision.transforms as T |
|
|
import urllib.request |
|
|
import json |
|
|
import cv2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
|
"DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"}, |
|
|
"DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"}, |
|
|
"ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"}, |
|
|
"ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"}, |
|
|
"ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano"}, |
|
|
"EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"}, |
|
|
"EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"}, |
|
|
"ResNet-50": {"type": "timm", "id": "resnet50"}, |
|
|
"MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"}, |
|
|
"MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"}, |
|
|
"MobileViT-Small": {"type": "timm", "id": "mobilevit_s"}, |
|
|
"EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"}, |
|
|
"RegNetY-002": {"type": "timm", "id": "regnety_002"} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" |
|
|
with urllib.request.urlopen(IMAGENET_URL) as url: |
|
|
IMAGENET_LABELS = json.load(url) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loaded_models = {} |
|
|
|
|
|
def load_model(model_name): |
|
|
if model_name in loaded_models: |
|
|
return loaded_models[model_name] |
|
|
|
|
|
config = MODEL_CONFIGS[model_name] |
|
|
if config["type"] == "hf": |
|
|
extractor = AutoFeatureExtractor.from_pretrained(config["id"]) |
|
|
model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True) |
|
|
model.eval() |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = True |
|
|
elif config["type"] == "timm": |
|
|
model = timm.create_model(config["id"], pretrained=True) |
|
|
model.eval() |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = True |
|
|
extractor = None |
|
|
elif config["type"] == "efficientnet": |
|
|
model = EfficientNet.from_pretrained(config["id"]) |
|
|
model.eval() |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = True |
|
|
extractor = None |
|
|
|
|
|
loaded_models[model_name] = (model, extractor) |
|
|
return model, extractor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_adversarial_noise(image, epsilon): |
|
|
img_array = np.array(image).astype(np.float32) / 255.0 |
|
|
noise = np.random.randn(*img_array.shape) * epsilon |
|
|
noisy_img = np.clip(img_array + noise, 0, 1) |
|
|
return Image.fromarray((noisy_img * 255).astype(np.uint8)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gradcam_for_class(model, image_tensor, class_idx): |
|
|
grad = None |
|
|
fmap = None |
|
|
|
|
|
def forward_hook(module, input, output): |
|
|
nonlocal fmap |
|
|
fmap = output.detach() |
|
|
|
|
|
def backward_hook(module, grad_in, grad_out): |
|
|
nonlocal grad |
|
|
grad = grad_out[0].detach() |
|
|
|
|
|
last_conv = None |
|
|
for name, module in reversed(list(model.named_modules())): |
|
|
if isinstance(module, torch.nn.Conv2d): |
|
|
last_conv = module |
|
|
break |
|
|
if last_conv is None: |
|
|
return np.ones((224, 224)) |
|
|
|
|
|
handle_fwd = last_conv.register_forward_hook(forward_hook) |
|
|
handle_bwd = last_conv.register_full_backward_hook(backward_hook) |
|
|
|
|
|
out = model(image_tensor) |
|
|
score = out[0, class_idx] |
|
|
model.zero_grad() |
|
|
score.backward() |
|
|
|
|
|
handle_fwd.remove() |
|
|
handle_bwd.remove() |
|
|
|
|
|
if grad is None or fmap is None: |
|
|
return np.ones((224, 224)) |
|
|
|
|
|
weights = grad.mean(dim=(2, 3), keepdim=True) |
|
|
cam = (weights * fmap).sum(dim=1, keepdim=True) |
|
|
cam = F.relu(cam) |
|
|
cam = cam.squeeze().cpu().numpy() |
|
|
cam = cv2.resize(cam, (224, 224)) |
|
|
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) |
|
|
|
|
|
return cam |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vit_attention_for_class(model, extractor, image, class_idx): |
|
|
inputs = extractor(images=image, return_tensors="pt") |
|
|
inputs['pixel_values'].requires_grad = True |
|
|
outputs = model(**inputs) |
|
|
|
|
|
score = outputs.logits[0, class_idx] |
|
|
model.zero_grad() |
|
|
score.backward() |
|
|
|
|
|
if hasattr(outputs, 'attentions') and outputs.attentions is not None: |
|
|
attn = outputs.attentions[-1] |
|
|
attn = attn.mean(1) |
|
|
attn = attn[:, 0, 1:] |
|
|
attn_map = attn.reshape(1, 14, 14) |
|
|
attn_map = attn_map.squeeze().detach().cpu().numpy() |
|
|
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) |
|
|
prob = F.softmax(outputs.logits, dim=-1)[0, class_idx].item() |
|
|
return attn_map, prob |
|
|
|
|
|
return np.ones((14, 14)), 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gradcam(model, image_tensor): |
|
|
grad = None |
|
|
fmap = None |
|
|
|
|
|
def forward_hook(module, input, output): |
|
|
nonlocal fmap |
|
|
fmap = output.detach() |
|
|
|
|
|
def backward_hook(module, grad_in, grad_out): |
|
|
nonlocal grad |
|
|
grad = grad_out[0].detach() |
|
|
|
|
|
last_conv = None |
|
|
for name, module in reversed(list(model.named_modules())): |
|
|
if isinstance(module, torch.nn.Conv2d): |
|
|
last_conv = module |
|
|
break |
|
|
if last_conv is None: |
|
|
return np.ones((224, 224)) |
|
|
|
|
|
handle_fwd = last_conv.register_forward_hook(forward_hook) |
|
|
handle_bwd = last_conv.register_full_backward_hook(backward_hook) |
|
|
|
|
|
out = model(image_tensor) |
|
|
class_idx = out.argmax(dim=1).item() |
|
|
score = out[0, class_idx] |
|
|
model.zero_grad() |
|
|
score.backward() |
|
|
|
|
|
handle_fwd.remove() |
|
|
handle_bwd.remove() |
|
|
|
|
|
if grad is None or fmap is None: |
|
|
return np.ones((224, 224)) |
|
|
|
|
|
weights = grad.mean(dim=(2, 3), keepdim=True) |
|
|
cam = (weights * fmap).sum(dim=1, keepdim=True) |
|
|
cam = F.relu(cam) |
|
|
cam = cam.squeeze().cpu().numpy() |
|
|
cam = cv2.resize(cam, (224, 224)) |
|
|
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) |
|
|
|
|
|
return cam |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vit_attention_rollout(outputs): |
|
|
if not hasattr(outputs, 'attentions') or outputs.attentions is None: |
|
|
return np.ones((14, 14)) |
|
|
attn = outputs.attentions[-1] |
|
|
attn = attn.mean(1) |
|
|
attn = attn[:, 0, 1:] |
|
|
attn_map = attn.reshape(1, 14, 14) |
|
|
attn_map = attn_map.squeeze().detach().cpu().numpy() |
|
|
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) |
|
|
return attn_map |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def overlay_attention(pil_img, attention_map): |
|
|
heatmap = (attention_map * 255).astype(np.uint8) |
|
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
|
|
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
|
|
heatmap = cv2.resize(heatmap, pil_img.size) |
|
|
heatmap_pil = Image.fromarray(heatmap) |
|
|
blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4) |
|
|
return blended |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(image, model_name, noise_level): |
|
|
try: |
|
|
if image is None: |
|
|
return {"Error": "Please upload an image"}, None, None |
|
|
|
|
|
if model_name is None: |
|
|
return {"Error": "Please select a model"}, None, None |
|
|
|
|
|
if noise_level > 0: |
|
|
image = add_adversarial_noise(image, noise_level) |
|
|
|
|
|
model, extractor = load_model(model_name) |
|
|
transform = T.Compose([ |
|
|
T.Resize((224, 224)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
if MODEL_CONFIGS[model_name]["type"] == "hf": |
|
|
with torch.no_grad(): |
|
|
inputs = extractor(images=image, return_tensors="pt") |
|
|
outputs = model(**inputs) |
|
|
probs = F.softmax(outputs.logits, dim=-1)[0] |
|
|
top5_prob, top5_idx = torch.topk(probs, k=5) |
|
|
top5_labels = [model.config.id2label[idx.item()] for idx in top5_idx] |
|
|
att_map = vit_attention_rollout(outputs) |
|
|
else: |
|
|
x = transform(image).unsqueeze(0) |
|
|
x.requires_grad = True |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(x.detach()) |
|
|
probs = F.softmax(outputs, dim=-1)[0] |
|
|
top5_prob, top5_idx = torch.topk(probs, k=5) |
|
|
top5_labels = [IMAGENET_LABELS[idx.item()] for idx in top5_idx] |
|
|
|
|
|
att_map = get_gradcam(model, x) |
|
|
|
|
|
overlay = overlay_attention(image, att_map) |
|
|
result = {label: float(prob) for label, prob in zip(top5_labels, top5_prob)} |
|
|
|
|
|
return result, overlay, image |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
print(f"Error: {traceback.format_exc()}") |
|
|
return {"Error": str(e)}, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_class_specific_attention(image, model_name, class_query): |
|
|
try: |
|
|
if image is None: |
|
|
return None, "Please upload an image first" |
|
|
|
|
|
if not class_query or class_query.strip() == "": |
|
|
return None, "Please enter a class name" |
|
|
|
|
|
class_query_lower = class_query.lower().strip() |
|
|
matching_idx = None |
|
|
matched_label = None |
|
|
confidence = 0.0 |
|
|
|
|
|
model, extractor = load_model(model_name) |
|
|
|
|
|
if MODEL_CONFIGS[model_name]["type"] == "hf": |
|
|
for idx, label in model.config.id2label.items(): |
|
|
if class_query_lower in label.lower(): |
|
|
matching_idx = idx |
|
|
matched_label = label |
|
|
break |
|
|
|
|
|
if matching_idx is None: |
|
|
return None, f"Class '{class_query}' not found in model labels." |
|
|
|
|
|
att_map, confidence = vit_attention_for_class(model, extractor, image, matching_idx) |
|
|
|
|
|
else: |
|
|
for idx, label in enumerate(IMAGENET_LABELS): |
|
|
if class_query_lower in label.lower(): |
|
|
matching_idx = idx |
|
|
matched_label = label |
|
|
break |
|
|
|
|
|
if matching_idx is None: |
|
|
return None, f"Class '{class_query}' not found in ImageNet labels." |
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize((224, 224)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
x = transform(image).unsqueeze(0) |
|
|
x.requires_grad = True |
|
|
att_map = get_gradcam_for_class(model, x, matching_idx) |
|
|
with torch.no_grad(): |
|
|
outputs = model(x) |
|
|
confidence = F.softmax(outputs, dim=-1)[0, matching_idx].item() |
|
|
|
|
|
overlay = overlay_attention(image, att_map) |
|
|
return overlay, f"β Attention map generated for class: '{matched_label}' (Index: {matching_idx}, Confidence: {confidence:.2f})" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
print(traceback.format_exc()) |
|
|
return None, f"Error generating attention map: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SAMPLE_CLASSES = [ |
|
|
"cat", "dog", "tiger", "lion", "elephant", |
|
|
"car", "truck", "airplane", "ship", "train", |
|
|
"pizza", "hamburger", "coffee", "banana", "apple", |
|
|
"chair", "table", "laptop", "keyboard", "mouse", |
|
|
"person", "bicycle", "building", "tree", "flower" |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# π§ Enhanced Multi-Model Image Classifier") |
|
|
gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 13+ Models") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
input_image = gr.Image(type="pil", label="πΈ Upload Image") |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=list(MODEL_CONFIGS.keys()), |
|
|
label="π€ Select Model", |
|
|
value="DeiT-Tiny" |
|
|
) |
|
|
|
|
|
gr.Markdown("### π Adversarial Noise") |
|
|
noise_slider = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=0.3, |
|
|
value=0, |
|
|
step=0.01, |
|
|
label="Noise Level (Ξ΅)", |
|
|
info="Add random noise to test model robustness" |
|
|
) |
|
|
|
|
|
run_button = gr.Button("π Run Model", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
output_label = gr.Label(num_top_classes=5, label="π― Top 5 Predictions") |
|
|
output_image = gr.Image(label="π Attention Map (Top Prediction)") |
|
|
processed_image = gr.Image(label="πΌοΈ Processed Image (with noise if applied)", visible=False) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### π¨ Class-Specific Attention Visualization") |
|
|
gr.Markdown("*Type any class name to see where the model looks for that specific object*") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
class_input = gr.Textbox( |
|
|
label="π Enter Class Name", |
|
|
placeholder="e.g., cat, dog, car, pizza...", |
|
|
info="Type any ImageNet class name" |
|
|
) |
|
|
class_button = gr.Button("π― Generate Class-Specific Attention", variant="primary") |
|
|
gr.Markdown("**π‘ Sample classes to try:**") |
|
|
sample_buttons = gr.Radio( |
|
|
choices=SAMPLE_CLASSES, |
|
|
label="Click to auto-fill", |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
class_output_image = gr.Image(label="π Class-Specific Attention Map") |
|
|
class_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown(""" |
|
|
### π‘ Tips: |
|
|
- **Adversarial Noise**: Adjust the slider to add random noise and see how robust the model is |
|
|
- **Class-Specific Attention**: Type any ImageNet class to visualize what the model looks for |
|
|
- **Model Variety**: Try different architectures (ViT, CNN, Hybrid) to compare their behavior |
|
|
""") |
|
|
|
|
|
run_button.click( |
|
|
predict, |
|
|
inputs=[input_image, model_dropdown, noise_slider], |
|
|
outputs=[output_label, output_image, processed_image] |
|
|
) |
|
|
|
|
|
sample_buttons.change( |
|
|
lambda x: x, |
|
|
inputs=[sample_buttons], |
|
|
outputs=[class_input] |
|
|
) |
|
|
|
|
|
class_button.click( |
|
|
get_class_specific_attention, |
|
|
inputs=[input_image, model_dropdown, class_input], |
|
|
outputs=[class_output_image, class_status] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|