import gc from pathlib import Path import gradio as gr import matplotlib.cm as cm import numpy as np import spaces import torch import torch.nn.functional as F from PIL import Image, ImageOps from transformers import AutoImageProcessor, AutoModel DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_MAP = { "DINOv3 ViT-L/16 Satellite (493M)": "facebook/dinov3-vitl16-pretrain-sat493m", "DINOv3 ViT-L/16 Web imgs (1.7B)": "facebook/dinov3-vitl16-pretrain-lvd1689m", # "DINOv3 ViT-7B/16 Satellite": "facebook/dinov3-vit7b16-pretrain-sat493m", } # uncomment the 7b ^ if running locally and have 60+ GB vram on deck DEFAULT_NAME = list(MODEL_MAP.keys())[0] MAX_IMAGE_DIM = 720 # Maximum dimension for longer side # Global model state processor = None model = None def cleanup_memory(): """Aggressive memory cleanup for model switching""" global processor, model if model is not None: del model model = None if processor is not None: del processor processor = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def compute_dynamic_size(height, width, max_dim: int = 720, patch_size: int = 16): """ Compute new dimensions preserving aspect ratio with max_dim constraint. Ensures dimensions are divisible by patch_size for clean patch extraction. """ # Determine scaling factor if height > width: scale = min(1.0, max_dim / height) else: scale = min(1.0, max_dim / width) # Compute new dimensions new_height = int(height * scale) new_width = int(width * scale) # Round to nearest multiple of patch_size for clean patches new_height = (new_height // patch_size) * patch_size new_width = (new_width // patch_size) * patch_size return new_height, new_width def load_model(name): """Load model with dtype""" global processor, model cleanup_memory() model_id = MODEL_MAP[name] processor = AutoImageProcessor.from_pretrained(model_id) model = AutoModel.from_pretrained( model_id, torch_dtype="auto", ).eval() param_count = sum(p.numel() for p in model.parameters()) / 1e9 return f"Loaded: {name} | {param_count:.2f}B params" # Initialize default model load_model(DEFAULT_NAME) def preprocess_image(img): """ Custom preprocessing that respects aspect ratio & uses dynamic sizing. DINOv3's 2D axial RoPE handles variable sizes, no need to force 224x224 """ # Convert to RGB if needed if img.mode != "RGB": img = img.convert("RGB") # Compute dynamic size orig_h, orig_w = img.height, img.width patch_size = model.config.patch_size new_h, new_w = compute_dynamic_size(orig_h, orig_w, MAX_IMAGE_DIM, patch_size) # Resize image img_resized = img.resize((new_w, new_h), Image.Resampling.BICUBIC) # Convert to tensor and normalize using the processor's normalization params img_array = np.array(img_resized).astype(np.float32) / 255.0 # Apply ImageNet normalization (from processor config) mean = ( processor.image_mean if hasattr(processor, "image_mean") else [0.485, 0.456, 0.406] ) std = ( processor.image_std if hasattr(processor, "image_std") else [0.229, 0.224, 0.225] ) img_array = (img_array - mean) / std # Convert to tensor with correct shape: [1, C, H, W] img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() return img_tensor, new_h, new_w @spaces.GPU(duration=60) def _extract_grid(img): """Extract feature grid from image - now with dynamic sizing!""" global model with torch.inference_mode(): # Move model to GPU for this call model = model.to("cuda") # Preprocess with dynamic sizing pv, img_h, img_w = preprocess_image(img) pv = pv.to(model.device) # Run inference - the model handles variable sizes perfectly! out = model(pixel_values=pv) last = out.last_hidden_state[0].to(torch.float32) # Extract features num_reg = getattr(model.config, "num_register_tokens", 0) p = model.config.patch_size # Calculate grid dimensions based on actual image size gh, gw = img_h // p, img_w // p feats = last[1 + num_reg :, :].reshape(gh, gw, -1).cpu() # Move model back to CPU before function exits model = model.cpu() torch.cuda.empty_cache() return feats, gh, gw, img_h, img_w def _overlay(orig, heat01, alpha=0.55, box=None): """Create heatmap overlay""" H, W = orig.height, orig.width heat = Image.fromarray((heat01 * 255).astype(np.uint8)).resize((W, H)) # Use turbo colormap - better for satellite imagery rgba = (cm.get_cmap("turbo")(np.asarray(heat) / 255.0) * 255).astype(np.uint8) ov = Image.fromarray(rgba, "RGBA") ov.putalpha(int(alpha * 255)) base = orig.copy().convert("RGBA") out = Image.alpha_composite(base, ov) if box: from PIL import ImageDraw draw = ImageDraw.Draw(out, "RGBA") # Enhanced box visualization draw.rectangle(box, outline=(255, 255, 255, 255), width=3) draw.rectangle( (box[0] - 1, box[1] - 1, box[2] + 1, box[3] + 1), outline=(0, 0, 0, 200), width=1, ) return out def prepare(img): """Prepare image and extract features with dynamic sizing""" if img is None: return None base = ImageOps.exif_transpose(img.convert("RGB")) feats, gh, gw, img_h, img_w = _extract_grid(base) return { "orig": base, "feats": feats, "gh": gh, "gw": gw, "processed_h": img_h, "processed_w": img_w, } def click(state, opacity, img_value, evt: gr.SelectData): """Handle click events for similarity visualization with progress feedback""" # Immediate feedback in resolution_info box if img_value is not None: yield img_value, state, "Computing similarity..." # If state wasn't prepared (e.g., Example selection), build it now if state is None and img_value is not None: state = prepare(img_value) if not state or evt.index is None: # Just show whatever is currently in the image component yield img_value, state, "" return base, feats, gh, gw = state["orig"], state["feats"], state["gh"], state["gw"] x, y = evt.index px_x, px_y = base.width / gw, base.height / gh i = min(int(x // px_x), gw - 1) j = min(int(y // px_y), gh - 1) d = feats.shape[-1] grid = F.normalize(feats.reshape(-1, d), dim=1) v = F.normalize(feats[j, i].reshape(1, d), dim=1) sims = (grid @ v.T).reshape(gh, gw).numpy() smin, smax = float(sims.min()), float(sims.max()) heat01 = (sims - smin) / (smax - smin + 1e-12) box = (int(i * px_x), int(j * px_y), int((i + 1) * px_x), int((j + 1) * px_y)) overlay = _overlay(base, heat01, alpha=opacity, box=box) # Add info about resolution being processed info_text = f"Processing at: {state['processed_w']}×{state['processed_h']} ({gh}×{gw} patches) | Patch [{i},{j}] • Range: {smin:.3f}-{smax:.3f}" yield overlay, state, info_text def reset(): """Reset the interface""" return None, None, "" with gr.Blocks( theme=gr.themes.Citrus(), css=""" .container {max-width: 1200px; margin: auto;} .header {text-align: center; padding: 20px;} .info-box { background: rgba(0,0,0,0.03); border-radius: 8px; padding: 12px; margin: 10px 0; border-left: 4px solid #2563eb; } """, ) as demo: gr.HTML( """

🛰️ DINOv3 Satellite Vision: Interactive Patch Similarity

Click any region to visualize feature similarities across the image

""" ) with gr.Row(): with gr.Column(scale=1): model_choice = gr.Dropdown( choices=list(MODEL_MAP.keys()), value=DEFAULT_NAME, label="Model Selection", info="Select a model (size/pretraining dataset)", ) status = gr.Textbox( label="Model Status", value=f"Loaded: {DEFAULT_NAME}", interactive=False, lines=1, ) resolution_info = gr.Textbox( label="Info & Status", value="", interactive=False, lines=1, ) opacity = gr.Slider( 0.0, 1.0, 0.55, step=0.05, label="Heatmap Opacity", info="Balance between image and similarity map", ) with gr.Row(): reset_btn = gr.Button("Reset", variant="secondary", scale=1) clear_btn = gr.ClearButton(value="Clear All", scale=1) with gr.Column(scale=2): img = gr.Image( type="pil", label="Interactive Canvas (Click to explore)", interactive=True, height=600, show_download_button=True, show_share_button=False, ) state = gr.State() model_choice.change( load_model, inputs=model_choice, outputs=status, show_progress="full" ) img.upload(prepare, inputs=img, outputs=state) img.select( click, inputs=[state, opacity, img], outputs=[img, state, resolution_info], show_progress="hidden", # Hide default overlay, use resolution_info for feedback ) reset_btn.click(reset, outputs=[img, state, resolution_info]) clear_btn.add([img, state, resolution_info]) # Examples from pwd example_files = [ f.name for f in Path.cwd().iterdir() if f.suffix.lower() in [".jpg", ".jpeg", ".png", ".webp"] ] if example_files: gr.Examples( examples=[[f] for f in example_files], inputs=img, fn=prepare, outputs=[state], label="Example Images", examples_per_page=4, cache_examples=False, ) gr.Markdown( f""" ---
Satellite-pretrained models are intended for: geographic patterns, land use classification. structural analysis, etc. Try comparing similarity maps for the same image created by the model pretrained on sat493m vs. the one on lvd1689m (general web).

Dynamic Resolution: Images are processed at up to {MAX_IMAGE_DIM}px (longer side) while preserving aspect ratio. DINOv3's 2D axial RoPE embeddings handle variable sizes.
""" ) if __name__ == "__main__": demo.launch(share=False, debug=True)