Godreign commited on
Commit
b7be9e3
Β·
verified Β·
1 Parent(s): 275c8e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -144
app.py CHANGED
@@ -10,36 +10,31 @@ import torchvision.transforms as T
10
  import urllib.request
11
  import json
12
  import cv2
 
13
 
14
- # ---------------------------
15
  # Model Configs
16
- # ---------------------------
17
  MODEL_CONFIGS = {
18
- "DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224"},
19
- "DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224"},
20
- "ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224"},
21
- "ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny"},
22
- "ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano"},
23
- "EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0"},
24
- "EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1"},
25
- "ResNet-50": {"type": "timm", "id": "resnet50"},
26
- "MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100"},
27
- "MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224"},
28
- "MobileViT-Small": {"type": "timm", "id": "mobilevit_s"},
29
- "EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small"},
30
- "RegNetY-002": {"type": "timm", "id": "regnety_002"}
31
  }
32
 
33
- # ---------------------------
34
  # ImageNet Labels
35
- # ---------------------------
36
  IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
37
  with urllib.request.urlopen(IMAGENET_URL) as url:
38
  IMAGENET_LABELS = json.load(url)
39
 
40
- # ---------------------------
41
  # Lazy Load
42
- # ---------------------------
43
  loaded_models = {}
44
 
45
  def load_model(model_name):
@@ -51,20 +46,17 @@ def load_model(model_name):
51
  extractor = AutoFeatureExtractor.from_pretrained(config["id"])
52
  model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
53
  model.eval()
54
- # Enable gradients for class-specific attention
55
  for param in model.parameters():
56
  param.requires_grad = True
57
  elif config["type"] == "timm":
58
  model = timm.create_model(config["id"], pretrained=True)
59
  model.eval()
60
- # Enable gradients for class-specific attention
61
  for param in model.parameters():
62
  param.requires_grad = True
63
  extractor = None
64
  elif config["type"] == "efficientnet":
65
  model = EfficientNet.from_pretrained(config["id"])
66
  model.eval()
67
- # Enable gradients for class-specific attention
68
  for param in model.parameters():
69
  param.requires_grad = True
70
  extractor = None
@@ -72,21 +64,14 @@ def load_model(model_name):
72
  loaded_models[model_name] = (model, extractor)
73
  return model, extractor
74
 
75
-
76
- # ---------------------------
77
  # Adversarial Noise
78
- # ---------------------------
79
  def add_adversarial_noise(image, epsilon):
80
- """Add random noise to image"""
81
  img_array = np.array(image).astype(np.float32) / 255.0
82
  noise = np.random.randn(*img_array.shape) * epsilon
83
  noisy_img = np.clip(img_array + noise, 0, 1)
84
  return Image.fromarray((noisy_img * 255).astype(np.uint8))
85
 
86
-
87
- # ---------------------------
88
  # Grad-CAM for Class-Specific Attention
89
- # ---------------------------
90
  def get_gradcam_for_class(model, image_tensor, class_idx):
91
  grad = None
92
  fmap = None
@@ -99,7 +84,6 @@ def get_gradcam_for_class(model, image_tensor, class_idx):
99
  nonlocal grad
100
  grad = grad_out[0].detach()
101
 
102
- # Find last conv layer
103
  last_conv = None
104
  for name, module in reversed(list(model.named_modules())):
105
  if isinstance(module, torch.nn.Conv2d):
@@ -128,15 +112,10 @@ def get_gradcam_for_class(model, image_tensor, class_idx):
128
  cam = cam.squeeze().cpu().numpy()
129
  cam = cv2.resize(cam, (224, 224))
130
  cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
131
-
132
  return cam
133
 
134
-
135
- # ---------------------------
136
  # ViT Attention for Class-Specific
137
- # ---------------------------
138
  def vit_attention_for_class(model, extractor, image, class_idx):
139
- """Get attention map for specific class in ViT"""
140
  inputs = extractor(images=image, return_tensors="pt")
141
  inputs['pixel_values'].requires_grad = True
142
  outputs = model(**inputs)
@@ -145,7 +124,6 @@ def vit_attention_for_class(model, extractor, image, class_idx):
145
  model.zero_grad()
146
  score.backward()
147
 
148
- # Use last layer attention
149
  if hasattr(outputs, 'attentions') and outputs.attentions is not None:
150
  attn = outputs.attentions[-1]
151
  attn = attn.mean(1)
@@ -157,10 +135,7 @@ def vit_attention_for_class(model, extractor, image, class_idx):
157
 
158
  return np.ones((14, 14))
159
 
160
-
161
- # ---------------------------
162
  # Grad-CAM Helper for CNNs (Top Prediction)
163
- # ---------------------------
164
  def get_gradcam(model, image_tensor):
165
  grad = None
166
  fmap = None
@@ -202,13 +177,9 @@ def get_gradcam(model, image_tensor):
202
  cam = cam.squeeze().cpu().numpy()
203
  cam = cv2.resize(cam, (224, 224))
204
  cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
205
-
206
  return cam
207
 
208
-
209
- # ---------------------------
210
  # ViT Attention Rollout Helper
211
- # ---------------------------
212
  def vit_attention_rollout(outputs):
213
  if not hasattr(outputs, 'attentions') or outputs.attentions is None:
214
  return np.ones((14, 14))
@@ -221,18 +192,12 @@ def vit_attention_rollout(outputs):
221
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
222
  return attn_map
223
 
224
-
225
- # ---------------------------
226
  # Create Gradient Legend
227
- # ---------------------------
228
  def create_gradient_legend():
229
- """Create a gradient legend image showing attention scale"""
230
  width, height = 400, 60
231
  gradient = np.zeros((height, width, 3), dtype=np.uint8)
232
 
233
- # Create gradient from blue to red (matching COLORMAP_JET)
234
  for i in range(width):
235
- # OpenCV's COLORMAP_JET: blue (low) -> cyan -> green -> yellow -> red (high)
236
  value = int(255 * i / width)
237
  color_single = np.array([[[value]]], dtype=np.uint8)
238
  color_rgb = cv2.applyColorMap(color_single, cv2.COLORMAP_JET)
@@ -240,22 +205,20 @@ def create_gradient_legend():
240
 
241
  gradient = cv2.cvtColor(gradient, cv2.COLOR_BGR2RGB)
242
 
243
- # Convert to PIL and add text
244
  from PIL import ImageDraw, ImageFont
245
  gradient_pil = Image.fromarray(gradient)
246
  draw = ImageDraw.Draw(gradient_pil)
247
 
248
- # Use default font
249
  try:
250
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
251
  except:
252
  font = ImageFont.load_default()
253
 
254
- # Add text labels
255
  draw.text((10, 20), "Low Attention", fill=(255, 255, 255), font=font)
256
  draw.text((width - 120, 20), "High Attention", fill=(255, 255, 255), font=font)
257
 
258
  return gradient_pil
 
259
  def overlay_attention(pil_img, attention_map):
260
  heatmap = (attention_map * 255).astype(np.uint8)
261
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
@@ -265,10 +228,7 @@ def overlay_attention(pil_img, attention_map):
265
  blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
266
  return blended
267
 
268
-
269
- # ---------------------------
270
  # Main Prediction Function
271
- # ---------------------------
272
  def predict(image, model_name, noise_level):
273
  try:
274
  if image is None:
@@ -277,7 +237,6 @@ def predict(image, model_name, noise_level):
277
  if model_name is None:
278
  return {"Error": "Please select a model"}, None, None
279
 
280
- # Apply adversarial noise if requested
281
  if noise_level > 0:
282
  image = add_adversarial_noise(image, noise_level)
283
 
@@ -285,8 +244,7 @@ def predict(image, model_name, noise_level):
285
  transform = T.Compose([
286
  T.Resize((224, 224)),
287
  T.ToTensor(),
288
- T.Normalize(mean=[0.485, 0.456, 0.406],
289
- std=[0.229, 0.224, 0.225])
290
  ])
291
 
292
  if MODEL_CONFIGS[model_name]["type"] == "hf":
@@ -320,10 +278,7 @@ def predict(image, model_name, noise_level):
320
  print(error_msg)
321
  return {"Error": str(e)}, None, None
322
 
323
-
324
- # ---------------------------
325
  # Class-Specific Attention
326
- # ---------------------------
327
  def get_class_specific_attention(image, model_name, class_query):
328
  try:
329
  if image is None:
@@ -332,7 +287,6 @@ def get_class_specific_attention(image, model_name, class_query):
332
  if not class_query or class_query.strip() == "":
333
  return None, None, "Please enter a class name"
334
 
335
- # Find matching class
336
  class_query_lower = class_query.lower().strip()
337
  matching_idx = None
338
  matched_label = None
@@ -340,7 +294,6 @@ def get_class_specific_attention(image, model_name, class_query):
340
  model, extractor = load_model(model_name)
341
 
342
  if MODEL_CONFIGS[model_name]["type"] == "hf":
343
- # Search in HF model labels
344
  for idx, label in model.config.id2label.items():
345
  if class_query_lower in label.lower():
346
  matching_idx = idx
@@ -350,11 +303,9 @@ def get_class_specific_attention(image, model_name, class_query):
350
  if matching_idx is None:
351
  return None, None, f"Class '{class_query}' not found in model labels. Try a different class name or check sample classes."
352
 
353
- # Get attention for this class
354
  att_map = vit_attention_for_class(model, extractor, image, matching_idx)
355
 
356
  else:
357
- # Search in ImageNet labels
358
  for idx, label in enumerate(IMAGENET_LABELS):
359
  if class_query_lower in label.lower():
360
  matching_idx = idx
@@ -364,12 +315,10 @@ def get_class_specific_attention(image, model_name, class_query):
364
  if matching_idx is None:
365
  return None, None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check sample classes."
366
 
367
- # Get Grad-CAM for this class
368
  transform = T.Compose([
369
  T.Resize((224, 224)),
370
  T.ToTensor(),
371
- T.Normalize(mean=[0.485, 0.456, 0.406],
372
- std=[0.229, 0.224, 0.225])
373
  ])
374
  x = transform(image).unsqueeze(0)
375
  x.requires_grad = True
@@ -385,10 +334,7 @@ def get_class_specific_attention(image, model_name, class_query):
385
  print(error_trace)
386
  return None, None, f"Error generating attention map: {str(e)}"
387
 
388
-
389
- # ---------------------------
390
  # Sample Classes
391
- # ---------------------------
392
  SAMPLE_CLASSES = [
393
  "cat", "dog", "tiger", "lion", "elephant",
394
  "car", "truck", "airplane", "ship", "train",
@@ -397,91 +343,122 @@ SAMPLE_CLASSES = [
397
  "person", "bicycle", "building", "tree", "flower"
398
  ]
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
- # ---------------------------
402
- # Gradio UI
403
- # ---------------------------
404
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
405
- gr.Markdown("# 🧠 Enhanced Multi-Model Image Classifier")
406
- gr.Markdown("### Features: Adversarial Examples | Class-Specific Attention | 13+ Models")
407
-
408
- with gr.Row():
409
- with gr.Column(scale=1):
410
- input_image = gr.Image(type="pil", label="πŸ“Έ Upload Image")
411
- model_dropdown = gr.Dropdown(
412
- choices=list(MODEL_CONFIGS.keys()),
413
- label="πŸ€– Select Model",
414
- value="DeiT-Tiny"
415
- )
416
-
417
- gr.Markdown("### 🎭 Adversarial Noise")
418
- noise_slider = gr.Slider(
419
- minimum=0,
420
- maximum=0.3,
421
- value=0,
422
- step=0.01,
423
- label="Noise Level (Ξ΅)",
424
- info="Add random noise to test model robustness"
425
- )
426
-
427
- run_button = gr.Button("πŸš€ Run Model", variant="primary")
428
-
429
- with gr.Column(scale=2):
430
- output_label = gr.Label(num_top_classes=5, label="🎯 Top 5 Predictions")
431
- output_image = gr.Image(label="πŸ” Attention Map (Top Prediction)")
432
- processed_image = gr.Image(label="πŸ–ΌοΈ Processed Image (with noise if applied)", visible=False)
433
-
434
- gr.Markdown("---")
435
- gr.Markdown("### 🎨 Class-Specific Attention Visualization")
436
- gr.Markdown("*Type any class name to see where the model looks for that specific object*")
437
-
438
- with gr.Row():
439
- with gr.Column(scale=1):
440
- class_input = gr.Textbox(
441
- label="πŸ” Enter Class Name",
442
- placeholder="e.g., cat, dog, car, pizza...",
443
- info="Type any ImageNet class name"
444
- )
445
- class_button = gr.Button("🎯 Generate Class-Specific Attention", variant="primary")
446
- gr.Markdown("**πŸ’‘ Sample classes to try:**")
447
- sample_buttons = gr.Radio(
448
- choices=SAMPLE_CLASSES,
449
- label="Click to auto-fill",
450
- interactive=True
451
- )
452
-
453
- with gr.Column(scale=2):
454
- class_output_image = gr.Image(label="πŸ” Class-Specific Attention Map")
455
- class_status = gr.Textbox(label="Status", interactive=False)
456
-
457
- gr.Markdown("---")
458
  gr.Markdown("""
459
- ### πŸ’‘ Tips:
460
- - **Adversarial Noise**: Adjust the slider to add random noise and see how robust the model is
461
- - **Class-Specific Attention**: Type any ImageNet class to visualize what the model looks for (e.g., "tiger", "sports car", "pizza")
462
- - **Model Variety**: Try different architectures (ViT, CNN, Hybrid) to compare their behavior
463
- """)
464
-
465
- # Event handlers
 
 
 
 
 
466
  run_button.click(
467
- predict,
468
  inputs=[input_image, model_dropdown, noise_slider],
469
- outputs=[output_label, output_image, processed_image]
 
470
  )
471
-
472
- # When user selects a sample class, update the text input
473
  sample_buttons.change(
474
- lambda x: x,
475
  inputs=[sample_buttons],
476
  outputs=[class_input]
477
  )
478
-
479
- # Generate attention map
480
  class_button.click(
481
- get_class_specific_attention,
482
  inputs=[input_image, model_dropdown, class_input],
483
- outputs=[class_output_image, gradient_legend, class_status]
 
484
  )
485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  if __name__ == "__main__":
487
  demo.launch()
 
10
  import urllib.request
11
  import json
12
  import cv2
13
+ import uuid
14
 
 
15
  # Model Configs
 
16
  MODEL_CONFIGS = {
17
+ "DeiT-Tiny": {"type": "hf", "id": "facebook/deit-tiny-patch16-224", "desc": "Lightweight Vision Transformer"},
18
+ "DeiT-Small": {"type": "hf", "id": "facebook/deit-small-patch16-224", "desc": "Small Vision Transformer"},
19
+ "ViT-Base": {"type": "hf", "id": "google/vit-base-patch16-224", "desc": "Base Vision Transformer"},
20
+ "ConvNeXt-Tiny": {"type": "timm", "id": "convnext_tiny", "desc": "Tiny ConvNeXt CNN"},
21
+ "ConvNeXt-Nano": {"type": "timm", "id": "convnext_nano", "desc": "Nano ConvNeXt CNN"},
22
+ "EfficientNet-B0": {"type": "efficientnet", "id": "efficientnet-b0", "desc": "EfficientNet B0"},
23
+ "EfficientNet-B1": {"type": "efficientnet", "id": "efficientnet-b1", "desc": "EfficientNet B1"},
24
+ "ResNet-50": {"type": "timm", "id": "resnet50", "desc": "Classic ResNet-50 CNN"},
25
+ "MobileNet-V2": {"type": "timm", "id": "mobilenetv2_100", "desc": "Lightweight MobileNet-V2"},
26
+ "MaxViT-Tiny": {"type": "timm", "id": "maxvit_tiny_tf_224", "desc": "Tiny MaxViT Hybrid"},
27
+ "MobileViT-Small": {"type": "timm", "id": "mobilevit_s", "desc": "Small MobileViT"},
28
+ "EdgeNeXt-Small": {"type": "timm", "id": "edgenext_small", "desc": "Small EdgeNeXt"},
29
+ "RegNetY-002": {"type": "timm", "id": "regnety_002", "desc": "RegNetY-002 CNN"}
30
  }
31
 
 
32
  # ImageNet Labels
 
33
  IMAGENET_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
34
  with urllib.request.urlopen(IMAGENET_URL) as url:
35
  IMAGENET_LABELS = json.load(url)
36
 
 
37
  # Lazy Load
 
38
  loaded_models = {}
39
 
40
  def load_model(model_name):
 
46
  extractor = AutoFeatureExtractor.from_pretrained(config["id"])
47
  model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
48
  model.eval()
 
49
  for param in model.parameters():
50
  param.requires_grad = True
51
  elif config["type"] == "timm":
52
  model = timm.create_model(config["id"], pretrained=True)
53
  model.eval()
 
54
  for param in model.parameters():
55
  param.requires_grad = True
56
  extractor = None
57
  elif config["type"] == "efficientnet":
58
  model = EfficientNet.from_pretrained(config["id"])
59
  model.eval()
 
60
  for param in model.parameters():
61
  param.requires_grad = True
62
  extractor = None
 
64
  loaded_models[model_name] = (model, extractor)
65
  return model, extractor
66
 
 
 
67
  # Adversarial Noise
 
68
  def add_adversarial_noise(image, epsilon):
 
69
  img_array = np.array(image).astype(np.float32) / 255.0
70
  noise = np.random.randn(*img_array.shape) * epsilon
71
  noisy_img = np.clip(img_array + noise, 0, 1)
72
  return Image.fromarray((noisy_img * 255).astype(np.uint8))
73
 
 
 
74
  # Grad-CAM for Class-Specific Attention
 
75
  def get_gradcam_for_class(model, image_tensor, class_idx):
76
  grad = None
77
  fmap = None
 
84
  nonlocal grad
85
  grad = grad_out[0].detach()
86
 
 
87
  last_conv = None
88
  for name, module in reversed(list(model.named_modules())):
89
  if isinstance(module, torch.nn.Conv2d):
 
112
  cam = cam.squeeze().cpu().numpy()
113
  cam = cv2.resize(cam, (224, 224))
114
  cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
 
115
  return cam
116
 
 
 
117
  # ViT Attention for Class-Specific
 
118
  def vit_attention_for_class(model, extractor, image, class_idx):
 
119
  inputs = extractor(images=image, return_tensors="pt")
120
  inputs['pixel_values'].requires_grad = True
121
  outputs = model(**inputs)
 
124
  model.zero_grad()
125
  score.backward()
126
 
 
127
  if hasattr(outputs, 'attentions') and outputs.attentions is not None:
128
  attn = outputs.attentions[-1]
129
  attn = attn.mean(1)
 
135
 
136
  return np.ones((14, 14))
137
 
 
 
138
  # Grad-CAM Helper for CNNs (Top Prediction)
 
139
  def get_gradcam(model, image_tensor):
140
  grad = None
141
  fmap = None
 
177
  cam = cam.squeeze().cpu().numpy()
178
  cam = cv2.resize(cam, (224, 224))
179
  cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
 
180
  return cam
181
 
 
 
182
  # ViT Attention Rollout Helper
 
183
  def vit_attention_rollout(outputs):
184
  if not hasattr(outputs, 'attentions') or outputs.attentions is None:
185
  return np.ones((14, 14))
 
192
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
193
  return attn_map
194
 
 
 
195
  # Create Gradient Legend
 
196
  def create_gradient_legend():
 
197
  width, height = 400, 60
198
  gradient = np.zeros((height, width, 3), dtype=np.uint8)
199
 
 
200
  for i in range(width):
 
201
  value = int(255 * i / width)
202
  color_single = np.array([[[value]]], dtype=np.uint8)
203
  color_rgb = cv2.applyColorMap(color_single, cv2.COLORMAP_JET)
 
205
 
206
  gradient = cv2.cvtColor(gradient, cv2.COLOR_BGR2RGB)
207
 
 
208
  from PIL import ImageDraw, ImageFont
209
  gradient_pil = Image.fromarray(gradient)
210
  draw = ImageDraw.Draw(gradient_pil)
211
 
 
212
  try:
213
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
214
  except:
215
  font = ImageFont.load_default()
216
 
 
217
  draw.text((10, 20), "Low Attention", fill=(255, 255, 255), font=font)
218
  draw.text((width - 120, 20), "High Attention", fill=(255, 255, 255), font=font)
219
 
220
  return gradient_pil
221
+
222
  def overlay_attention(pil_img, attention_map):
223
  heatmap = (attention_map * 255).astype(np.uint8)
224
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
 
228
  blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
229
  return blended
230
 
 
 
231
  # Main Prediction Function
 
232
  def predict(image, model_name, noise_level):
233
  try:
234
  if image is None:
 
237
  if model_name is None:
238
  return {"Error": "Please select a model"}, None, None
239
 
 
240
  if noise_level > 0:
241
  image = add_adversarial_noise(image, noise_level)
242
 
 
244
  transform = T.Compose([
245
  T.Resize((224, 224)),
246
  T.ToTensor(),
247
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
248
  ])
249
 
250
  if MODEL_CONFIGS[model_name]["type"] == "hf":
 
278
  print(error_msg)
279
  return {"Error": str(e)}, None, None
280
 
 
 
281
  # Class-Specific Attention
 
282
  def get_class_specific_attention(image, model_name, class_query):
283
  try:
284
  if image is None:
 
287
  if not class_query or class_query.strip() == "":
288
  return None, None, "Please enter a class name"
289
 
 
290
  class_query_lower = class_query.lower().strip()
291
  matching_idx = None
292
  matched_label = None
 
294
  model, extractor = load_model(model_name)
295
 
296
  if MODEL_CONFIGS[model_name]["type"] == "hf":
 
297
  for idx, label in model.config.id2label.items():
298
  if class_query_lower in label.lower():
299
  matching_idx = idx
 
303
  if matching_idx is None:
304
  return None, None, f"Class '{class_query}' not found in model labels. Try a different class name or check sample classes."
305
 
 
306
  att_map = vit_attention_for_class(model, extractor, image, matching_idx)
307
 
308
  else:
 
309
  for idx, label in enumerate(IMAGENET_LABELS):
310
  if class_query_lower in label.lower():
311
  matching_idx = idx
 
315
  if matching_idx is None:
316
  return None, None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check sample classes."
317
 
 
318
  transform = T.Compose([
319
  T.Resize((224, 224)),
320
  T.ToTensor(),
321
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
322
  ])
323
  x = transform(image).unsqueeze(0)
324
  x.requires_grad = True
 
334
  print(error_trace)
335
  return None, None, f"Error generating attention map: {str(e)}"
336
 
 
 
337
  # Sample Classes
 
338
  SAMPLE_CLASSES = [
339
  "cat", "dog", "tiger", "lion", "elephant",
340
  "car", "truck", "airplane", "ship", "train",
 
343
  "person", "bicycle", "building", "tree", "flower"
344
  ]
345
 
346
+ # Improved Gradio UI
347
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="gray", font=["Inter", "sans-serif"])) as demo:
348
+ gr.Markdown("""
349
+ # 🧠 Advanced Image Classification Studio
350
+ Explore state-of-the-art image classification with multiple models, adversarial testing, and attention visualization.
351
+ """, elem_classes=["header-text"])
352
+
353
+ with gr.Tabs() as tabs:
354
+ with gr.TabItem("πŸ” Predict & Analyze"):
355
+ with gr.Row(variant="panel"):
356
+ with gr.Column(scale=1, min_width=300):
357
+ gr.Markdown("### πŸ“· Input")
358
+ input_image = gr.Image(type="pil", label="Upload Image", height=300, interactive=True, tool="editor")
359
+ model_dropdown = gr.Dropdown(
360
+ choices=[f"{name} - {MODEL_CONFIGS[name]['desc']}" for name in MODEL_CONFIGS.keys()],
361
+ label="Select Model",
362
+ value="DeiT-Tiny - Lightweight Vision Transformer",
363
+ interactive=True,
364
+ info="Choose from various architectures (Transformers, CNNs, Hybrids)"
365
+ )
366
+ with gr.Group():
367
+ gr.Markdown("### 🎭 Adversarial Testing")
368
+ noise_slider = gr.Slider(
369
+ minimum=0, maximum=0.3, value=0, step=0.01,
370
+ label="Noise Level (Ξ΅)",
371
+ info="Add noise to test model robustness",
372
+ interactive=True
373
+ )
374
+ run_button = gr.Button("πŸš€ Run Prediction", variant="primary", scale=0)
375
+
376
+ with gr.Column(scale=2):
377
+ gr.Markdown("### πŸ“Š Results")
378
+ output_label = gr.Label(num_top_classes=5, label="Top 5 Predictions", show_label=True)
379
+ with gr.Row():
380
+ output_image = gr.Image(label="Attention Map (Top Prediction)", height=300)
381
+ processed_image = gr.Image(label="Processed Image (with noise)", height=300, visible=False)
382
+
383
+ with gr.TabItem("🎨 Class-Specific Attention"):
384
+ gr.Markdown("### Visualize Model Attention for Specific Classes")
385
+ with gr.Row(variant="panel"):
386
+ with gr.Column(scale=1, min_width=300):
387
+ class_input = gr.Textbox(
388
+ label="Enter Class Name",
389
+ placeholder="e.g., cat, dog, car, pizza...",
390
+ info="Type any ImageNet class name",
391
+ interactive=True
392
+ )
393
+ class_button = gr.Button("🎯 Generate Attention Map", variant="primary")
394
+ with gr.Accordion("πŸ’‘ Sample Classes", open=False):
395
+ sample_buttons = gr.CheckboxGroup(
396
+ choices=SAMPLE_CLASSES,
397
+ label="Select or click to auto-fill",
398
+ interactive=True
399
+ )
400
+
401
+ with gr.Column(scale=2):
402
+ class_output_image = gr.Image(label="Class-Specific Attention Map", height=300)
403
+ gradient_legend = gr.Image(label="Attention Scale", interactive=False)
404
+ class_status = gr.Textbox(label="Status", interactive=False, lines=2)
405
+
406
+ with gr.TabItem("ℹ️ About Models"):
407
+ gr.Markdown("""
408
+ ### Available Models
409
+ Explore different architectures and their strengths:
410
+ """)
411
+ for model_name, config in MODEL_CONFIGS.items():
412
+ with gr.Accordion(f"{model_name}", open=False):
413
+ gr.Markdown(f"- **Type**: {config['type'].upper()}")
414
+ gr.Markdown(f"- **Description**: {config['desc']}")
415
+ gr.Markdown(f"- **Model ID**: {config['id']}")
416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  gr.Markdown("""
418
+ ---
419
+ ### πŸ’‘ How to Use
420
+ - **Predict & Analyze**: Upload an image, select a model, adjust noise level, and run prediction to see top classes and attention maps.
421
+ - **Class-Specific Attention**: Enter a class name or select from samples to visualize where the model focuses for that class.
422
+ - **Adversarial Testing**: Use the noise slider to test model robustness against perturbations.
423
+ - **Model Info**: Check the 'About Models' tab for details on available architectures.
424
+ """, elem_classes=["footer-text"])
425
+
426
+ # Event Handlers
427
+ def update_class_input(selected_classes):
428
+ return selected_classes[0] if selected_classes else ""
429
+
430
  run_button.click(
431
+ fn=predict,
432
  inputs=[input_image, model_dropdown, noise_slider],
433
+ outputs=[output_label, output_image, processed_image],
434
+ show_progress=True
435
  )
436
+
 
437
  sample_buttons.change(
438
+ fn=update_class_input,
439
  inputs=[sample_buttons],
440
  outputs=[class_input]
441
  )
442
+
 
443
  class_button.click(
444
+ fn=get_class_specific_attention,
445
  inputs=[input_image, model_dropdown, class_input],
446
+ outputs=[class_output_image, gradient_legend, class_status],
447
+ show_progress=True
448
  )
449
 
450
+ # Add custom CSS for improved styling
451
+ gr.HTML("""
452
+ <style>
453
+ .header-text { font-size: 2rem; font-weight: bold; color: #1E3A8A; margin-bottom: 1rem; }
454
+ .footer-text { font-size: 0.9rem; color: #4B5563; }
455
+ .gr-button { transition: all 0.3s ease; }
456
+ .gr-button:hover { transform: scale(1.05); }
457
+ .gr-panel { border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
458
+ .gr-image { border-radius: 8px; }
459
+ .gr-accordion { margin-bottom: 1rem; }
460
+ </style>
461
+ """)
462
+
463
  if __name__ == "__main__":
464
  demo.launch()