Godreign commited on
Commit
aa41700
·
verified ·
1 Parent(s): 376d711

confidence value added

Browse files
Files changed (1) hide show
  1. app.py +10 -16
app.py CHANGED
@@ -69,7 +69,6 @@ def load_model(model_name):
69
  loaded_models[model_name] = (model, extractor)
70
  return model, extractor
71
 
72
-
73
  # ---------------------------
74
  # Adversarial Noise
75
  # ---------------------------
@@ -79,7 +78,6 @@ def add_adversarial_noise(image, epsilon):
79
  noisy_img = np.clip(img_array + noise, 0, 1)
80
  return Image.fromarray((noisy_img * 255).astype(np.uint8))
81
 
82
-
83
  # ---------------------------
84
  # Grad-CAM for Class-Specific Attention
85
  # ---------------------------
@@ -126,7 +124,6 @@ def get_gradcam_for_class(model, image_tensor, class_idx):
126
 
127
  return cam
128
 
129
-
130
  # ---------------------------
131
  # ViT Attention for Class-Specific
132
  # ---------------------------
@@ -146,10 +143,10 @@ def vit_attention_for_class(model, extractor, image, class_idx):
146
  attn_map = attn.reshape(1, 14, 14)
147
  attn_map = attn_map.squeeze().detach().cpu().numpy()
148
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
149
- return attn_map
 
150
 
151
- return np.ones((14, 14))
152
-
153
 
154
  # ---------------------------
155
  # Grad-CAM Helper for CNNs
@@ -198,14 +195,12 @@ def get_gradcam(model, image_tensor):
198
 
199
  return cam
200
 
201
-
202
  # ---------------------------
203
  # ViT Attention Rollout
204
  # ---------------------------
205
  def vit_attention_rollout(outputs):
206
  if not hasattr(outputs, 'attentions') or outputs.attentions is None:
207
  return np.ones((14, 14))
208
-
209
  attn = outputs.attentions[-1]
210
  attn = attn.mean(1)
211
  attn = attn[:, 0, 1:]
@@ -214,7 +209,6 @@ def vit_attention_rollout(outputs):
214
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
215
  return attn_map
216
 
217
-
218
  # ---------------------------
219
  # Overlay Attention on Image
220
  # ---------------------------
@@ -227,7 +221,6 @@ def overlay_attention(pil_img, attention_map):
227
  blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
228
  return blended
229
 
230
-
231
  # ---------------------------
232
  # Main Prediction Function
233
  # ---------------------------
@@ -280,9 +273,8 @@ def predict(image, model_name, noise_level):
280
  print(f"Error: {traceback.format_exc()}")
281
  return {"Error": str(e)}, None, None
282
 
283
-
284
  # ---------------------------
285
- # Class-Specific Attention
286
  # ---------------------------
287
  def get_class_specific_attention(image, model_name, class_query):
288
  try:
@@ -295,6 +287,7 @@ def get_class_specific_attention(image, model_name, class_query):
295
  class_query_lower = class_query.lower().strip()
296
  matching_idx = None
297
  matched_label = None
 
298
 
299
  model, extractor = load_model(model_name)
300
 
@@ -308,7 +301,7 @@ def get_class_specific_attention(image, model_name, class_query):
308
  if matching_idx is None:
309
  return None, f"Class '{class_query}' not found in model labels."
310
 
311
- att_map = vit_attention_for_class(model, extractor, image, matching_idx)
312
 
313
  else:
314
  for idx, label in enumerate(IMAGENET_LABELS):
@@ -329,16 +322,18 @@ def get_class_specific_attention(image, model_name, class_query):
329
  x = transform(image).unsqueeze(0)
330
  x.requires_grad = True
331
  att_map = get_gradcam_for_class(model, x, matching_idx)
 
 
 
332
 
333
  overlay = overlay_attention(image, att_map)
334
- return overlay, f"✓ Attention map generated for class: '{matched_label}' (Index: {matching_idx})"
335
 
336
  except Exception as e:
337
  import traceback
338
  print(traceback.format_exc())
339
  return None, f"Error generating attention map: {str(e)}"
340
 
341
-
342
  # ---------------------------
343
  # Sample Classes
344
  # ---------------------------
@@ -350,7 +345,6 @@ SAMPLE_CLASSES = [
350
  "person", "bicycle", "building", "tree", "flower"
351
  ]
352
 
353
-
354
  # ---------------------------
355
  # Gradio UI
356
  # ---------------------------
 
69
  loaded_models[model_name] = (model, extractor)
70
  return model, extractor
71
 
 
72
  # ---------------------------
73
  # Adversarial Noise
74
  # ---------------------------
 
78
  noisy_img = np.clip(img_array + noise, 0, 1)
79
  return Image.fromarray((noisy_img * 255).astype(np.uint8))
80
 
 
81
  # ---------------------------
82
  # Grad-CAM for Class-Specific Attention
83
  # ---------------------------
 
124
 
125
  return cam
126
 
 
127
  # ---------------------------
128
  # ViT Attention for Class-Specific
129
  # ---------------------------
 
143
  attn_map = attn.reshape(1, 14, 14)
144
  attn_map = attn_map.squeeze().detach().cpu().numpy()
145
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
146
+ prob = F.softmax(outputs.logits, dim=-1)[0, class_idx].item()
147
+ return attn_map, prob
148
 
149
+ return np.ones((14, 14)), 0.0
 
150
 
151
  # ---------------------------
152
  # Grad-CAM Helper for CNNs
 
195
 
196
  return cam
197
 
 
198
  # ---------------------------
199
  # ViT Attention Rollout
200
  # ---------------------------
201
  def vit_attention_rollout(outputs):
202
  if not hasattr(outputs, 'attentions') or outputs.attentions is None:
203
  return np.ones((14, 14))
 
204
  attn = outputs.attentions[-1]
205
  attn = attn.mean(1)
206
  attn = attn[:, 0, 1:]
 
209
  attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
210
  return attn_map
211
 
 
212
  # ---------------------------
213
  # Overlay Attention on Image
214
  # ---------------------------
 
221
  blended = Image.blend(pil_img.convert("RGB"), heatmap_pil, alpha=0.4)
222
  return blended
223
 
 
224
  # ---------------------------
225
  # Main Prediction Function
226
  # ---------------------------
 
273
  print(f"Error: {traceback.format_exc()}")
274
  return {"Error": str(e)}, None, None
275
 
 
276
  # ---------------------------
277
+ # Class-Specific Attention with Confidence
278
  # ---------------------------
279
  def get_class_specific_attention(image, model_name, class_query):
280
  try:
 
287
  class_query_lower = class_query.lower().strip()
288
  matching_idx = None
289
  matched_label = None
290
+ confidence = 0.0
291
 
292
  model, extractor = load_model(model_name)
293
 
 
301
  if matching_idx is None:
302
  return None, f"Class '{class_query}' not found in model labels."
303
 
304
+ att_map, confidence = vit_attention_for_class(model, extractor, image, matching_idx)
305
 
306
  else:
307
  for idx, label in enumerate(IMAGENET_LABELS):
 
322
  x = transform(image).unsqueeze(0)
323
  x.requires_grad = True
324
  att_map = get_gradcam_for_class(model, x, matching_idx)
325
+ with torch.no_grad():
326
+ outputs = model(x)
327
+ confidence = F.softmax(outputs, dim=-1)[0, matching_idx].item()
328
 
329
  overlay = overlay_attention(image, att_map)
330
+ return overlay, f"✓ Attention map generated for class: '{matched_label}' (Index: {matching_idx}, Confidence: {confidence:.2f})"
331
 
332
  except Exception as e:
333
  import traceback
334
  print(traceback.format_exc())
335
  return None, f"Error generating attention map: {str(e)}"
336
 
 
337
  # ---------------------------
338
  # Sample Classes
339
  # ---------------------------
 
345
  "person", "bicycle", "building", "tree", "flower"
346
  ]
347
 
 
348
  # ---------------------------
349
  # Gradio UI
350
  # ---------------------------