import gradio as gr from transformers import AutoFeatureExtractor, AutoModelForImageClassification from efficientnet_pytorch import EfficientNet import timm import torch import torch.nn.functional as F from PIL import Image import numpy as np import torchvision.transforms as T import urllib.request import json import cv2 # --------------------------- # Model Configs # --------------------------- MODEL_CONFIGS = { "DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"}, "DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"}, "ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"}, "ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"}, "ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano"}, "EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"}, "EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"}, "ResNet-50": {"type": "timm", "id": "resnet50"}, "MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"}, "MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"}, "MobileViT-Small": {"type": "timm", "id": "mobilevit_s"}, "EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"}, "RegNetY-002": {"type": "timm", "id": "regnety_002"} } # --------------------------- # ImageNet Labels # --------------------------- IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" with urllib.request.urlopen(IMAGENET_URL) as url: IMAGENET_LABELS = json.load(url) # --------------------------- # Lazy Load # --------------------------- loaded_models = {} def load_model(model_name): if model_name in loaded_models: return loaded_models[model_name] config = MODEL_CONFIGS[model_name] if config["type"] == "hf": extractor = AutoFeatureExtractor.from_pretrained(config["id"]) model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True) model.eval() for param in model.parameters(): param.requires_grad = True elif config["type"] == "timm": model = timm.create_model(config["id"], pretrained=True) model.eval() for param in model.parameters(): param.requires_grad = True extractor = None elif config["type"] == "efficientnet": model = EfficientNet.from_pretrained(config["id"]) model.eval() for param in model.parameters(): param.requires_grad = True extractor = None loaded_models[model_name] = (model, extractor) return model, extractor # --------------------------- # Adversarial Noise # --------------------------- def add_adversarial_noise(image, epsilon): img_array = np.array(image).astype(np.float32) / 255.0 noise = np.random.randn(*img_array.shape) * epsilon noisy_img = np.clip(img_array + noise, 0, 1) return Image.fromarray((noisy_img * 255).astype(np.uint8)) # --------------------------- # Grad-CAM for Class-Specific Attention # --------------------------- def get_gradcam_for_class(model, image_tensor, class_idx): grad = None fmap = None def forward_hook(module, input, output): nonlocal fmap fmap = output.detach() def backward_hook(module, grad_in, grad_out): nonlocal grad grad = grad_out[0].detach() last_conv = None for name, module in reversed(list(model.named_modules())): if isinstance(module, torch.nn.Conv2d): last_conv = module break if last_conv is None: return np.ones((224, 224)) handle_fwd = last_conv.register_forward_hook(forward_hook) handle_bwd = last_conv.register_full_backward_hook(backward_hook) out = model(image_tensor) score = out[0, class_idx] model.zero_grad() score.backward() handle_fwd.remove() handle_bwd.remove() if grad is None or fmap is None: return np.ones((224, 224)) weights = grad.mean(dim=(2, 3), keepdim=True) cam = (weights * fmap).sum(dim=1, keepdim=True) cam = F.relu(cam) cam = cam.squeeze().cpu().numpy() cam = cv2.resize(cam, (224, 224)) cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) return cam # --------------------------- # ViT Attention for Class-Specific # --------------------------- def vit_attention_for_class(model, extractor, image, class_idx): inputs = extractor(images=image, return_tensors="pt") inputs['pixel_values'].requires_grad = True outputs = model(**inputs) score = outputs.logits[0, class_idx] model.zero_grad() score.backward() if hasattr(outputs, 'attentions') and outputs.attentions is not None: attn = outputs.attentions[-1] attn = attn.mean(1) attn = attn[:, 0, 1:] attn_map = attn.reshape(1, 14, 14) attn_map = attn_map.squeeze().detach().cpu().numpy() attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) prob = F.softmax(outputs.logits, dim=-1)[0, class_idx].item() return attn_map, prob return np.ones((14, 14)), 0.0 # --------------------------- # Grad-CAM Helper for CNNs # --------------------------- def get_gradcam(model, image_tensor): grad = None fmap = None def forward_hook(module, input, output): nonlocal fmap fmap = output.detach() def backward_hook(module, grad_in, grad_out): nonlocal grad grad = grad_out[0].detach() last_conv = None for name, module in reversed(list(model.named_modules())): if isinstance(module, torch.nn.Conv2d): last_conv = module break if last_conv is None: return np.ones((224, 224)) handle_fwd = last_conv.register_forward_hook(forward_hook) handle_bwd = last_conv.register_full_backward_hook(backward_hook) out = model(image_tensor) class_idx = out.argmax(dim=1).item() score = out[0, class_idx] model.zero_grad() score.backward() handle_fwd.remove() handle_bwd.remove() if grad is None or fmap is None: return np.ones((224, 224)) weights = grad.mean(dim=(2, 3), keepdim=True) cam = (weights * fmap).sum(dim=1, keepdim=True) cam = F.relu(cam) cam = cam.squeeze().cpu().numpy() cam = cv2.resize(cam, (224, 224)) cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) return cam # --------------------------- # ViT Attention Rollout # --------------------------- def vit_attention_rollout(outputs): if not hasattr(outputs, 'attentions') or outputs.attentions is None: return np.ones((14, 14)) attn = outputs.attentions[-1] attn = attn.mean(1) attn = attn[:, 0, 1:] attn_map = attn.reshape(1, 14, 14) attn_map = attn_map.squeeze().detach().cpu().numpy() attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8) return attn_map # --------------------------- # Overlay Attention on Image # --------------------------- def overlay_attention(pil_img, attention_map): heatmap = (attention_map * 255).astype(np.uint8) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = cv2.resize(heatmap, pil_img.size) heatmap_pil = Image.fromarray(heatmap) blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4) return blended # --------------------------- # Main Prediction Function # --------------------------- def predict(image, model_name, noise_level): try: if image is None: return {"Error": "Please upload an image"}, None, None if model_name is None: return {"Error": "Please select a model"}, None, None if noise_level > 0: image = add_adversarial_noise(image, noise_level) model, extractor = load_model(model_name) transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if MODEL_CONFIGS[model_name]["type"] == "hf": with torch.no_grad(): inputs = extractor(images=image, return_tensors="pt") outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=-1)[0] top5_prob, top5_idx = torch.topk(probs, k=5) top5_labels = [model.config.id2label[idx.item()] for idx in top5_idx] att_map = vit_attention_rollout(outputs) else: x = transform(image).unsqueeze(0) x.requires_grad = True with torch.no_grad(): outputs = model(x.detach()) probs = F.softmax(outputs, dim=-1)[0] top5_prob, top5_idx = torch.topk(probs, k=5) top5_labels = [IMAGENET_LABELS[idx.item()] for idx in top5_idx] att_map = get_gradcam(model, x) overlay = overlay_attention(image, att_map) result = {label: float(prob) for label, prob in zip(top5_labels, top5_prob)} return result, overlay, image except Exception as e: import traceback print(f"Error: {traceback.format_exc()}") return {"Error": str(e)}, None, None # --------------------------- # Class-Specific Attention with Confidence # --------------------------- def get_class_specific_attention(image, model_name, class_query): try: if image is None: return None, "Please upload an image first" if not class_query or class_query.strip() == "": return None, "Please enter a class name" class_query_lower = class_query.lower().strip() matching_idx = None matched_label = None confidence = 0.0 model, extractor = load_model(model_name) if MODEL_CONFIGS[model_name]["type"] == "hf": for idx, label in model.config.id2label.items(): if class_query_lower in label.lower(): matching_idx = idx matched_label = label break if matching_idx is None: return None, f"Class '{class_query}' not found in model labels." att_map, confidence = vit_attention_for_class(model, extractor, image, matching_idx) else: for idx, label in enumerate(IMAGENET_LABELS): if class_query_lower in label.lower(): matching_idx = idx matched_label = label break if matching_idx is None: return None, f"Class '{class_query}' not found in ImageNet labels." transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) x = transform(image).unsqueeze(0) x.requires_grad = True att_map = get_gradcam_for_class(model, x, matching_idx) with torch.no_grad(): outputs = model(x) confidence = F.softmax(outputs, dim=-1)[0, matching_idx].item() overlay = overlay_attention(image, att_map) return overlay, f"✓ Attention map generated for class: '{matched_label}' (Index: {matching_idx}, Confidence: {confidence:.2f})" except Exception as e: import traceback print(traceback.format_exc()) return None, f"Error generating attention map: {str(e)}" # --------------------------- # Sample Classes # --------------------------- SAMPLE_CLASSES = [ "cat", "dog", "tiger", "lion", "elephant", "car", "truck", "airplane", "ship", "train", "pizza", "hamburger", "coffee", "banana", "apple", "chair", "table", "laptop", "keyboard", "mouse", "person", "bicycle", "building", "tree", "flower" ] # --------------------------- # Gradio UI # --------------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 Enhanced Multi-Model Image Classifier") gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 13+ Models") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="📸 Upload Image") model_dropdown = gr.Dropdown( choices=list(MODEL_CONFIGS.keys()), label="🤖 Select Model", value="DeiT-Tiny" ) gr.Markdown("### 🎭 Adversarial Noise") noise_slider = gr.Slider( minimum=0, maximum=0.3, value=0, step=0.01, label="Noise Level (ε)", info="Add random noise to test model robustness" ) run_button = gr.Button("🚀 Run Model", variant="primary") with gr.Column(scale=2): output_label = gr.Label(num_top_classes=5, label="🎯 Top 5 Predictions") output_image = gr.Image(label="🔍 Attention Map (Top Prediction)") processed_image = gr.Image(label="🖼️ Processed Image (with noise if applied)", visible=False) gr.Markdown("---") gr.Markdown("### 🎨 Class-Specific Attention Visualization") gr.Markdown("*Type any class name to see where the model looks for that specific object*") with gr.Row(): with gr.Column(scale=1): class_input = gr.Textbox( label="🔍 Enter Class Name", placeholder="e.g., cat, dog, car, pizza...", info="Type any ImageNet class name" ) class_button = gr.Button("🎯 Generate Class-Specific Attention", variant="primary") gr.Markdown("**💡 Sample classes to try:**") sample_buttons = gr.Radio( choices=SAMPLE_CLASSES, label="Click to auto-fill", interactive=True ) with gr.Column(scale=2): class_output_image = gr.Image(label="🔍 Class-Specific Attention Map") class_status = gr.Textbox(label="Status", interactive=False) gr.Markdown("---") gr.Markdown(""" ### 💡 Tips: - **Adversarial Noise**: Adjust the slider to add random noise and see how robust the model is - **Class-Specific Attention**: Type any ImageNet class to visualize what the model looks for - **Model Variety**: Try different architectures (ViT, CNN, Hybrid) to compare their behavior """) run_button.click( predict, inputs=[input_image, model_dropdown, noise_slider], outputs=[output_label, output_image, processed_image] ) sample_buttons.change( lambda x: x, inputs=[sample_buttons], outputs=[class_input] ) class_button.click( get_class_specific_attention, inputs=[input_image, model_dropdown, class_input], outputs=[class_output_image, class_status] ) if __name__ == "__main__": demo.launch()