Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import gc | |
| from pathlib import Path | |
| import gradio as gr | |
| import matplotlib.cm as cm | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageOps | |
| from transformers import AutoImageProcessor, AutoModel | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_MAP = { | |
| "DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m", | |
| "DINOv3 ViT-L/16 Web imgs (1.7B)": "facebook/dinov3-vitl16-pretrain-lvd1689m", | |
| # "DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m", | |
| } # uncomment the 7b ^ if running locally and have 60+ GB vram on deck | |
| DEFAULT_NAME = list(MODEL_MAP.keys())[0] | |
| MAX_IMAGE_DIM = 720 # Maximum dimension for longer side | |
| # Global model state | |
| processor = None | |
| model = None | |
| def cleanup_memory(): | |
| """Aggressive memory cleanup for model switching""" | |
| global processor, model | |
| if model is not None: | |
| del model | |
| model = None | |
| if processor is not None: | |
| del processor | |
| processor = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def compute_dynamic_size(height, width, max_dim: int = 720, patch_size: int = 16): | |
| """ | |
| Compute new dimensions preserving aspect ratio with max_dim constraint. | |
| Ensures dimensions are divisible by patch_size for clean patch extraction. | |
| """ | |
| # Determine scaling factor | |
| if height > width: | |
| scale = min(1.0, max_dim / height) | |
| else: | |
| scale = min(1.0, max_dim / width) | |
| # Compute new dimensions | |
| new_height = int(height * scale) | |
| new_width = int(width * scale) | |
| # Round to nearest multiple of patch_size for clean patches | |
| new_height = (new_height // patch_size) * patch_size | |
| new_width = (new_width // patch_size) * patch_size | |
| return new_height, new_width | |
| def load_model(name): | |
| """Load model with dtype""" | |
| global processor, model | |
| cleanup_memory() | |
| model_id = MODEL_MAP[name] | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModel.from_pretrained( | |
| model_id, | |
| torch_dtype="auto", | |
| ).eval() | |
| param_count = sum(p.numel() for p in model.parameters()) / 1e9 | |
| return f"Loaded: {name} | {param_count:.2f}B params" | |
| # Initialize default model | |
| load_model(DEFAULT_NAME) | |
| def preprocess_image(img): | |
| """ | |
| Custom preprocessing that respects aspect ratio & uses dynamic sizing. | |
| DINOv3's 2D axial RoPE handles variable sizes, no need to force 224x224 | |
| """ | |
| # Convert to RGB if needed | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| # Compute dynamic size | |
| orig_h, orig_w = img.height, img.width | |
| patch_size = model.config.patch_size | |
| new_h, new_w = compute_dynamic_size(orig_h, orig_w, MAX_IMAGE_DIM, patch_size) | |
| # Resize image | |
| img_resized = img.resize((new_w, new_h), Image.Resampling.BICUBIC) | |
| # Convert to tensor and normalize using the processor's normalization params | |
| img_array = np.array(img_resized).astype(np.float32) / 255.0 | |
| # Apply ImageNet normalization (from processor config) | |
| mean = ( | |
| processor.image_mean | |
| if hasattr(processor, "image_mean") | |
| else [0.485, 0.456, 0.406] | |
| ) | |
| std = ( | |
| processor.image_std | |
| if hasattr(processor, "image_std") | |
| else [0.229, 0.224, 0.225] | |
| ) | |
| img_array = (img_array - mean) / std | |
| # Convert to tensor with correct shape: [1, C, H, W] | |
| img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() | |
| return img_tensor, new_h, new_w | |
| def _extract_grid(img): | |
| """Extract feature grid from image - now with dynamic sizing!""" | |
| global model | |
| with torch.inference_mode(): | |
| # Move model to GPU for this call | |
| model = model.to("cuda") | |
| # Preprocess with dynamic sizing | |
| pv, img_h, img_w = preprocess_image(img) | |
| pv = pv.to(model.device) | |
| # Run inference - the model handles variable sizes perfectly! | |
| out = model(pixel_values=pv) | |
| last = out.last_hidden_state[0].to(torch.float32) | |
| # Extract features | |
| num_reg = getattr(model.config, "num_register_tokens", 0) | |
| p = model.config.patch_size | |
| # Calculate grid dimensions based on actual image size | |
| gh, gw = img_h // p, img_w // p | |
| feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu() | |
| # Move model back to CPU before function exits | |
| model = model.cpu() | |
| torch.cuda.empty_cache() | |
| return feats, gh, gw, img_h, img_w | |
| def _overlay(orig, heat01, alpha=0.55, box=None): | |
| """Create heatmap overlay""" | |
| H, W = orig.height, orig.width | |
| heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H)) | |
| # Use turbo colormap - better for satellite imagery | |
| rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8) | |
| ov = Image.fromarray(rgba, "RGBA") | |
| ov.putalpha(int(alpha * 255)) | |
| base = orig.copy().convert("RGBA") | |
| out = Image.alpha_composite(base, ov) | |
| if box: | |
| from PIL import ImageDraw | |
| draw = ImageDraw.Draw(out, "RGBA") | |
| # Enhanced box visualization | |
| draw.rectangle(box, outline=(255, 255, 255, 255), width=3) | |
| draw.rectangle( | |
| (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1), | |
| outline=(0, 0, 0, 200), | |
| width=1, | |
| ) | |
| return out | |
| def prepare(img): | |
| """Prepare image and extract features with dynamic sizing""" | |
| if img is None: | |
| return None | |
| base = ImageOps.exif_transpose(img.convert("RGB")) | |
| feats, gh, gw, img_h, img_w = _extract_grid(base) | |
| return { | |
| "orig": base, | |
| "feats": feats, | |
| "gh": gh, | |
| "gw": gw, | |
| "processed_h": img_h, | |
| "processed_w": img_w, | |
| } | |
| def click(state, opacity, img_value, evt: gr.SelectData): | |
| """Handle click events for similarity visualization with progress feedback""" | |
| # Immediate feedback in resolution_info box | |
| if img_value is not None: | |
| yield img_value, state, "Computing similarity..." | |
| # If state wasn't prepared (e.g., Example selection), build it now | |
| if state is None and img_value is not None: | |
| state = prepare(img_value) | |
| if not state or evt.index is None: | |
| # Just show whatever is currently in the image component | |
| yield img_value, state, "" | |
| return | |
| base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"] | |
| x, y = evt.index | |
| px_x, px_y = base.width / gw, base.height / gh | |
| i = min(int(x // px_x), gw - 1) | |
| j = min(int(y // px_y), gh - 1) | |
| d = feats.shape[-1] | |
| grid = F.normalize(feats.reshape(-1, d), dim=1) | |
| v = F.normalize(feats[j, i].reshape(1, d), dim=1) | |
| sims = (grid @ v.T).reshape(gh, gw).numpy() | |
| smin, smax = float(sims.min()), float(sims.max()) | |
| heat01 = (sims - smin) / (smax - smin + 1e-12) | |
| box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y)) | |
| overlay = _overlay(base, heat01, alpha=opacity, box=box) | |
| # Add info about resolution being processed | |
| info_text = f"Processing at: {state['processed_w']}×{state['processed_h']} ({gh}×{gw} patches) | Patch [{i},{j}] • Range: {smin:.3f}-{smax:.3f}" | |
| yield overlay, state, info_text | |
| def reset(): | |
| """Reset the interface""" | |
| return None, None, "" | |
| with gr.Blocks( | |
| theme=gr.themes.Citrus(), | |
| css=""" | |
| .container {max-width: 1200px; margin: auto;} | |
| .header {text-align: center; padding: 20px;} | |
| .info-box { | |
| background: rgba(0,0,0,0.03); | |
| border-radius: 8px; | |
| padding: 12px; | |
| margin: 10px 0; | |
| border-left: 4px solid #2563eb; | |
| } | |
| """, | |
| ) as demo: | |
| gr.HTML( | |
| """ | |
| <div class="header"> | |
| <h1>🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity</h1> | |
| <p style="font-size: 1.1em; color: #666;"> | |
| Click any region to visualize feature similarities across the image | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_choice = gr.Dropdown( | |
| choices=list(MODEL_MAP.keys()), | |
| value=DEFAULT_NAME, | |
| label="Model Selection", | |
| info="Select a model (size/pretraining dataset)", | |
| ) | |
| status = gr.Textbox( | |
| label="Model Status", | |
| value=f"Loaded: {DEFAULT_NAME}", | |
| interactive=False, | |
| lines=1, | |
| ) | |
| resolution_info = gr.Textbox( | |
| label="Info & Status", | |
| value="", | |
| interactive=False, | |
| lines=1, | |
| ) | |
| opacity = gr.Slider( | |
| 0.0, | |
| 1.0, | |
| 0.55, | |
| step=0.05, | |
| label="Heatmap Opacity", | |
| info="Balance between image and similarity map", | |
| ) | |
| with gr.Row(): | |
| reset_btn = gr.Button("Reset", variant="secondary", scale=1) | |
| clear_btn = gr.ClearButton(value="Clear All", scale=1) | |
| with gr.Column(scale=2): | |
| img = gr.Image( | |
| type="pil", | |
| label="Interactive Canvas (Click to explore)", | |
| interactive=True, | |
| height=600, | |
| show_download_button=True, | |
| show_share_button=False, | |
| ) | |
| state = gr.State() | |
| model_choice.change( | |
| load_model, inputs=model_choice, outputs=status, show_progress="full" | |
| ) | |
| img.upload(prepare, inputs=img, outputs=state) | |
| img.select( | |
| click, | |
| inputs=[state, opacity, img], | |
| outputs=[img, state, resolution_info], | |
| show_progress="hidden", # Hide default overlay, use resolution_info for feedback | |
| ) | |
| reset_btn.click(reset, outputs=[img, state, resolution_info]) | |
| clear_btn.add([img, state, resolution_info]) | |
| # Examples from pwd | |
| example_files = [ | |
| f.name | |
| for f in Path.cwd().iterdir() | |
| if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"] | |
| ] | |
| if example_files: | |
| gr.Examples( | |
| examples=[[f] for f in example_files], | |
| inputs=img, | |
| fn=prepare, | |
| outputs=[state], | |
| label="Example Images", | |
| examples_per_page=4, | |
| cache_examples=False, | |
| ) | |
| gr.Markdown( | |
| f""" | |
| --- | |
| <div style="text-align: center; color: #666; font-size: 0.9em;"> | |
| Satellite-pretrained models are intended for: geographic patterns, land use classification. structural analysis, etc. Try comparing similarity maps for the same image created by the model pretrained on sat493m vs. the one on lvd1689m (<i>general web</i>). | |
| <br><br> | |
| <b>Dynamic Resolution:</b> Images are processed at up to {MAX_IMAGE_DIM}px (longer side) while preserving aspect ratio. | |
| DINOv3's 2D axial RoPE embeddings handle variable sizes. | |
| <br> | |
| </div> | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, debug=True) | |
 
			
