changes in ui
Browse files
app.py
CHANGED
|
@@ -51,13 +51,22 @@ def load_model(model_name):
|
|
| 51 |
extractor = AutoFeatureExtractor.from_pretrained(config["id"])
|
| 52 |
model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
|
| 53 |
model.eval()
|
|
|
|
|
|
|
|
|
|
| 54 |
elif config["type"] == "timm":
|
| 55 |
model = timm.create_model(config["id"], pretrained=True)
|
| 56 |
model.eval()
|
|
|
|
|
|
|
|
|
|
| 57 |
extractor = None
|
| 58 |
elif config["type"] == "efficientnet":
|
| 59 |
model = EfficientNet.from_pretrained(config["id"])
|
| 60 |
model.eval()
|
|
|
|
|
|
|
|
|
|
| 61 |
extractor = None
|
| 62 |
|
| 63 |
loaded_models[model_name] = (model, extractor)
|
|
@@ -129,6 +138,7 @@ def get_gradcam_for_class(model, image_tensor, class_idx):
|
|
| 129 |
def vit_attention_for_class(model, extractor, image, class_idx):
|
| 130 |
"""Get attention map for specific class in ViT"""
|
| 131 |
inputs = extractor(images=image, return_tensors="pt")
|
|
|
|
| 132 |
outputs = model(**inputs)
|
| 133 |
|
| 134 |
score = outputs.logits[0, class_idx]
|
|
@@ -213,8 +223,39 @@ def vit_attention_rollout(outputs):
|
|
| 213 |
|
| 214 |
|
| 215 |
# ---------------------------
|
| 216 |
-
#
|
| 217 |
# ---------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
def overlay_attention(pil_img, attention_map):
|
| 219 |
heatmap = (attention_map * 255).astype(np.uint8)
|
| 220 |
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
|
@@ -286,10 +327,10 @@ def predict(image, model_name, noise_level):
|
|
| 286 |
def get_class_specific_attention(image, model_name, class_query):
|
| 287 |
try:
|
| 288 |
if image is None:
|
| 289 |
-
return None, "Please upload an image first"
|
| 290 |
|
| 291 |
if not class_query or class_query.strip() == "":
|
| 292 |
-
return None, "Please enter a class name"
|
| 293 |
|
| 294 |
# Find matching class
|
| 295 |
class_query_lower = class_query.lower().strip()
|
|
@@ -307,7 +348,7 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 307 |
break
|
| 308 |
|
| 309 |
if matching_idx is None:
|
| 310 |
-
return None, f"Class '{class_query}' not found in model labels. Try a different class name or check
|
| 311 |
|
| 312 |
# Get attention for this class
|
| 313 |
att_map = vit_attention_for_class(model, extractor, image, matching_idx)
|
|
@@ -321,7 +362,7 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 321 |
break
|
| 322 |
|
| 323 |
if matching_idx is None:
|
| 324 |
-
return None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check
|
| 325 |
|
| 326 |
# Get Grad-CAM for this class
|
| 327 |
transform = T.Compose([
|
|
@@ -335,13 +376,14 @@ def get_class_specific_attention(image, model_name, class_query):
|
|
| 335 |
att_map = get_gradcam_for_class(model, x, matching_idx)
|
| 336 |
|
| 337 |
overlay = overlay_attention(image, att_map)
|
| 338 |
-
|
|
|
|
| 339 |
|
| 340 |
except Exception as e:
|
| 341 |
import traceback
|
| 342 |
error_trace = traceback.format_exc()
|
| 343 |
print(error_trace)
|
| 344 |
-
return None, f"Error generating attention map: {str(e)}"
|
| 345 |
|
| 346 |
|
| 347 |
# ---------------------------
|
|
@@ -400,16 +442,17 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 400 |
placeholder="e.g., cat, dog, car, pizza...",
|
| 401 |
info="Type any ImageNet class name"
|
| 402 |
)
|
|
|
|
| 403 |
gr.Markdown("**π‘ Sample classes to try:**")
|
| 404 |
sample_buttons = gr.Radio(
|
| 405 |
choices=SAMPLE_CLASSES,
|
| 406 |
label="Click to auto-fill",
|
| 407 |
interactive=True
|
| 408 |
)
|
| 409 |
-
class_button = gr.Button("π― Generate Class-Specific Attention", variant="primary")
|
| 410 |
|
| 411 |
with gr.Column(scale=2):
|
| 412 |
class_output_image = gr.Image(label="π Class-Specific Attention Map")
|
|
|
|
| 413 |
class_status = gr.Textbox(label="Status", interactive=False)
|
| 414 |
|
| 415 |
gr.Markdown("---")
|
|
@@ -438,7 +481,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 438 |
class_button.click(
|
| 439 |
get_class_specific_attention,
|
| 440 |
inputs=[input_image, model_dropdown, class_input],
|
| 441 |
-
outputs=[class_output_image, class_status]
|
| 442 |
)
|
| 443 |
|
| 444 |
if __name__ == "__main__":
|
|
|
|
| 51 |
extractor = AutoFeatureExtractor.from_pretrained(config["id"])
|
| 52 |
model = AutoModelForImageClassification.from_pretrained(config["id"], output_attentions=True)
|
| 53 |
model.eval()
|
| 54 |
+
# Enable gradients for class-specific attention
|
| 55 |
+
for param in model.parameters():
|
| 56 |
+
param.requires_grad = True
|
| 57 |
elif config["type"] == "timm":
|
| 58 |
model = timm.create_model(config["id"], pretrained=True)
|
| 59 |
model.eval()
|
| 60 |
+
# Enable gradients for class-specific attention
|
| 61 |
+
for param in model.parameters():
|
| 62 |
+
param.requires_grad = True
|
| 63 |
extractor = None
|
| 64 |
elif config["type"] == "efficientnet":
|
| 65 |
model = EfficientNet.from_pretrained(config["id"])
|
| 66 |
model.eval()
|
| 67 |
+
# Enable gradients for class-specific attention
|
| 68 |
+
for param in model.parameters():
|
| 69 |
+
param.requires_grad = True
|
| 70 |
extractor = None
|
| 71 |
|
| 72 |
loaded_models[model_name] = (model, extractor)
|
|
|
|
| 138 |
def vit_attention_for_class(model, extractor, image, class_idx):
|
| 139 |
"""Get attention map for specific class in ViT"""
|
| 140 |
inputs = extractor(images=image, return_tensors="pt")
|
| 141 |
+
inputs['pixel_values'].requires_grad = True
|
| 142 |
outputs = model(**inputs)
|
| 143 |
|
| 144 |
score = outputs.logits[0, class_idx]
|
|
|
|
| 223 |
|
| 224 |
|
| 225 |
# ---------------------------
|
| 226 |
+
# Create Gradient Legend
|
| 227 |
# ---------------------------
|
| 228 |
+
def create_gradient_legend():
|
| 229 |
+
"""Create a gradient legend image showing attention scale"""
|
| 230 |
+
width, height = 400, 60
|
| 231 |
+
gradient = np.zeros((height, width, 3), dtype=np.uint8)
|
| 232 |
+
|
| 233 |
+
# Create gradient from blue to red (matching COLORMAP_JET)
|
| 234 |
+
for i in range(width):
|
| 235 |
+
# OpenCV's COLORMAP_JET: blue (low) -> cyan -> green -> yellow -> red (high)
|
| 236 |
+
value = int(255 * i / width)
|
| 237 |
+
color_single = np.array([[[value]]], dtype=np.uint8)
|
| 238 |
+
color_rgb = cv2.applyColorMap(color_single, cv2.COLORMAP_JET)
|
| 239 |
+
gradient[:, i] = color_rgb[0, 0]
|
| 240 |
+
|
| 241 |
+
gradient = cv2.cvtColor(gradient, cv2.COLOR_BGR2RGB)
|
| 242 |
+
|
| 243 |
+
# Convert to PIL and add text
|
| 244 |
+
from PIL import ImageDraw, ImageFont
|
| 245 |
+
gradient_pil = Image.fromarray(gradient)
|
| 246 |
+
draw = ImageDraw.Draw(gradient_pil)
|
| 247 |
+
|
| 248 |
+
# Use default font
|
| 249 |
+
try:
|
| 250 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
|
| 251 |
+
except:
|
| 252 |
+
font = ImageFont.load_default()
|
| 253 |
+
|
| 254 |
+
# Add text labels
|
| 255 |
+
draw.text((10, 20), "Low Attention", fill=(255, 255, 255), font=font)
|
| 256 |
+
draw.text((width - 120, 20), "High Attention", fill=(255, 255, 255), font=font)
|
| 257 |
+
|
| 258 |
+
return gradient_pil
|
| 259 |
def overlay_attention(pil_img, attention_map):
|
| 260 |
heatmap = (attention_map * 255).astype(np.uint8)
|
| 261 |
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
|
|
|
| 327 |
def get_class_specific_attention(image, model_name, class_query):
|
| 328 |
try:
|
| 329 |
if image is None:
|
| 330 |
+
return None, None, "Please upload an image first"
|
| 331 |
|
| 332 |
if not class_query or class_query.strip() == "":
|
| 333 |
+
return None, None, "Please enter a class name"
|
| 334 |
|
| 335 |
# Find matching class
|
| 336 |
class_query_lower = class_query.lower().strip()
|
|
|
|
| 348 |
break
|
| 349 |
|
| 350 |
if matching_idx is None:
|
| 351 |
+
return None, None, f"Class '{class_query}' not found in model labels. Try a different class name or check sample classes."
|
| 352 |
|
| 353 |
# Get attention for this class
|
| 354 |
att_map = vit_attention_for_class(model, extractor, image, matching_idx)
|
|
|
|
| 362 |
break
|
| 363 |
|
| 364 |
if matching_idx is None:
|
| 365 |
+
return None, None, f"Class '{class_query}' not found in ImageNet labels. Try a different class name or check sample classes."
|
| 366 |
|
| 367 |
# Get Grad-CAM for this class
|
| 368 |
transform = T.Compose([
|
|
|
|
| 376 |
att_map = get_gradcam_for_class(model, x, matching_idx)
|
| 377 |
|
| 378 |
overlay = overlay_attention(image, att_map)
|
| 379 |
+
legend = create_gradient_legend()
|
| 380 |
+
return overlay, legend, f"β Attention map generated for class: '{matched_label}' (Index: {matching_idx})"
|
| 381 |
|
| 382 |
except Exception as e:
|
| 383 |
import traceback
|
| 384 |
error_trace = traceback.format_exc()
|
| 385 |
print(error_trace)
|
| 386 |
+
return None, None, f"Error generating attention map: {str(e)}"
|
| 387 |
|
| 388 |
|
| 389 |
# ---------------------------
|
|
|
|
| 442 |
placeholder="e.g., cat, dog, car, pizza...",
|
| 443 |
info="Type any ImageNet class name"
|
| 444 |
)
|
| 445 |
+
class_button = gr.Button("π― Generate Class-Specific Attention", variant="primary")
|
| 446 |
gr.Markdown("**π‘ Sample classes to try:**")
|
| 447 |
sample_buttons = gr.Radio(
|
| 448 |
choices=SAMPLE_CLASSES,
|
| 449 |
label="Click to auto-fill",
|
| 450 |
interactive=True
|
| 451 |
)
|
|
|
|
| 452 |
|
| 453 |
with gr.Column(scale=2):
|
| 454 |
class_output_image = gr.Image(label="π Class-Specific Attention Map")
|
| 455 |
+
gradient_legend = gr.Image(label="π Attention Scale", show_label=True)
|
| 456 |
class_status = gr.Textbox(label="Status", interactive=False)
|
| 457 |
|
| 458 |
gr.Markdown("---")
|
|
|
|
| 481 |
class_button.click(
|
| 482 |
get_class_specific_attention,
|
| 483 |
inputs=[input_image, model_dropdown, class_input],
|
| 484 |
+
outputs=[class_output_image, gradient_legend, class_status]
|
| 485 |
)
|
| 486 |
|
| 487 |
if __name__ == "__main__":
|