Spaces:
Sleeping
Sleeping
| import os | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| import numpy as np | |
| import gradio as gr | |
| from transformers import AutoModel, AutoImageProcessor | |
| from PIL import Image | |
| import torch | |
| os.environ["HF_HUB_OFFLINE"] = "0" | |
| # Global state to store loaded model + processors | |
| state = { | |
| "model_type": None, | |
| "model": None, | |
| "processor": None, | |
| "repo_id": None, | |
| } | |
| def similarity_heatmap(image): | |
| """ | |
| Compute cosine similarity between CLS token and patch tokens | |
| """ | |
| model, processor = state["model"], state["processor"] | |
| inputs = processor(images=image, return_tensors="pt") | |
| pixel_values = inputs["pixel_values"].to(model.device) # shape: (1, 3, H, W) | |
| # get ViT patch size (from model config) | |
| patch_size = model.config.patch_size # usually 16 | |
| # Compute patch grid (needed for resizing later) | |
| H_patch = pixel_values.shape[2] // patch_size | |
| W_patch = pixel_values.shape[3] // patch_size | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) # last_hidden_state: (1, seq_len, hidden_dim) | |
| last_hidden_state = outputs.last_hidden_state | |
| cls_token = last_hidden_state[:, 0, :] # shape: (1, hidden_dim) | |
| patch_tokens = last_hidden_state[:, 1:, :] # shape: (1, num_patches, hidden_dim) | |
| cls_norm = cls_token / cls_token.norm(dim=-1, keepdim=True) | |
| patch_norm = patch_tokens / patch_tokens.norm(dim=-1, keepdim=True) | |
| cos_sim = torch.einsum("bd,bpd->bp", cls_norm, patch_norm) # shape: (1, num_patches) | |
| cos_sim = cos_sim.reshape((H_patch, W_patch)) | |
| return np.array(cos_sim) | |
| def overlay_cosine_grid_on_image(cos_grid: np.ndarray, image: Image.Image, alpha=0.5, colormap="viridis"): | |
| """ | |
| cos_grid: (H_patch, W_patch) numpy array of cosine similarities | |
| image: PIL.Image | |
| alpha: blending factor | |
| colormap: matplotlib colormap name | |
| """ | |
| # Normalize cosine values to [0, 1] for colormap | |
| norm_grid = (cos_grid - cos_grid.min()) / (cos_grid.max() - cos_grid.min() + 1e-8) | |
| # Apply colormap | |
| cmap = cm.get_cmap(colormap) | |
| heatmap_rgba = cmap(norm_grid) # shape: (H_patch, W_patch, 4) | |
| # Convert to RGB 0-255 | |
| heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8) | |
| heatmap_img = Image.fromarray(heatmap_rgb) | |
| # Resize heatmap to match original image size | |
| heatmap_resized = heatmap_img.resize(image.size, resample=Image.BILINEAR) | |
| # Blend with original image | |
| blended = Image.blend(image.convert("RGBA"), heatmap_resized.convert("RGBA"), alpha=alpha) | |
| return blended | |
| def load_model(repo_id: str, revision: str = None): | |
| """ | |
| Load a Hugging Face model + processor from Hub. | |
| Works with any public repo_id. | |
| """ | |
| try: | |
| # Clean up inputs | |
| repo_id = repo_id.strip() | |
| if not repo_id: | |
| return "Please enter a model repo ID" | |
| if revision and revision.strip() == "": | |
| revision = None | |
| # First try without cache_dir to avoid permission issues | |
| try: | |
| model = AutoModel.from_pretrained( | |
| repo_id, | |
| revision=revision, | |
| trust_remote_code=True, | |
| use_auth_token=False # Explicitly no auth for public models | |
| ) | |
| processor = AutoImageProcessor.from_pretrained( | |
| repo_id, | |
| revision=revision, | |
| trust_remote_code=True, | |
| use_auth_token=False | |
| ) | |
| except Exception as e1: | |
| # If that fails, try with explicit cache directory | |
| model = AutoModel.from_pretrained( | |
| repo_id, | |
| revision=revision, | |
| cache_dir="/tmp/model_cache", # Use /tmp for better permissions | |
| trust_remote_code=True, | |
| use_auth_token=False, | |
| local_files_only=False # Ensure we can download | |
| ) | |
| processor = AutoImageProcessor.from_pretrained( | |
| repo_id, | |
| revision=revision, | |
| cache_dir="/tmp/model_cache", | |
| trust_remote_code=True, | |
| use_auth_token=False, | |
| local_files_only=False | |
| ) | |
| # Move to appropriate device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| # Validate it's a Vision Transformer | |
| if not hasattr(model.config, 'patch_size'): | |
| return f"Model '{repo_id}' doesn't appear to be a Vision Transformer (no patch_size in config)" | |
| # Update global state | |
| state["model"] = model | |
| state["processor"] = processor | |
| state["repo_id"] = repo_id | |
| state["model_type"] = "custom" | |
| patch_size = model.config.patch_size | |
| return f"Successfully loaded ViT model '{repo_id}' (patch size: {patch_size}) on {device}" | |
| except Exception as e: | |
| error_str = str(e).lower() | |
| if "repository not found" in error_str or "404" in error_str: | |
| return f"Repository '{repo_id}' not found. Please check the repo ID." | |
| elif "connection" in error_str or "network" in error_str or "offline" in error_str: | |
| return f"Network error: {str(e)}" | |
| elif "permission" in error_str or "forbidden" in error_str: | |
| return f"Permission denied. This might be a private repository." | |
| else: | |
| return f"Error loading model: {str(e)}" | |
| def display_image(image: Image): | |
| """ | |
| Simply returns the uploaded image. | |
| """ | |
| return image | |
| def visualize_cosine_heatmap(image: Image): | |
| """ | |
| Generate and overlay cosine similarity heatmap on the input image. | |
| """ | |
| if state["model"] is None: | |
| return None # Return None if no model is loaded | |
| try: | |
| cos_grid = similarity_heatmap(image) | |
| blended = overlay_cosine_grid_on_image(cos_grid, image) | |
| return blended | |
| except Exception as e: | |
| print(f"Error generating heatmap: {e}") | |
| return None | |
| # Gradio UI | |
| with gr.Blocks(title="ViT CLS Visualizer") as demo: | |
| gr.Markdown("# ViT CLS-Visualizer") | |
| gr.Markdown( | |
| "Enter the Hugging Face model repo ID (must be public), upload an image, " | |
| "and visualize the cosine similarity between the CLS token and patches." | |
| ) | |
| gr.Markdown("### Popular Vision Transformer models to try:") | |
| gr.Markdown( | |
| "- `google/vit-base-patch16-224`\n" | |
| "- `facebook/deit-base-distilled-patch16-224`\n" | |
| "- `microsoft/dit-base`" | |
| ) | |
| with gr.Row(): | |
| repo_input = gr.Textbox( | |
| label="Hugging Face Model Repo ID", | |
| placeholder="e.g. google/vit-base-patch16-224", | |
| value="google/vit-base-patch16-224" | |
| ) | |
| revision_input = gr.Textbox( | |
| label="Revision (optional)", | |
| placeholder="branch, tag, or commit hash" | |
| ) | |
| load_btn = gr.Button("Load Model", variant="primary") | |
| load_status = gr.Textbox(label="Model Status", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| image_output = gr.Image(label="Uploaded Image") | |
| with gr.Column(): | |
| compute_btn = gr.Button("Compute Heatmap", variant="primary") | |
| heatmap_output = gr.Image(label="Cosine Similarity Heatmap") | |
| # Events | |
| load_btn.click( | |
| fn=load_model, | |
| inputs=[repo_input, revision_input], | |
| outputs=load_status | |
| ) | |
| image_input.change( | |
| fn=display_image, | |
| inputs=image_input, | |
| outputs=image_output | |
| ) | |
| compute_btn.click( | |
| fn=visualize_cosine_heatmap, | |
| inputs=image_input, | |
| outputs=heatmap_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |