Godreign commited on
Commit
b9cbc71
Β·
verified Β·
1 Parent(s): 3bdd51c

changes in ui

Browse files
Files changed (1) hide show
  1. app.py +52 -9
app.py CHANGED
@@ -51,13 +51,22 @@ def load_model(model_name):
51
  extractor = AutoFeatureExtractor.from_pretrained(config["id"])
52
  model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
53
  model.eval()
 
 
 
54
  elif config["type"] == "timm":
55
  model = timm.create_model(config["id"], pretrained=True)
56
  model.eval()
 
 
 
57
  extractor = None
58
  elif config["type"] == "efficientnet":
59
  model = EfficientNet.from_pretrained(config["id"])
60
  model.eval()
 
 
 
61
  extractor = None
62
 
63
  loaded_models[model_name] = (model, extractor)
@@ -129,6 +138,7 @@ def get_gradcam_for_class(model, image_tensor, class_idx):
129
  def vit_attention_for_class(model, extractor, image, class_idx):
130
  """Get attention map for specific class in ViT"""
131
  inputs = extractor(images=image, return_tensors="pt")
 
132
  outputs = model(**inputs)
133
 
134
  score = outputs.logits[0, class_idx]
@@ -213,8 +223,39 @@ def vit_attention_rollout(outputs):
213
 
214
 
215
  # ---------------------------
216
- # Overlay Helper
217
  # ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  def overlay_attention(pil_img, attention_map):
219
  heatmap = (attention_map * 255).astype(np.uint8)
220
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
@@ -286,10 +327,10 @@ def predict(image, model_name, noise_level):
286
  def get_class_specific_attention(image, model_name, class_query):
287
  try:
288
  if image is None:
289
- return None, "Please upload an image first"
290
 
291
  if not class_query or class_query.strip() == "":
292
- return None, "Please enter a class name"
293
 
294
  # Find matching class
295
  class_query_lower = class_query.lower().strip()
@@ -307,7 +348,7 @@ def get_class_specific_attention(image, model_name, class_query):
307
  break
308
 
309
  if matching_idx is None:
310
- return None, f"Class '{class_query}' not found in model labels. Try a different class name or check suggestions."
311
 
312
  # Get attention for this class
313
  att_map = vit_attention_for_class(model, extractor, image, matching_idx)
@@ -321,7 +362,7 @@ def get_class_specific_attention(image, model_name, class_query):
321
  break
322
 
323
  if matching_idx is None:
324
- return None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check suggestions."
325
 
326
  # Get Grad-CAM for this class
327
  transform = T.Compose([
@@ -335,13 +376,14 @@ def get_class_specific_attention(image, model_name, class_query):
335
  att_map = get_gradcam_for_class(model, x, matching_idx)
336
 
337
  overlay = overlay_attention(image, att_map)
338
- return overlay, f"βœ“ Attention map generated for class: '{matched_label}'"
 
339
 
340
  except Exception as e:
341
  import traceback
342
  error_trace = traceback.format_exc()
343
  print(error_trace)
344
- return None, f"Error generating attention map: {str(e)}"
345
 
346
 
347
  # ---------------------------
@@ -400,16 +442,17 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
400
  placeholder="e.g., cat, dog, car, pizza...",
401
  info="Type any ImageNet class name"
402
  )
 
403
  gr.Markdown("**πŸ’‘ Sample classes to try:**")
404
  sample_buttons = gr.Radio(
405
  choices=SAMPLE_CLASSES,
406
  label="Click to auto-fill",
407
  interactive=True
408
  )
409
- class_button = gr.Button("🎯 Generate Class-Specific Attention", variant="primary")
410
 
411
  with gr.Column(scale=2):
412
  class_output_image = gr.Image(label="πŸ” Class-Specific Attention Map")
 
413
  class_status = gr.Textbox(label="Status", interactive=False)
414
 
415
  gr.Markdown("---")
@@ -438,7 +481,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
438
  class_button.click(
439
  get_class_specific_attention,
440
  inputs=[input_image, model_dropdown, class_input],
441
- outputs=[class_output_image, class_status]
442
  )
443
 
444
  if __name__ == "__main__":
 
51
  extractor = AutoFeatureExtractor.from_pretrained(config["id"])
52
  model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
53
  model.eval()
54
+ # Enable gradients for class-specific attention
55
+ for param in model.parameters():
56
+ param.requires_grad = True
57
  elif config["type"] == "timm":
58
  model = timm.create_model(config["id"], pretrained=True)
59
  model.eval()
60
+ # Enable gradients for class-specific attention
61
+ for param in model.parameters():
62
+ param.requires_grad = True
63
  extractor = None
64
  elif config["type"] == "efficientnet":
65
  model = EfficientNet.from_pretrained(config["id"])
66
  model.eval()
67
+ # Enable gradients for class-specific attention
68
+ for param in model.parameters():
69
+ param.requires_grad = True
70
  extractor = None
71
 
72
  loaded_models[model_name] = (model, extractor)
 
138
  def vit_attention_for_class(model, extractor, image, class_idx):
139
  """Get attention map for specific class in ViT"""
140
  inputs = extractor(images=image, return_tensors="pt")
141
+ inputs['pixel_values'].requires_grad = True
142
  outputs = model(**inputs)
143
 
144
  score = outputs.logits[0, class_idx]
 
223
 
224
 
225
  # ---------------------------
226
+ # Create Gradient Legend
227
  # ---------------------------
228
+ def create_gradient_legend():
229
+ """Create a gradient legend image showing attention scale"""
230
+ width, height = 400, 60
231
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
232
+
233
+ # Create gradient from blue to red (matching COLORMAP_JET)
234
+ for i in range(width):
235
+ # OpenCV's COLORMAP_JET: blue (low) -> cyan -> green -> yellow -> red (high)
236
+ value = int(255 * i / width)
237
+ color_single = np.array([[[value]]], dtype=np.uint8)
238
+ color_rgb = cv2.applyColorMap(color_single, cv2.COLORMAP_JET)
239
+ gradient[:, i] = color_rgb[0, 0]
240
+
241
+ gradient = cv2.cvtColor(gradient, cv2.COLOR_BGR2RGB)
242
+
243
+ # Convert to PIL and add text
244
+ from PIL import ImageDraw, ImageFont
245
+ gradient_pil = Image.fromarray(gradient)
246
+ draw = ImageDraw.Draw(gradient_pil)
247
+
248
+ # Use default font
249
+ try:
250
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
251
+ except:
252
+ font = ImageFont.load_default()
253
+
254
+ # Add text labels
255
+ draw.text((10, 20), "Low Attention", fill=(255, 255, 255), font=font)
256
+ draw.text((width - 120, 20), "High Attention", fill=(255, 255, 255), font=font)
257
+
258
+ return gradient_pil
259
  def overlay_attention(pil_img, attention_map):
260
  heatmap = (attention_map * 255).astype(np.uint8)
261
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
 
327
  def get_class_specific_attention(image, model_name, class_query):
328
  try:
329
  if image is None:
330
+ return None, None, "Please upload an image first"
331
 
332
  if not class_query or class_query.strip() == "":
333
+ return None, None, "Please enter a class name"
334
 
335
  # Find matching class
336
  class_query_lower = class_query.lower().strip()
 
348
  break
349
 
350
  if matching_idx is None:
351
+ return None, None, f"Class '{class_query}' not found in model labels. Try a different class name or check sample classes."
352
 
353
  # Get attention for this class
354
  att_map = vit_attention_for_class(model, extractor, image, matching_idx)
 
362
  break
363
 
364
  if matching_idx is None:
365
+ return None, None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check sample classes."
366
 
367
  # Get Grad-CAM for this class
368
  transform = T.Compose([
 
376
  att_map = get_gradcam_for_class(model, x, matching_idx)
377
 
378
  overlay = overlay_attention(image, att_map)
379
+ legend = create_gradient_legend()
380
+ return overlay, legend, f"βœ“ Attention map generated for class: '{matched_label}' (Index: {matching_idx})"
381
 
382
  except Exception as e:
383
  import traceback
384
  error_trace = traceback.format_exc()
385
  print(error_trace)
386
+ return None, None, f"Error generating attention map: {str(e)}"
387
 
388
 
389
  # ---------------------------
 
442
  placeholder="e.g., cat, dog, car, pizza...",
443
  info="Type any ImageNet class name"
444
  )
445
+ class_button = gr.Button("🎯 Generate Class-Specific Attention", variant="primary")
446
  gr.Markdown("**πŸ’‘ Sample classes to try:**")
447
  sample_buttons = gr.Radio(
448
  choices=SAMPLE_CLASSES,
449
  label="Click to auto-fill",
450
  interactive=True
451
  )
 
452
 
453
  with gr.Column(scale=2):
454
  class_output_image = gr.Image(label="πŸ” Class-Specific Attention Map")
455
+ gradient_legend = gr.Image(label="πŸ“Š Attention Scale", show_label=True)
456
  class_status = gr.Textbox(label="Status", interactive=False)
457
 
458
  gr.Markdown("---")
 
481
  class_button.click(
482
  get_class_specific_attention,
483
  inputs=[input_image, model_dropdown, class_input],
484
+ outputs=[class_output_image, gradient_legend, class_status]
485
  )
486
 
487
  if __name__ == "__main__":