Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import timm | |
| import torch.nn.functional as F | |
| from timm.models import create_model | |
| from timm.data import create_transform | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from typing import List, Tuple, Dict | |
| from collections import OrderedDict | |
| class AttentionExtractor: | |
| def __init__(self, model: torch.nn.Module): | |
| self.model = model | |
| self.attention_maps = OrderedDict() | |
| self._register_hooks() | |
| def _register_hooks(self): | |
| def hook_fn(module, input, output): | |
| if isinstance(output, tuple): | |
| self.attention_maps[module.full_name] = output[1] # attention_probs | |
| else: | |
| self.attention_maps[module.full_name] = output | |
| for name, module in self.model.named_modules(): | |
| if name.lower().endswith('.attn_drop'): | |
| module.full_name = name | |
| print('hooking', name) | |
| module.register_forward_hook(hook_fn) | |
| def get_attention_maps(self) -> OrderedDict: | |
| return self.attention_maps | |
| def get_attention_models() -> List[str]: | |
| """Get a list of timm models that have attention blocks.""" | |
| all_models = timm.list_models() | |
| attention_models = [model for model in all_models if 'vit' in model.lower()] # Focusing on ViT models for simplicity | |
| return attention_models | |
| def load_model(model_name: str) -> Tuple[torch.nn.Module, AttentionExtractor]: | |
| """Load a model from timm and prepare it for attention extraction.""" | |
| timm.layers.set_fused_attn(False) | |
| model = create_model(model_name, pretrained=True) | |
| model.eval() | |
| extractor = AttentionExtractor(model) | |
| return model, extractor | |
| def process_image(image: Image.Image, model: torch.nn.Module, extractor: AttentionExtractor) -> Dict[str, torch.Tensor]: | |
| """Process the input image and get the attention maps.""" | |
| # Get the correct transform for the model | |
| config = model.pretrained_cfg | |
| transform = create_transform( | |
| input_size=config['input_size'], | |
| crop_pct=config['crop_pct'], | |
| mean=config['mean'], | |
| std=config['std'], | |
| interpolation=config['interpolation'], | |
| is_training=False | |
| ) | |
| # Preprocess the image | |
| tensor = transform(image).unsqueeze(0) | |
| # Forward pass | |
| with torch.no_grad(): | |
| _ = model(tensor) | |
| # Extract attention maps | |
| attention_maps = extractor.get_attention_maps() | |
| return attention_maps | |
| def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray: | |
| # Ensure mask and image have the same shape | |
| mask = mask[:, :, np.newaxis] | |
| mask = np.repeat(mask, 3, axis=2) | |
| # Convert color to numpy array | |
| color = np.array(color) | |
| # Apply mask | |
| masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255 | |
| return masked_image.astype(np.uint8) | |
| def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image]: | |
| """Visualize attention maps for the given image and model.""" | |
| model, extractor = load_model(model_name) | |
| attention_maps = process_image(image, model, extractor) | |
| # Convert PIL Image to numpy array | |
| image_np = np.array(image) | |
| # Create visualizations | |
| visualizations = [] | |
| for layer_name, attn_map in attention_maps.items(): | |
| print(f"Attention map shape for {layer_name}: {attn_map.shape}") | |
| # Remove the CLS token attention and average over heads | |
| attn_map = attn_map[0, :, 0, 1:].mean(0) # Shape: (seq_len-1,) | |
| # Reshape the attention map to 2D | |
| num_patches = int(np.sqrt(attn_map.shape[0])) | |
| attn_map = attn_map.reshape(num_patches, num_patches) | |
| # Interpolate to match image size | |
| attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0) | |
| attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False) | |
| attn_map = attn_map.squeeze().cpu().numpy() | |
| # Normalize attention map | |
| attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) | |
| # Create visualization | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) | |
| # Original image | |
| ax1.imshow(image_np) | |
| ax1.set_title("Original Image") | |
| ax1.axis('off') | |
| # Attention map overlay | |
| masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0)) # Red mask | |
| ax2.imshow(masked_image) | |
| ax2.set_title(f'Attention Map for {layer_name}') | |
| ax2.axis('off') | |
| plt.tight_layout() | |
| # Convert plot to image | |
| fig.canvas.draw() | |
| vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| visualizations.append(vis_image) | |
| plt.close(fig) | |
| return visualizations | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=visualize_attention, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Dropdown(choices=get_attention_models(), label="Select Model") | |
| ], | |
| outputs=gr.Gallery(label="Attention Maps"), | |
| title="Attention Map Visualizer for timm Models", | |
| description="Upload an image and select a timm model to visualize its attention maps." | |
| ) | |
| iface.launch(debug=True) |