rollbak to old ui
Browse files
app.py
CHANGED
|
@@ -11,29 +11,35 @@ import urllib.request
|
|
| 11 |
import json
|
| 12 |
import cv2
|
| 13 |
|
|
|
|
| 14 |
# Model Configs
|
|
|
|
| 15 |
MODEL_CONFIGS = {
|
| 16 |
-
"DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"
|
| 17 |
-
"DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"
|
| 18 |
-
"ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"
|
| 19 |
-
"ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"
|
| 20 |
-
"ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano"
|
| 21 |
-
"EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"
|
| 22 |
-
"EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"
|
| 23 |
-
"ResNet-50": {"type": "timm", "id": "resnet50"
|
| 24 |
-
"MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"
|
| 25 |
-
"MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"
|
| 26 |
-
"MobileViT-Small": {"type": "timm", "id": "mobilevit_s"
|
| 27 |
-
"EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"
|
| 28 |
-
"RegNetY-002": {"type": "timm", "id": "regnety_002"
|
| 29 |
}
|
| 30 |
|
|
|
|
| 31 |
# ImageNet Labels
|
|
|
|
| 32 |
IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
|
| 33 |
with urllib.request.urlopen(IMAGENET_URL) as url:
|
| 34 |
IMAGENET_LABELS = json.load(url)
|
| 35 |
|
|
|
|
| 36 |
# Lazy Load
|
|
|
|
| 37 |
loaded_models = {}
|
| 38 |
|
| 39 |
def load_model(model_name):
|
|
@@ -45,17 +51,20 @@ def load_model(model_name):
|
|
| 45 |
extractor = AutoFeatureExtractor.from_pretrained(config["id"])
|
| 46 |
model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
|
| 47 |
model.eval()
|
|
|
|
| 48 |
for param in model.parameters():
|
| 49 |
param.requires_grad = True
|
| 50 |
elif config["type"] == "timm":
|
| 51 |
model = timm.create_model(config["id"], pretrained=True)
|
| 52 |
model.eval()
|
|
|
|
| 53 |
for param in model.parameters():
|
| 54 |
param.requires_grad = True
|
| 55 |
extractor = None
|
| 56 |
elif config["type"] == "efficientnet":
|
| 57 |
model = EfficientNet.from_pretrained(config["id"])
|
| 58 |
model.eval()
|
|
|
|
| 59 |
for param in model.parameters():
|
| 60 |
param.requires_grad = True
|
| 61 |
extractor = None
|
|
@@ -63,14 +72,21 @@ def load_model(model_name):
|
|
| 63 |
loaded_models[model_name] = (model, extractor)
|
| 64 |
return model, extractor
|
| 65 |
|
|
|
|
|
|
|
| 66 |
# Adversarial Noise
|
|
|
|
| 67 |
def add_adversarial_noise(image, epsilon):
|
|
|
|
| 68 |
img_array = np.array(image).astype(np.float32) / 255.0
|
| 69 |
noise = np.random.randn(*img_array.shape) * epsilon
|
| 70 |
noisy_img = np.clip(img_array + noise, 0, 1)
|
| 71 |
return Image.fromarray((noisy_img * 255).astype(np.uint8))
|
| 72 |
|
|
|
|
|
|
|
| 73 |
# Grad-CAM for Class-Specific Attention
|
|
|
|
| 74 |
def get_gradcam_for_class(model, image_tensor, class_idx):
|
| 75 |
grad = None
|
| 76 |
fmap = None
|
|
@@ -83,6 +99,7 @@ def get_gradcam_for_class(model, image_tensor, class_idx):
|
|
| 83 |
nonlocal grad
|
| 84 |
grad = grad_out[0].detach()
|
| 85 |
|
|
|
|
| 86 |
last_conv = None
|
| 87 |
for name, module in reversed(list(model.named_modules())):
|
| 88 |
if isinstance(module, torch.nn.Conv2d):
|
|
@@ -111,10 +128,15 @@ def get_gradcam_for_class(model, image_tensor, class_idx):
|
|
| 111 |
cam = cam.squeeze().cpu().numpy()
|
| 112 |
cam = cv2.resize(cam, (224, 224))
|
| 113 |
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
|
|
|
|
| 114 |
return cam
|
| 115 |
|
|
|
|
|
|
|
| 116 |
# ViT Attention for Class-Specific
|
|
|
|
| 117 |
def vit_attention_for_class(model, extractor, image, class_idx):
|
|
|
|
| 118 |
inputs = extractor(images=image, return_tensors="pt")
|
| 119 |
inputs['pixel_values'].requires_grad = True
|
| 120 |
outputs = model(**inputs)
|
|
@@ -123,6 +145,7 @@ def vit_attention_for_class(model, extractor, image, class_idx):
|
|
| 123 |
model.zero_grad()
|
| 124 |
score.backward()
|
| 125 |
|
|
|
|
| 126 |
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
|
| 127 |
attn = outputs.attentions[-1]
|
| 128 |
attn = attn.mean(1)
|
|
@@ -134,7 +157,10 @@ def vit_attention_for_class(model, extractor, image, class_idx):
|
|
| 134 |
|
| 135 |
return np.ones((14, 14))
|
| 136 |
|
|
|
|
|
|
|
| 137 |
# Grad-CAM Helper for CNNs (Top Prediction)
|
|
|
|
| 138 |
def get_gradcam(model, image_tensor):
|
| 139 |
grad = None
|
| 140 |
fmap = None
|
|
@@ -176,9 +202,13 @@ def get_gradcam(model, image_tensor):
|
|
| 176 |
cam = cam.squeeze().cpu().numpy()
|
| 177 |
cam = cv2.resize(cam, (224, 224))
|
| 178 |
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
|
|
|
|
| 179 |
return cam
|
| 180 |
|
|
|
|
|
|
|
| 181 |
# ViT Attention Rollout Helper
|
|
|
|
| 182 |
def vit_attention_rollout(outputs):
|
| 183 |
if not hasattr(outputs, 'attentions') or outputs.attentions is None:
|
| 184 |
return np.ones((14, 14))
|
|
@@ -191,12 +221,18 @@ def vit_attention_rollout(outputs):
|
|
| 191 |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
|
| 192 |
return attn_map
|
| 193 |
|
|
|
|
|
|
|
| 194 |
# Create Gradient Legend
|
|
|
|
| 195 |
def create_gradient_legend():
|
|
|
|
| 196 |
width, height = 400, 60
|
| 197 |
gradient = np.zeros((height, width, 3), dtype=np.uint8)
|
| 198 |
|
|
|
|
| 199 |
for i in range(width):
|
|
|
|
| 200 |
value = int(255 * i / width)
|
| 201 |
color_single = np.array([[[value]]], dtype=np.uint8)
|
| 202 |
color_rgb = cv2.applyColorMap(color_single, cv2.COLORMAP_JET)
|
|
@@ -204,20 +240,22 @@ def create_gradient_legend():
|
|
| 204 |
|
| 205 |
gradient = cv2.cvtColor(gradient, cv2.COLOR_BGR2RGB)
|
| 206 |
|
|
|
|
| 207 |
from PIL import ImageDraw, ImageFont
|
| 208 |
gradient_pil = Image.fromarray(gradient)
|
| 209 |
draw = ImageDraw.Draw(gradient_pil)
|
| 210 |
|
|
|
|
| 211 |
try:
|
| 212 |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
|
| 213 |
except:
|
| 214 |
font = ImageFont.load_default()
|
| 215 |
|
|
|
|
| 216 |
draw.text((10, 20), "Low Attention", fill=(255, 255, 255), font=font)
|
| 217 |
draw.text((width - 120, 20), "High Attention", fill=(255, 255, 255), font=font)
|
| 218 |
|
| 219 |
return gradient_pil
|
| 220 |
-
|
| 221 |
def overlay_attention(pil_img, attention_map):
|
| 222 |
heatmap = (attention_map * 255).astype(np.uint8)
|
| 223 |
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
|
@@ -227,7 +265,10 @@ def overlay_attention(pil_img, attention_map):
|
|
| 227 |
blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
|
| 228 |
return blended
|
| 229 |
|
|
|
|
|
|
|
| 230 |
# Main Prediction Function
|
|
|
|
| 231 |
def predict(image, model_name, noise_level):
|
| 232 |
try:
|
| 233 |
if image is None:
|
|
@@ -236,9 +277,7 @@ def predict(image, model_name, noise_level):
|
|
| 236 |
if model_name is None:
|
| 237 |
return {"Error": "Please select a model"}, None, None
|
| 238 |
|
| 239 |
-
#
|
| 240 |
-
model_name = model_name.split(" - ")[0]
|
| 241 |
-
|
| 242 |
if noise_level > 0:
|
| 243 |
image = add_adversarial_noise(image, noise_level)
|
| 244 |
|
|
@@ -246,7 +285,8 @@ def predict(image, model_name, noise_level):
|
|
| 246 |
transform = T.Compose([
|
| 247 |
T.Resize((224, 224)),
|
| 248 |
T.ToTensor(),
|
| 249 |
-
T.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
|
| 250 |
])
|
| 251 |
|
| 252 |
if MODEL_CONFIGS[model_name]["type"] == "hf":
|
|
@@ -280,7 +320,10 @@ def predict(image, model_name, noise_level):
|
|
| 280 |
print(error_msg)
|
| 281 |
return {"Error": str(e)}, None, None
|
| 282 |
|
|
|
|
|
|
|
| 283 |
# Class-Specific Attention
|
|
|
|
| 284 |
def get_class_specific_attention(image, model_name, class_query):
|
| 285 |
try:
|
| 286 |
if image is None:
|
|
@@ -289,9 +332,7 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 289 |
if not class_query or class_query.strip() == "":
|
| 290 |
return None, None, "Please enter a class name"
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
model_name = model_name.split(" - ")[0]
|
| 294 |
-
|
| 295 |
class_query_lower = class_query.lower().strip()
|
| 296 |
matching_idx = None
|
| 297 |
matched_label = None
|
|
@@ -299,6 +340,7 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 299 |
model, extractor = load_model(model_name)
|
| 300 |
|
| 301 |
if MODEL_CONFIGS[model_name]["type"] == "hf":
|
|
|
|
| 302 |
for idx, label in model.config.id2label.items():
|
| 303 |
if class_query_lower in label.lower():
|
| 304 |
matching_idx = idx
|
|
@@ -308,9 +350,11 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 308 |
if matching_idx is None:
|
| 309 |
return None, None, f"Class '{class_query}' not found in model labels. Try a different class name or check sample classes."
|
| 310 |
|
|
|
|
| 311 |
att_map = vit_attention_for_class(model, extractor, image, matching_idx)
|
| 312 |
|
| 313 |
else:
|
|
|
|
| 314 |
for idx, label in enumerate(IMAGENET_LABELS):
|
| 315 |
if class_query_lower in label.lower():
|
| 316 |
matching_idx = idx
|
|
@@ -320,10 +364,12 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 320 |
if matching_idx is None:
|
| 321 |
return None, None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check sample classes."
|
| 322 |
|
|
|
|
| 323 |
transform = T.Compose([
|
| 324 |
T.Resize((224, 224)),
|
| 325 |
T.ToTensor(),
|
| 326 |
-
T.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
|
| 327 |
])
|
| 328 |
x = transform(image).unsqueeze(0)
|
| 329 |
x.requires_grad = True
|
|
@@ -339,7 +385,10 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 339 |
print(error_trace)
|
| 340 |
return None, None, f"Error generating attention map: {str(e)}"
|
| 341 |
|
|
|
|
|
|
|
| 342 |
# Sample Classes
|
|
|
|
| 343 |
SAMPLE_CLASSES = [
|
| 344 |
"cat", "dog", "tiger", "lion", "elephant",
|
| 345 |
"car", "truck", "airplane", "ship", "train",
|
|
@@ -348,122 +397,91 @@ SAMPLE_CLASSES = [
|
|
| 348 |
"person", "bicycle", "building", "tree", "flower"
|
| 349 |
]
|
| 350 |
|
| 351 |
-
# Improved Gradio UI
|
| 352 |
-
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="gray", font=["Inter", "sans-serif"])) as demo:
|
| 353 |
-
gr.Markdown("""
|
| 354 |
-
# π§ Advanced Image Classification Studio
|
| 355 |
-
Explore state-of-the-art image classification with multiple models, adversarial testing, and attention visualization.
|
| 356 |
-
""", elem_classes=["header-text"])
|
| 357 |
-
|
| 358 |
-
with gr.Tabs() as tabs:
|
| 359 |
-
with gr.TabItem("π Predict & Analyze"):
|
| 360 |
-
with gr.Row(variant="panel"):
|
| 361 |
-
with gr.Column(scale=1, min_width=300):
|
| 362 |
-
gr.Markdown("### π· Input")
|
| 363 |
-
input_image = gr.Image(type="pil", label="Upload Image", height=300, interactive=True)
|
| 364 |
-
model_dropdown = gr.Dropdown(
|
| 365 |
-
choices=[f"{name} - {MODEL_CONFIGS[name]['desc']}" for name in MODEL_CONFIGS.keys()],
|
| 366 |
-
label="Select Model",
|
| 367 |
-
value="DeiT-Tiny - Lightweight Vision Transformer",
|
| 368 |
-
interactive=True,
|
| 369 |
-
info="Choose from various architectures (Transformers, CNNs, Hybrids)"
|
| 370 |
-
)
|
| 371 |
-
with gr.Group():
|
| 372 |
-
gr.Markdown("### π Adversarial Testing")
|
| 373 |
-
noise_slider = gr.Slider(
|
| 374 |
-
minimum=0, maximum=0.3, value=0, step=0.01,
|
| 375 |
-
label="Noise Level (Ξ΅)",
|
| 376 |
-
info="Add noise to test model robustness",
|
| 377 |
-
interactive=True
|
| 378 |
-
)
|
| 379 |
-
run_button = gr.Button("π Run Prediction", variant="primary")
|
| 380 |
-
|
| 381 |
-
with gr.Column(scale=2):
|
| 382 |
-
gr.Markdown("### π Results")
|
| 383 |
-
output_label = gr.Label(num_top_classes=5, label="Top 5 Predictions", show_label=True)
|
| 384 |
-
with gr.Row():
|
| 385 |
-
output_image = gr.Image(label="Attention Map (Top Prediction)", height=300)
|
| 386 |
-
processed_image = gr.Image(label="Processed Image (with noise)", height=300, visible=False)
|
| 387 |
-
|
| 388 |
-
with gr.TabItem("π¨ Class-Specific Attention"):
|
| 389 |
-
gr.Markdown("### Visualize Model Attention for Specific Classes")
|
| 390 |
-
with gr.Row(variant="panel"):
|
| 391 |
-
with gr.Column(scale=1, min_width=300):
|
| 392 |
-
class_input = gr.Textbox(
|
| 393 |
-
label="Enter Class Name",
|
| 394 |
-
placeholder="e.g., cat, dog, car, pizza...",
|
| 395 |
-
info="Type any ImageNet class name",
|
| 396 |
-
interactive=True
|
| 397 |
-
)
|
| 398 |
-
class_button = gr.Button("π― Generate Attention Map", variant="primary")
|
| 399 |
-
with gr.Accordion("π‘ Sample Classes", open=False):
|
| 400 |
-
sample_buttons = gr.CheckboxGroup(
|
| 401 |
-
choices=SAMPLE_CLASSES,
|
| 402 |
-
label="Select or click to auto-fill",
|
| 403 |
-
interactive=True
|
| 404 |
-
)
|
| 405 |
-
|
| 406 |
-
with gr.Column(scale=2):
|
| 407 |
-
class_output_image = gr.Image(label="Class-Specific Attention Map", height=300)
|
| 408 |
-
gradient_legend = gr.Image(label="Attention Scale", interactive=False)
|
| 409 |
-
class_status = gr.Textbox(label="Status", interactive=False, lines=2)
|
| 410 |
-
|
| 411 |
-
with gr.TabItem("βΉοΈ About Models"):
|
| 412 |
-
gr.Markdown("""
|
| 413 |
-
### Available Models
|
| 414 |
-
Explore different architectures and their strengths:
|
| 415 |
-
""")
|
| 416 |
-
for model_name, config in MODEL_CONFIGS.items():
|
| 417 |
-
with gr.Accordion(f"{model_name}", open=False):
|
| 418 |
-
gr.Markdown(f"- **Type**: {config['type'].upper()}")
|
| 419 |
-
gr.Markdown(f"- **Description**: {config['desc']}")
|
| 420 |
-
gr.Markdown(f"- **Model ID**: {config['id']}")
|
| 421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
gr.Markdown("""
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
# Event Handlers
|
| 432 |
-
def update_class_input(selected_classes):
|
| 433 |
-
return selected_classes[0] if selected_classes else ""
|
| 434 |
-
|
| 435 |
run_button.click(
|
| 436 |
-
|
| 437 |
inputs=[input_image, model_dropdown, noise_slider],
|
| 438 |
-
outputs=[output_label, output_image, processed_image]
|
| 439 |
-
show_progress=True
|
| 440 |
)
|
| 441 |
-
|
|
|
|
| 442 |
sample_buttons.change(
|
| 443 |
-
|
| 444 |
inputs=[sample_buttons],
|
| 445 |
outputs=[class_input]
|
| 446 |
)
|
| 447 |
-
|
|
|
|
| 448 |
class_button.click(
|
| 449 |
-
|
| 450 |
inputs=[input_image, model_dropdown, class_input],
|
| 451 |
-
outputs=[class_output_image, gradient_legend, class_status]
|
| 452 |
-
show_progress=True
|
| 453 |
)
|
| 454 |
|
| 455 |
-
# Add custom CSS for improved styling
|
| 456 |
-
gr.HTML("""
|
| 457 |
-
<style>
|
| 458 |
-
.header-text { font-size: 2rem; font-weight: bold; color: #1E3A8A; margin-bottom: 1rem; }
|
| 459 |
-
.footer-text { font-size: 0.9rem; color: #4B5563; }
|
| 460 |
-
.gr-button { transition: all 0.3s ease; }
|
| 461 |
-
.gr-button:hover { transform: scale(1.05); }
|
| 462 |
-
.gr-panel { border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
| 463 |
-
.gr-image { border-radius: 8px; }
|
| 464 |
-
.gr-accordion { margin-bottom: 1rem; }
|
| 465 |
-
</style>
|
| 466 |
-
""")
|
| 467 |
-
|
| 468 |
if __name__ == "__main__":
|
| 469 |
demo.launch()
|
|
|
|
| 11 |
import json
|
| 12 |
import cv2
|
| 13 |
|
| 14 |
+
# ---------------------------
|
| 15 |
# Model Configs
|
| 16 |
+
# ---------------------------
|
| 17 |
MODEL_CONFIGS = {
|
| 18 |
+
"DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"},
|
| 19 |
+
"DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"},
|
| 20 |
+
"ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"},
|
| 21 |
+
"ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"},
|
| 22 |
+
"ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano"},
|
| 23 |
+
"EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"},
|
| 24 |
+
"EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"},
|
| 25 |
+
"ResNet-50": {"type": "timm", "id": "resnet50"},
|
| 26 |
+
"MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"},
|
| 27 |
+
"MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"},
|
| 28 |
+
"MobileViT-Small": {"type": "timm", "id": "mobilevit_s"},
|
| 29 |
+
"EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"},
|
| 30 |
+
"RegNetY-002": {"type": "timm", "id": "regnety_002"}
|
| 31 |
}
|
| 32 |
|
| 33 |
+
# ---------------------------
|
| 34 |
# ImageNet Labels
|
| 35 |
+
# ---------------------------
|
| 36 |
IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
|
| 37 |
with urllib.request.urlopen(IMAGENET_URL) as url:
|
| 38 |
IMAGENET_LABELS = json.load(url)
|
| 39 |
|
| 40 |
+
# ---------------------------
|
| 41 |
# Lazy Load
|
| 42 |
+
# ---------------------------
|
| 43 |
loaded_models = {}
|
| 44 |
|
| 45 |
def load_model(model_name):
|
|
|
|
| 51 |
extractor = AutoFeatureExtractor.from_pretrained(config["id"])
|
| 52 |
model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
|
| 53 |
model.eval()
|
| 54 |
+
# Enable gradients for class-specific attention
|
| 55 |
for param in model.parameters():
|
| 56 |
param.requires_grad = True
|
| 57 |
elif config["type"] == "timm":
|
| 58 |
model = timm.create_model(config["id"], pretrained=True)
|
| 59 |
model.eval()
|
| 60 |
+
# Enable gradients for class-specific attention
|
| 61 |
for param in model.parameters():
|
| 62 |
param.requires_grad = True
|
| 63 |
extractor = None
|
| 64 |
elif config["type"] == "efficientnet":
|
| 65 |
model = EfficientNet.from_pretrained(config["id"])
|
| 66 |
model.eval()
|
| 67 |
+
# Enable gradients for class-specific attention
|
| 68 |
for param in model.parameters():
|
| 69 |
param.requires_grad = True
|
| 70 |
extractor = None
|
|
|
|
| 72 |
loaded_models[model_name] = (model, extractor)
|
| 73 |
return model, extractor
|
| 74 |
|
| 75 |
+
|
| 76 |
+
# ---------------------------
|
| 77 |
# Adversarial Noise
|
| 78 |
+
# ---------------------------
|
| 79 |
def add_adversarial_noise(image, epsilon):
|
| 80 |
+
"""Add random noise to image"""
|
| 81 |
img_array = np.array(image).astype(np.float32) / 255.0
|
| 82 |
noise = np.random.randn(*img_array.shape) * epsilon
|
| 83 |
noisy_img = np.clip(img_array + noise, 0, 1)
|
| 84 |
return Image.fromarray((noisy_img * 255).astype(np.uint8))
|
| 85 |
|
| 86 |
+
|
| 87 |
+
# ---------------------------
|
| 88 |
# Grad-CAM for Class-Specific Attention
|
| 89 |
+
# ---------------------------
|
| 90 |
def get_gradcam_for_class(model, image_tensor, class_idx):
|
| 91 |
grad = None
|
| 92 |
fmap = None
|
|
|
|
| 99 |
nonlocal grad
|
| 100 |
grad = grad_out[0].detach()
|
| 101 |
|
| 102 |
+
# Find last conv layer
|
| 103 |
last_conv = None
|
| 104 |
for name, module in reversed(list(model.named_modules())):
|
| 105 |
if isinstance(module, torch.nn.Conv2d):
|
|
|
|
| 128 |
cam = cam.squeeze().cpu().numpy()
|
| 129 |
cam = cv2.resize(cam, (224, 224))
|
| 130 |
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
|
| 131 |
+
|
| 132 |
return cam
|
| 133 |
|
| 134 |
+
|
| 135 |
+
# ---------------------------
|
| 136 |
# ViT Attention for Class-Specific
|
| 137 |
+
# ---------------------------
|
| 138 |
def vit_attention_for_class(model, extractor, image, class_idx):
|
| 139 |
+
"""Get attention map for specific class in ViT"""
|
| 140 |
inputs = extractor(images=image, return_tensors="pt")
|
| 141 |
inputs['pixel_values'].requires_grad = True
|
| 142 |
outputs = model(**inputs)
|
|
|
|
| 145 |
model.zero_grad()
|
| 146 |
score.backward()
|
| 147 |
|
| 148 |
+
# Use last layer attention
|
| 149 |
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
|
| 150 |
attn = outputs.attentions[-1]
|
| 151 |
attn = attn.mean(1)
|
|
|
|
| 157 |
|
| 158 |
return np.ones((14, 14))
|
| 159 |
|
| 160 |
+
|
| 161 |
+
# ---------------------------
|
| 162 |
# Grad-CAM Helper for CNNs (Top Prediction)
|
| 163 |
+
# ---------------------------
|
| 164 |
def get_gradcam(model, image_tensor):
|
| 165 |
grad = None
|
| 166 |
fmap = None
|
|
|
|
| 202 |
cam = cam.squeeze().cpu().numpy()
|
| 203 |
cam = cv2.resize(cam, (224, 224))
|
| 204 |
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
|
| 205 |
+
|
| 206 |
return cam
|
| 207 |
|
| 208 |
+
|
| 209 |
+
# ---------------------------
|
| 210 |
# ViT Attention Rollout Helper
|
| 211 |
+
# ---------------------------
|
| 212 |
def vit_attention_rollout(outputs):
|
| 213 |
if not hasattr(outputs, 'attentions') or outputs.attentions is None:
|
| 214 |
return np.ones((14, 14))
|
|
|
|
| 221 |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
|
| 222 |
return attn_map
|
| 223 |
|
| 224 |
+
|
| 225 |
+
# ---------------------------
|
| 226 |
# Create Gradient Legend
|
| 227 |
+
# ---------------------------
|
| 228 |
def create_gradient_legend():
|
| 229 |
+
"""Create a gradient legend image showing attention scale"""
|
| 230 |
width, height = 400, 60
|
| 231 |
gradient = np.zeros((height, width, 3), dtype=np.uint8)
|
| 232 |
|
| 233 |
+
# Create gradient from blue to red (matching COLORMAP_JET)
|
| 234 |
for i in range(width):
|
| 235 |
+
# OpenCV's COLORMAP_JET: blue (low) -> cyan -> green -> yellow -> red (high)
|
| 236 |
value = int(255 * i / width)
|
| 237 |
color_single = np.array([[[value]]], dtype=np.uint8)
|
| 238 |
color_rgb = cv2.applyColorMap(color_single, cv2.COLORMAP_JET)
|
|
|
|
| 240 |
|
| 241 |
gradient = cv2.cvtColor(gradient, cv2.COLOR_BGR2RGB)
|
| 242 |
|
| 243 |
+
# Convert to PIL and add text
|
| 244 |
from PIL import ImageDraw, ImageFont
|
| 245 |
gradient_pil = Image.fromarray(gradient)
|
| 246 |
draw = ImageDraw.Draw(gradient_pil)
|
| 247 |
|
| 248 |
+
# Use default font
|
| 249 |
try:
|
| 250 |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
|
| 251 |
except:
|
| 252 |
font = ImageFont.load_default()
|
| 253 |
|
| 254 |
+
# Add text labels
|
| 255 |
draw.text((10, 20), "Low Attention", fill=(255, 255, 255), font=font)
|
| 256 |
draw.text((width - 120, 20), "High Attention", fill=(255, 255, 255), font=font)
|
| 257 |
|
| 258 |
return gradient_pil
|
|
|
|
| 259 |
def overlay_attention(pil_img, attention_map):
|
| 260 |
heatmap = (attention_map * 255).astype(np.uint8)
|
| 261 |
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
|
|
|
| 265 |
blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
|
| 266 |
return blended
|
| 267 |
|
| 268 |
+
|
| 269 |
+
# ---------------------------
|
| 270 |
# Main Prediction Function
|
| 271 |
+
# ---------------------------
|
| 272 |
def predict(image, model_name, noise_level):
|
| 273 |
try:
|
| 274 |
if image is None:
|
|
|
|
| 277 |
if model_name is None:
|
| 278 |
return {"Error": "Please select a model"}, None, None
|
| 279 |
|
| 280 |
+
# Apply adversarial noise if requested
|
|
|
|
|
|
|
| 281 |
if noise_level > 0:
|
| 282 |
image = add_adversarial_noise(image, noise_level)
|
| 283 |
|
|
|
|
| 285 |
transform = T.Compose([
|
| 286 |
T.Resize((224, 224)),
|
| 287 |
T.ToTensor(),
|
| 288 |
+
T.Normalize(mean=[0.485, 0.456, 0.406],
|
| 289 |
+
std=[0.229, 0.224, 0.225])
|
| 290 |
])
|
| 291 |
|
| 292 |
if MODEL_CONFIGS[model_name]["type"] == "hf":
|
|
|
|
| 320 |
print(error_msg)
|
| 321 |
return {"Error": str(e)}, None, None
|
| 322 |
|
| 323 |
+
|
| 324 |
+
# ---------------------------
|
| 325 |
# Class-Specific Attention
|
| 326 |
+
# ---------------------------
|
| 327 |
def get_class_specific_attention(image, model_name, class_query):
|
| 328 |
try:
|
| 329 |
if image is None:
|
|
|
|
| 332 |
if not class_query or class_query.strip() == "":
|
| 333 |
return None, None, "Please enter a class name"
|
| 334 |
|
| 335 |
+
# Find matching class
|
|
|
|
|
|
|
| 336 |
class_query_lower = class_query.lower().strip()
|
| 337 |
matching_idx = None
|
| 338 |
matched_label = None
|
|
|
|
| 340 |
model, extractor = load_model(model_name)
|
| 341 |
|
| 342 |
if MODEL_CONFIGS[model_name]["type"] == "hf":
|
| 343 |
+
# Search in HF model labels
|
| 344 |
for idx, label in model.config.id2label.items():
|
| 345 |
if class_query_lower in label.lower():
|
| 346 |
matching_idx = idx
|
|
|
|
| 350 |
if matching_idx is None:
|
| 351 |
return None, None, f"Class '{class_query}' not found in model labels. Try a different class name or check sample classes."
|
| 352 |
|
| 353 |
+
# Get attention for this class
|
| 354 |
att_map = vit_attention_for_class(model, extractor, image, matching_idx)
|
| 355 |
|
| 356 |
else:
|
| 357 |
+
# Search in ImageNet labels
|
| 358 |
for idx, label in enumerate(IMAGENET_LABELS):
|
| 359 |
if class_query_lower in label.lower():
|
| 360 |
matching_idx = idx
|
|
|
|
| 364 |
if matching_idx is None:
|
| 365 |
return None, None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check sample classes."
|
| 366 |
|
| 367 |
+
# Get Grad-CAM for this class
|
| 368 |
transform = T.Compose([
|
| 369 |
T.Resize((224, 224)),
|
| 370 |
T.ToTensor(),
|
| 371 |
+
T.Normalize(mean=[0.485, 0.456, 0.406],
|
| 372 |
+
std=[0.229, 0.224, 0.225])
|
| 373 |
])
|
| 374 |
x = transform(image).unsqueeze(0)
|
| 375 |
x.requires_grad = True
|
|
|
|
| 385 |
print(error_trace)
|
| 386 |
return None, None, f"Error generating attention map: {str(e)}"
|
| 387 |
|
| 388 |
+
|
| 389 |
+
# ---------------------------
|
| 390 |
# Sample Classes
|
| 391 |
+
# ---------------------------
|
| 392 |
SAMPLE_CLASSES = [
|
| 393 |
"cat", "dog", "tiger", "lion", "elephant",
|
| 394 |
"car", "truck", "airplane", "ship", "train",
|
|
|
|
| 397 |
"person", "bicycle", "building", "tree", "flower"
|
| 398 |
]
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
+
# ---------------------------
|
| 402 |
+
# Gradio UI
|
| 403 |
+
# ---------------------------
|
| 404 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 405 |
+
gr.Markdown("# π§ Enhanced Multi-Model Image Classifier")
|
| 406 |
+
gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 13+ Models")
|
| 407 |
+
|
| 408 |
+
with gr.Row():
|
| 409 |
+
with gr.Column(scale=1):
|
| 410 |
+
input_image = gr.Image(type="pil", label="πΈ Upload Image")
|
| 411 |
+
model_dropdown = gr.Dropdown(
|
| 412 |
+
choices=list(MODEL_CONFIGS.keys()),
|
| 413 |
+
label="π€ Select Model",
|
| 414 |
+
value="DeiT-Tiny"
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
gr.Markdown("### π Adversarial Noise")
|
| 418 |
+
noise_slider = gr.Slider(
|
| 419 |
+
minimum=0,
|
| 420 |
+
maximum=0.3,
|
| 421 |
+
value=0,
|
| 422 |
+
step=0.01,
|
| 423 |
+
label="Noise Level (Ξ΅)",
|
| 424 |
+
info="Add random noise to test model robustness"
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
run_button = gr.Button("π Run Model", variant="primary")
|
| 428 |
+
|
| 429 |
+
with gr.Column(scale=2):
|
| 430 |
+
output_label = gr.Label(num_top_classes=5, label="π― Top 5 Predictions")
|
| 431 |
+
output_image = gr.Image(label="π Attention Map (Top Prediction)")
|
| 432 |
+
processed_image = gr.Image(label="πΌοΈ Processed Image (with noise if applied)", visible=False)
|
| 433 |
+
|
| 434 |
+
gr.Markdown("---")
|
| 435 |
+
gr.Markdown("### π¨ Class-Specific Attention Visualization")
|
| 436 |
+
gr.Markdown("*Type any class name to see where the model looks for that specific object*")
|
| 437 |
+
|
| 438 |
+
with gr.Row():
|
| 439 |
+
with gr.Column(scale=1):
|
| 440 |
+
class_input = gr.Textbox(
|
| 441 |
+
label="π Enter Class Name",
|
| 442 |
+
placeholder="e.g., cat, dog, car, pizza...",
|
| 443 |
+
info="Type any ImageNet class name"
|
| 444 |
+
)
|
| 445 |
+
class_button = gr.Button("π― Generate Class-Specific Attention", variant="primary")
|
| 446 |
+
gr.Markdown("**π‘ Sample classes to try:**")
|
| 447 |
+
sample_buttons = gr.Radio(
|
| 448 |
+
choices=SAMPLE_CLASSES,
|
| 449 |
+
label="Click to auto-fill",
|
| 450 |
+
interactive=True
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
with gr.Column(scale=2):
|
| 454 |
+
class_output_image = gr.Image(label="π Class-Specific Attention Map")
|
| 455 |
+
class_status = gr.Textbox(label="Status", interactive=False)
|
| 456 |
+
|
| 457 |
+
gr.Markdown("---")
|
| 458 |
gr.Markdown("""
|
| 459 |
+
### π‘ Tips:
|
| 460 |
+
- **Adversarial Noise**: Adjust the slider to add random noise and see how robust the model is
|
| 461 |
+
- **Class-Specific Attention**: Type any ImageNet class to visualize what the model looks for (e.g., "tiger", "sports car", "pizza")
|
| 462 |
+
- **Model Variety**: Try different architectures (ViT, CNN, Hybrid) to compare their behavior
|
| 463 |
+
""")
|
| 464 |
+
|
| 465 |
+
# Event handlers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
run_button.click(
|
| 467 |
+
predict,
|
| 468 |
inputs=[input_image, model_dropdown, noise_slider],
|
| 469 |
+
outputs=[output_label, output_image, processed_image]
|
|
|
|
| 470 |
)
|
| 471 |
+
|
| 472 |
+
# When user selects a sample class, update the text input
|
| 473 |
sample_buttons.change(
|
| 474 |
+
lambda x: x,
|
| 475 |
inputs=[sample_buttons],
|
| 476 |
outputs=[class_input]
|
| 477 |
)
|
| 478 |
+
|
| 479 |
+
# Generate attention map
|
| 480 |
class_button.click(
|
| 481 |
+
get_class_specific_attention,
|
| 482 |
inputs=[input_image, model_dropdown, class_input],
|
| 483 |
+
outputs=[class_output_image, gradient_legend, class_status]
|
|
|
|
| 484 |
)
|
| 485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
if __name__ == "__main__":
|
| 487 |
demo.launch()
|