Godreign commited on
Commit
9b1dcad
Β·
verified Β·
1 Parent(s): a7838bb

rollbak to old ui

Browse files
Files changed (1) hide show
  1. app.py +144 -126
app.py CHANGED
@@ -11,29 +11,35 @@ import urllib.request
11
  import json
12
  import cv2
13
 
 
14
  # Model Configs
 
15
  MODEL_CONFIGS = {
16
- "DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224", "desc": "Lightweight Vision Transformer"},
17
- "DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224", "desc": "Small Vision Transformer"},
18
- "ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224", "desc": "Base Vision Transformer"},
19
- "ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny", "desc": "Tiny ConvNeXt CNN"},
20
- "ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano", "desc": "Nano ConvNeXt CNN"},
21
- "EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0", "desc": "EfficientNet B0"},
22
- "EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1", "desc": "EfficientNet B1"},
23
- "ResNet-50": {"type": "timm", "id": "resnet50", "desc": "Classic ResNet-50 CNN"},
24
- "MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100", "desc": "Lightweight MobileNet-V2"},
25
- "MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224", "desc": "Tiny MaxViT Hybrid"},
26
- "MobileViT-Small": {"type": "timm", "id": "mobilevit_s", "desc": "Small MobileViT"},
27
- "EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small", "desc": "Small EdgeNeXt"},
28
- "RegNetY-002": {"type": "timm", "id": "regnety_002", "desc": "RegNetY-002 CNN"}
29
  }
30
 
 
31
  # ImageNet Labels
 
32
  IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
33
  with urllib.request.urlopen(IMAGENET_URL) as url:
34
  IMAGENET_LABELS = json.load(url)
35
 
 
36
  # Lazy Load
 
37
  loaded_models = {}
38
 
39
  def load_model(model_name):
@@ -45,17 +51,20 @@ def load_model(model_name):
45
  extractor = AutoFeatureExtractor.from_pretrained(config["id"])
46
  model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
47
  model.eval()
 
48
  for param in model.parameters():
49
  param.requires_grad = True
50
  elif config["type"] == "timm":
51
  model = timm.create_model(config["id"], pretrained=True)
52
  model.eval()
 
53
  for param in model.parameters():
54
  param.requires_grad = True
55
  extractor = None
56
  elif config["type"] == "efficientnet":
57
  model = EfficientNet.from_pretrained(config["id"])
58
  model.eval()
 
59
  for param in model.parameters():
60
  param.requires_grad = True
61
  extractor = None
@@ -63,14 +72,21 @@ def load_model(model_name):
63
  loaded_models[model_name] = (model, extractor)
64
  return model, extractor
65
 
 
 
66
  # Adversarial Noise
 
67
  def add_adversarial_noise(image, epsilon):
 
68
  img_array = np.array(image).astype(np.float32) / 255.0
69
  noise = np.random.randn(*img_array.shape) * epsilon
70
  noisy_img = np.clip(img_array + noise, 0, 1)
71
  return Image.fromarray((noisy_img * 255).astype(np.uint8))
72
 
 
 
73
  # Grad-CAM for Class-Specific Attention
 
74
  def get_gradcam_for_class(model, image_tensor, class_idx):
75
  grad = None
76
  fmap = None
@@ -83,6 +99,7 @@ def get_gradcam_for_class(model, image_tensor, class_idx):
83
  nonlocal grad
84
  grad = grad_out[0].detach()
85
 
 
86
  last_conv = None
87
  for name, module in reversed(list(model.named_modules())):
88
  if isinstance(module, torch.nn.Conv2d):
@@ -111,10 +128,15 @@ def get_gradcam_for_class(model, image_tensor, class_idx):
111
  cam = cam.squeeze().cpu().numpy()
112
  cam = cv2.resize(cam, (224, 224))
113
  cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
 
114
  return cam
115
 
 
 
116
  # ViT Attention for Class-Specific
 
117
  def vit_attention_for_class(model, extractor, image, class_idx):
 
118
  inputs = extractor(images=image, return_tensors="pt")
119
  inputs['pixel_values'].requires_grad = True
120
  outputs = model(**inputs)
@@ -123,6 +145,7 @@ def vit_attention_for_class(model, extractor, image, class_idx):
123
  model.zero_grad()
124
  score.backward()
125
 
 
126
  if hasattr(outputs, 'attentions') and outputs.attentions is not None:
127
  attn = outputs.attentions[-1]
128
  attn = attn.mean(1)
@@ -134,7 +157,10 @@ def vit_attention_for_class(model, extractor, image, class_idx):
134
 
135
  return np.ones((14, 14))
136
 
 
 
137
  # Grad-CAM Helper for CNNs (Top Prediction)
 
138
  def get_gradcam(model, image_tensor):
139
  grad = None
140
  fmap = None
@@ -176,9 +202,13 @@ def get_gradcam(model, image_tensor):
176
  cam = cam.squeeze().cpu().numpy()
177
  cam = cv2.resize(cam, (224, 224))
178
  cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
 
179
  return cam
180
 
 
 
181
  # ViT Attention Rollout Helper
 
182
  def vit_attention_rollout(outputs):
183
  if not hasattr(outputs, 'attentions') or outputs.attentions is None:
184
  return np.ones((14, 14))
@@ -191,12 +221,18 @@ def vit_attention_rollout(outputs):
191
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
192
  return attn_map
193
 
 
 
194
  # Create Gradient Legend
 
195
  def create_gradient_legend():
 
196
  width, height = 400, 60
197
  gradient = np.zeros((height, width, 3), dtype=np.uint8)
198
 
 
199
  for i in range(width):
 
200
  value = int(255 * i / width)
201
  color_single = np.array([[[value]]], dtype=np.uint8)
202
  color_rgb = cv2.applyColorMap(color_single, cv2.COLORMAP_JET)
@@ -204,20 +240,22 @@ def create_gradient_legend():
204
 
205
  gradient = cv2.cvtColor(gradient, cv2.COLOR_BGR2RGB)
206
 
 
207
  from PIL import ImageDraw, ImageFont
208
  gradient_pil = Image.fromarray(gradient)
209
  draw = ImageDraw.Draw(gradient_pil)
210
 
 
211
  try:
212
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
213
  except:
214
  font = ImageFont.load_default()
215
 
 
216
  draw.text((10, 20), "Low Attention", fill=(255, 255, 255), font=font)
217
  draw.text((width - 120, 20), "High Attention", fill=(255, 255, 255), font=font)
218
 
219
  return gradient_pil
220
-
221
  def overlay_attention(pil_img, attention_map):
222
  heatmap = (attention_map * 255).astype(np.uint8)
223
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
@@ -227,7 +265,10 @@ def overlay_attention(pil_img, attention_map):
227
  blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
228
  return blended
229
 
 
 
230
  # Main Prediction Function
 
231
  def predict(image, model_name, noise_level):
232
  try:
233
  if image is None:
@@ -236,9 +277,7 @@ def predict(image, model_name, noise_level):
236
  if model_name is None:
237
  return {"Error": "Please select a model"}, None, None
238
 
239
- # Extract model name from dropdown (remove description)
240
- model_name = model_name.split(" - ")[0]
241
-
242
  if noise_level > 0:
243
  image = add_adversarial_noise(image, noise_level)
244
 
@@ -246,7 +285,8 @@ def predict(image, model_name, noise_level):
246
  transform = T.Compose([
247
  T.Resize((224, 224)),
248
  T.ToTensor(),
249
- T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
250
  ])
251
 
252
  if MODEL_CONFIGS[model_name]["type"] == "hf":
@@ -280,7 +320,10 @@ def predict(image, model_name, noise_level):
280
  print(error_msg)
281
  return {"Error": str(e)}, None, None
282
 
 
 
283
  # Class-Specific Attention
 
284
  def get_class_specific_attention(image, model_name, class_query):
285
  try:
286
  if image is None:
@@ -289,9 +332,7 @@ def get_class_specific_attention(image, model_name, class_query):
289
  if not class_query or class_query.strip() == "":
290
  return None, None, "Please enter a class name"
291
 
292
- # Extract model name from dropdown (remove description)
293
- model_name = model_name.split(" - ")[0]
294
-
295
  class_query_lower = class_query.lower().strip()
296
  matching_idx = None
297
  matched_label = None
@@ -299,6 +340,7 @@ def get_class_specific_attention(image, model_name, class_query):
299
  model, extractor = load_model(model_name)
300
 
301
  if MODEL_CONFIGS[model_name]["type"] == "hf":
 
302
  for idx, label in model.config.id2label.items():
303
  if class_query_lower in label.lower():
304
  matching_idx = idx
@@ -308,9 +350,11 @@ def get_class_specific_attention(image, model_name, class_query):
308
  if matching_idx is None:
309
  return None, None, f"Class '{class_query}' not found in model labels. Try a different class name or check sample classes."
310
 
 
311
  att_map = vit_attention_for_class(model, extractor, image, matching_idx)
312
 
313
  else:
 
314
  for idx, label in enumerate(IMAGENET_LABELS):
315
  if class_query_lower in label.lower():
316
  matching_idx = idx
@@ -320,10 +364,12 @@ def get_class_specific_attention(image, model_name, class_query):
320
  if matching_idx is None:
321
  return None, None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check sample classes."
322
 
 
323
  transform = T.Compose([
324
  T.Resize((224, 224)),
325
  T.ToTensor(),
326
- T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
327
  ])
328
  x = transform(image).unsqueeze(0)
329
  x.requires_grad = True
@@ -339,7 +385,10 @@ def get_class_specific_attention(image, model_name, class_query):
339
  print(error_trace)
340
  return None, None, f"Error generating attention map: {str(e)}"
341
 
 
 
342
  # Sample Classes
 
343
  SAMPLE_CLASSES = [
344
  "cat", "dog", "tiger", "lion", "elephant",
345
  "car", "truck", "airplane", "ship", "train",
@@ -348,122 +397,91 @@ SAMPLE_CLASSES = [
348
  "person", "bicycle", "building", "tree", "flower"
349
  ]
350
 
351
- # Improved Gradio UI
352
- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="gray", font=["Inter", "sans-serif"])) as demo:
353
- gr.Markdown("""
354
- # 🧠 Advanced Image Classification Studio
355
- Explore state-of-the-art image classification with multiple models, adversarial testing, and attention visualization.
356
- """, elem_classes=["header-text"])
357
-
358
- with gr.Tabs() as tabs:
359
- with gr.TabItem("πŸ” Predict & Analyze"):
360
- with gr.Row(variant="panel"):
361
- with gr.Column(scale=1, min_width=300):
362
- gr.Markdown("### πŸ“· Input")
363
- input_image = gr.Image(type="pil", label="Upload Image", height=300, interactive=True)
364
- model_dropdown = gr.Dropdown(
365
- choices=[f"{name} - {MODEL_CONFIGS[name]['desc']}" for name in MODEL_CONFIGS.keys()],
366
- label="Select Model",
367
- value="DeiT-Tiny - Lightweight Vision Transformer",
368
- interactive=True,
369
- info="Choose from various architectures (Transformers, CNNs, Hybrids)"
370
- )
371
- with gr.Group():
372
- gr.Markdown("### 🎭 Adversarial Testing")
373
- noise_slider = gr.Slider(
374
- minimum=0, maximum=0.3, value=0, step=0.01,
375
- label="Noise Level (Ξ΅)",
376
- info="Add noise to test model robustness",
377
- interactive=True
378
- )
379
- run_button = gr.Button("πŸš€ Run Prediction", variant="primary")
380
-
381
- with gr.Column(scale=2):
382
- gr.Markdown("### πŸ“Š Results")
383
- output_label = gr.Label(num_top_classes=5, label="Top 5 Predictions", show_label=True)
384
- with gr.Row():
385
- output_image = gr.Image(label="Attention Map (Top Prediction)", height=300)
386
- processed_image = gr.Image(label="Processed Image (with noise)", height=300, visible=False)
387
-
388
- with gr.TabItem("🎨 Class-Specific Attention"):
389
- gr.Markdown("### Visualize Model Attention for Specific Classes")
390
- with gr.Row(variant="panel"):
391
- with gr.Column(scale=1, min_width=300):
392
- class_input = gr.Textbox(
393
- label="Enter Class Name",
394
- placeholder="e.g., cat, dog, car, pizza...",
395
- info="Type any ImageNet class name",
396
- interactive=True
397
- )
398
- class_button = gr.Button("🎯 Generate Attention Map", variant="primary")
399
- with gr.Accordion("πŸ’‘ Sample Classes", open=False):
400
- sample_buttons = gr.CheckboxGroup(
401
- choices=SAMPLE_CLASSES,
402
- label="Select or click to auto-fill",
403
- interactive=True
404
- )
405
-
406
- with gr.Column(scale=2):
407
- class_output_image = gr.Image(label="Class-Specific Attention Map", height=300)
408
- gradient_legend = gr.Image(label="Attention Scale", interactive=False)
409
- class_status = gr.Textbox(label="Status", interactive=False, lines=2)
410
-
411
- with gr.TabItem("ℹ️ About Models"):
412
- gr.Markdown("""
413
- ### Available Models
414
- Explore different architectures and their strengths:
415
- """)
416
- for model_name, config in MODEL_CONFIGS.items():
417
- with gr.Accordion(f"{model_name}", open=False):
418
- gr.Markdown(f"- **Type**: {config['type'].upper()}")
419
- gr.Markdown(f"- **Description**: {config['desc']}")
420
- gr.Markdown(f"- **Model ID**: {config['id']}")
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  gr.Markdown("""
423
- ---
424
- ### πŸ’‘ How to Use
425
- - **Predict & Analyze**: Upload an image, select a model, adjust noise level, and run prediction to see top classes and attention maps.
426
- - **Class-Specific Attention**: Enter a class name or select from samples to visualize where the model focuses for that class.
427
- - **Adversarial Testing**: Use the noise slider to test model robustness against perturbations.
428
- - **Model Info**: Check the 'About Models' tab for details on available architectures.
429
- """, elem_classes=["footer-text"])
430
-
431
- # Event Handlers
432
- def update_class_input(selected_classes):
433
- return selected_classes[0] if selected_classes else ""
434
-
435
  run_button.click(
436
- fn=predict,
437
  inputs=[input_image, model_dropdown, noise_slider],
438
- outputs=[output_label, output_image, processed_image],
439
- show_progress=True
440
  )
441
-
 
442
  sample_buttons.change(
443
- fn=update_class_input,
444
  inputs=[sample_buttons],
445
  outputs=[class_input]
446
  )
447
-
 
448
  class_button.click(
449
- fn=get_class_specific_attention,
450
  inputs=[input_image, model_dropdown, class_input],
451
- outputs=[class_output_image, gradient_legend, class_status],
452
- show_progress=True
453
  )
454
 
455
- # Add custom CSS for improved styling
456
- gr.HTML("""
457
- <style>
458
- .header-text { font-size: 2rem; font-weight: bold; color: #1E3A8A; margin-bottom: 1rem; }
459
- .footer-text { font-size: 0.9rem; color: #4B5563; }
460
- .gr-button { transition: all 0.3s ease; }
461
- .gr-button:hover { transform: scale(1.05); }
462
- .gr-panel { border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
463
- .gr-image { border-radius: 8px; }
464
- .gr-accordion { margin-bottom: 1rem; }
465
- </style>
466
- """)
467
-
468
  if __name__ == "__main__":
469
  demo.launch()
 
11
  import json
12
  import cv2
13
 
14
+ # ---------------------------
15
  # Model Configs
16
+ # ---------------------------
17
  MODEL_CONFIGS = {
18
+ "DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"},
19
+ "DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"},
20
+ "ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"},
21
+ "ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"},
22
+ "ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano"},
23
+ "EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"},
24
+ "EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"},
25
+ "ResNet-50": {"type": "timm", "id": "resnet50"},
26
+ "MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"},
27
+ "MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"},
28
+ "MobileViT-Small": {"type": "timm", "id": "mobilevit_s"},
29
+ "EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"},
30
+ "RegNetY-002": {"type": "timm", "id": "regnety_002"}
31
  }
32
 
33
+ # ---------------------------
34
  # ImageNet Labels
35
+ # ---------------------------
36
  IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
37
  with urllib.request.urlopen(IMAGENET_URL) as url:
38
  IMAGENET_LABELS = json.load(url)
39
 
40
+ # ---------------------------
41
  # Lazy Load
42
+ # ---------------------------
43
  loaded_models = {}
44
 
45
  def load_model(model_name):
 
51
  extractor = AutoFeatureExtractor.from_pretrained(config["id"])
52
  model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
53
  model.eval()
54
+ # Enable gradients for class-specific attention
55
  for param in model.parameters():
56
  param.requires_grad = True
57
  elif config["type"] == "timm":
58
  model = timm.create_model(config["id"], pretrained=True)
59
  model.eval()
60
+ # Enable gradients for class-specific attention
61
  for param in model.parameters():
62
  param.requires_grad = True
63
  extractor = None
64
  elif config["type"] == "efficientnet":
65
  model = EfficientNet.from_pretrained(config["id"])
66
  model.eval()
67
+ # Enable gradients for class-specific attention
68
  for param in model.parameters():
69
  param.requires_grad = True
70
  extractor = None
 
72
  loaded_models[model_name] = (model, extractor)
73
  return model, extractor
74
 
75
+
76
+ # ---------------------------
77
  # Adversarial Noise
78
+ # ---------------------------
79
  def add_adversarial_noise(image, epsilon):
80
+ """Add random noise to image"""
81
  img_array = np.array(image).astype(np.float32) / 255.0
82
  noise = np.random.randn(*img_array.shape) * epsilon
83
  noisy_img = np.clip(img_array + noise, 0, 1)
84
  return Image.fromarray((noisy_img * 255).astype(np.uint8))
85
 
86
+
87
+ # ---------------------------
88
  # Grad-CAM for Class-Specific Attention
89
+ # ---------------------------
90
  def get_gradcam_for_class(model, image_tensor, class_idx):
91
  grad = None
92
  fmap = None
 
99
  nonlocal grad
100
  grad = grad_out[0].detach()
101
 
102
+ # Find last conv layer
103
  last_conv = None
104
  for name, module in reversed(list(model.named_modules())):
105
  if isinstance(module, torch.nn.Conv2d):
 
128
  cam = cam.squeeze().cpu().numpy()
129
  cam = cv2.resize(cam, (224, 224))
130
  cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
131
+
132
  return cam
133
 
134
+
135
+ # ---------------------------
136
  # ViT Attention for Class-Specific
137
+ # ---------------------------
138
  def vit_attention_for_class(model, extractor, image, class_idx):
139
+ """Get attention map for specific class in ViT"""
140
  inputs = extractor(images=image, return_tensors="pt")
141
  inputs['pixel_values'].requires_grad = True
142
  outputs = model(**inputs)
 
145
  model.zero_grad()
146
  score.backward()
147
 
148
+ # Use last layer attention
149
  if hasattr(outputs, 'attentions') and outputs.attentions is not None:
150
  attn = outputs.attentions[-1]
151
  attn = attn.mean(1)
 
157
 
158
  return np.ones((14, 14))
159
 
160
+
161
+ # ---------------------------
162
  # Grad-CAM Helper for CNNs (Top Prediction)
163
+ # ---------------------------
164
  def get_gradcam(model, image_tensor):
165
  grad = None
166
  fmap = None
 
202
  cam = cam.squeeze().cpu().numpy()
203
  cam = cv2.resize(cam, (224, 224))
204
  cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
205
+
206
  return cam
207
 
208
+
209
+ # ---------------------------
210
  # ViT Attention Rollout Helper
211
+ # ---------------------------
212
  def vit_attention_rollout(outputs):
213
  if not hasattr(outputs, 'attentions') or outputs.attentions is None:
214
  return np.ones((14, 14))
 
221
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
222
  return attn_map
223
 
224
+
225
+ # ---------------------------
226
  # Create Gradient Legend
227
+ # ---------------------------
228
  def create_gradient_legend():
229
+ """Create a gradient legend image showing attention scale"""
230
  width, height = 400, 60
231
  gradient = np.zeros((height, width, 3), dtype=np.uint8)
232
 
233
+ # Create gradient from blue to red (matching COLORMAP_JET)
234
  for i in range(width):
235
+ # OpenCV's COLORMAP_JET: blue (low) -> cyan -> green -> yellow -> red (high)
236
  value = int(255 * i / width)
237
  color_single = np.array([[[value]]], dtype=np.uint8)
238
  color_rgb = cv2.applyColorMap(color_single, cv2.COLORMAP_JET)
 
240
 
241
  gradient = cv2.cvtColor(gradient, cv2.COLOR_BGR2RGB)
242
 
243
+ # Convert to PIL and add text
244
  from PIL import ImageDraw, ImageFont
245
  gradient_pil = Image.fromarray(gradient)
246
  draw = ImageDraw.Draw(gradient_pil)
247
 
248
+ # Use default font
249
  try:
250
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
251
  except:
252
  font = ImageFont.load_default()
253
 
254
+ # Add text labels
255
  draw.text((10, 20), "Low Attention", fill=(255, 255, 255), font=font)
256
  draw.text((width - 120, 20), "High Attention", fill=(255, 255, 255), font=font)
257
 
258
  return gradient_pil
 
259
  def overlay_attention(pil_img, attention_map):
260
  heatmap = (attention_map * 255).astype(np.uint8)
261
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
 
265
  blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
266
  return blended
267
 
268
+
269
+ # ---------------------------
270
  # Main Prediction Function
271
+ # ---------------------------
272
  def predict(image, model_name, noise_level):
273
  try:
274
  if image is None:
 
277
  if model_name is None:
278
  return {"Error": "Please select a model"}, None, None
279
 
280
+ # Apply adversarial noise if requested
 
 
281
  if noise_level > 0:
282
  image = add_adversarial_noise(image, noise_level)
283
 
 
285
  transform = T.Compose([
286
  T.Resize((224, 224)),
287
  T.ToTensor(),
288
+ T.Normalize(mean=[0.485, 0.456, 0.406],
289
+ std=[0.229, 0.224, 0.225])
290
  ])
291
 
292
  if MODEL_CONFIGS[model_name]["type"] == "hf":
 
320
  print(error_msg)
321
  return {"Error": str(e)}, None, None
322
 
323
+
324
+ # ---------------------------
325
  # Class-Specific Attention
326
+ # ---------------------------
327
  def get_class_specific_attention(image, model_name, class_query):
328
  try:
329
  if image is None:
 
332
  if not class_query or class_query.strip() == "":
333
  return None, None, "Please enter a class name"
334
 
335
+ # Find matching class
 
 
336
  class_query_lower = class_query.lower().strip()
337
  matching_idx = None
338
  matched_label = None
 
340
  model, extractor = load_model(model_name)
341
 
342
  if MODEL_CONFIGS[model_name]["type"] == "hf":
343
+ # Search in HF model labels
344
  for idx, label in model.config.id2label.items():
345
  if class_query_lower in label.lower():
346
  matching_idx = idx
 
350
  if matching_idx is None:
351
  return None, None, f"Class '{class_query}' not found in model labels. Try a different class name or check sample classes."
352
 
353
+ # Get attention for this class
354
  att_map = vit_attention_for_class(model, extractor, image, matching_idx)
355
 
356
  else:
357
+ # Search in ImageNet labels
358
  for idx, label in enumerate(IMAGENET_LABELS):
359
  if class_query_lower in label.lower():
360
  matching_idx = idx
 
364
  if matching_idx is None:
365
  return None, None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check sample classes."
366
 
367
+ # Get Grad-CAM for this class
368
  transform = T.Compose([
369
  T.Resize((224, 224)),
370
  T.ToTensor(),
371
+ T.Normalize(mean=[0.485, 0.456, 0.406],
372
+ std=[0.229, 0.224, 0.225])
373
  ])
374
  x = transform(image).unsqueeze(0)
375
  x.requires_grad = True
 
385
  print(error_trace)
386
  return None, None, f"Error generating attention map: {str(e)}"
387
 
388
+
389
+ # ---------------------------
390
  # Sample Classes
391
+ # ---------------------------
392
  SAMPLE_CLASSES = [
393
  "cat", "dog", "tiger", "lion", "elephant",
394
  "car", "truck", "airplane", "ship", "train",
 
397
  "person", "bicycle", "building", "tree", "flower"
398
  ]
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
+ # ---------------------------
402
+ # Gradio UI
403
+ # ---------------------------
404
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
405
+ gr.Markdown("# 🧠 Enhanced Multi-Model Image Classifier")
406
+ gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 13+ Models")
407
+
408
+ with gr.Row():
409
+ with gr.Column(scale=1):
410
+ input_image = gr.Image(type="pil", label="πŸ“Έ Upload Image")
411
+ model_dropdown = gr.Dropdown(
412
+ choices=list(MODEL_CONFIGS.keys()),
413
+ label="πŸ€– Select Model",
414
+ value="DeiT-Tiny"
415
+ )
416
+
417
+ gr.Markdown("### 🎭 Adversarial Noise")
418
+ noise_slider = gr.Slider(
419
+ minimum=0,
420
+ maximum=0.3,
421
+ value=0,
422
+ step=0.01,
423
+ label="Noise Level (Ξ΅)",
424
+ info="Add random noise to test model robustness"
425
+ )
426
+
427
+ run_button = gr.Button("πŸš€ Run Model", variant="primary")
428
+
429
+ with gr.Column(scale=2):
430
+ output_label = gr.Label(num_top_classes=5, label="🎯 Top 5 Predictions")
431
+ output_image = gr.Image(label="πŸ” Attention Map (Top Prediction)")
432
+ processed_image = gr.Image(label="πŸ–ΌοΈ Processed Image (with noise if applied)", visible=False)
433
+
434
+ gr.Markdown("---")
435
+ gr.Markdown("### 🎨 Class-Specific Attention Visualization")
436
+ gr.Markdown("*Type any class name to see where the model looks for that specific object*")
437
+
438
+ with gr.Row():
439
+ with gr.Column(scale=1):
440
+ class_input = gr.Textbox(
441
+ label="πŸ” Enter Class Name",
442
+ placeholder="e.g., cat, dog, car, pizza...",
443
+ info="Type any ImageNet class name"
444
+ )
445
+ class_button = gr.Button("🎯 Generate Class-Specific Attention", variant="primary")
446
+ gr.Markdown("**πŸ’‘ Sample classes to try:**")
447
+ sample_buttons = gr.Radio(
448
+ choices=SAMPLE_CLASSES,
449
+ label="Click to auto-fill",
450
+ interactive=True
451
+ )
452
+
453
+ with gr.Column(scale=2):
454
+ class_output_image = gr.Image(label="πŸ” Class-Specific Attention Map")
455
+ class_status = gr.Textbox(label="Status", interactive=False)
456
+
457
+ gr.Markdown("---")
458
  gr.Markdown("""
459
+ ### πŸ’‘ Tips:
460
+ - **Adversarial Noise**: Adjust the slider to add random noise and see how robust the model is
461
+ - **Class-Specific Attention**: Type any ImageNet class to visualize what the model looks for (e.g., "tiger", "sports car", "pizza")
462
+ - **Model Variety**: Try different architectures (ViT, CNN, Hybrid) to compare their behavior
463
+ """)
464
+
465
+ # Event handlers
 
 
 
 
 
466
  run_button.click(
467
+ predict,
468
  inputs=[input_image, model_dropdown, noise_slider],
469
+ outputs=[output_label, output_image, processed_image]
 
470
  )
471
+
472
+ # When user selects a sample class, update the text input
473
  sample_buttons.change(
474
+ lambda x: x,
475
  inputs=[sample_buttons],
476
  outputs=[class_input]
477
  )
478
+
479
+ # Generate attention map
480
  class_button.click(
481
+ get_class_specific_attention,
482
  inputs=[input_image, model_dropdown, class_input],
483
+ outputs=[class_output_image, gradient_legend, class_status]
 
484
  )
485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  if __name__ == "__main__":
487
  demo.launch()