Spaces:
Runtime error
Runtime error
| import os | |
| import requests | |
| # Disable JIT | |
| os.environ["PYTORCH_JIT"] = "0" | |
| from einops import rearrange | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageOps | |
| from transformers import AutoModel, CLIPImageProcessor | |
| from segment_anything import SamAutomaticMaskGenerator, sam_model_registry | |
| from segment_anything.modeling.image_encoder import ImageEncoderViT | |
| class RADIOVenc(nn.Module): | |
| def __init__(self, radio: nn.Module, img_enc: ImageEncoderViT, img_size: int = 1024): | |
| super().__init__() | |
| self.radio = radio | |
| self.neck = img_enc.neck | |
| self.img_size = img_size | |
| self.dtype = radio.input_conditioner.dtype | |
| def forward(self, x: torch.Tensor): | |
| h, w = x.shape[-2:] | |
| if self.dtype is not None: | |
| x = x.to(dtype=self.dtype) | |
| with torch.autocast('cuda', dtype=torch.bfloat16, enabled=self.dtype is None): | |
| output = self.radio(x) | |
| features = output["sam"].features | |
| rows = h // 16 | |
| cols = w // 16 | |
| features = rearrange(features, 'b (h w) c -> b c h w', h=rows, w=cols) | |
| features = self.neck(features) | |
| return features | |
| def download_file(url, save_path): | |
| # Check if the file already exists | |
| if os.path.exists(save_path): | |
| print(f"File already exists at {save_path}. Skipping download.") | |
| return | |
| print(f"Downloading from {url}") | |
| # Send a GET request to the URL | |
| response = requests.get(url, stream=True) | |
| # Check if the request was successful | |
| if response.status_code == 200: | |
| # Open the file in binary write mode | |
| with open(save_path, 'wb') as file: | |
| # Iterate over the response content in chunks | |
| for chunk in response.iter_content(chunk_size=1024): | |
| if chunk: # filter out keep-alive new chunks | |
| file.write(chunk) | |
| print(f"File downloaded successfully and saved as {save_path}") | |
| else: | |
| print(f"Failed to download file. HTTP Status Code: {response.status_code}") | |
| hf_repo = "nvidia/RADIO-L" | |
| image_processor = CLIPImageProcessor.from_pretrained(hf_repo) | |
| model_version = "radio_v2.5-l" # for RADIOv2.5-L model (ViT-L/16) | |
| model = torch.hub.load( | |
| 'NVlabs/RADIO', | |
| 'radio_model', | |
| version=model_version, | |
| progress=True, | |
| skip_validation=True, | |
| adaptor_names='sam') | |
| model.eval() | |
| local_sam_checkpoint_path = "sam_vit_h_4b8939.pth" | |
| download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", local_sam_checkpoint_path) | |
| sam = sam_model_registry["vit_h"](checkpoint=local_sam_checkpoint_path) | |
| model._patch_size = 16 | |
| sam.image_encoder = RADIOVenc(model, sam.image_encoder, img_size=1024) | |
| conditioner = model.make_preprocessor_external() | |
| sam.pixel_mean = conditioner.norm_mean * 255 | |
| sam.pixel_std = conditioner.norm_std * 255 | |
| def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False): | |
| # features: (N, C) | |
| # m: a hyperparam controlling how many std dev outside for outliers | |
| assert len(features.shape) == 2, "features should be (N, C)" | |
| reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2] | |
| colors = features @ reduction_mat | |
| if remove_first_component: | |
| colors_min = colors.min(dim=0).values | |
| colors_max = colors.max(dim=0).values | |
| tmp_colors = (colors - colors_min) / (colors_max - colors_min) | |
| fg_mask = tmp_colors[..., 0] < 0.2 | |
| reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2] | |
| colors = features @ reduction_mat | |
| else: | |
| fg_mask = torch.ones_like(colors[:, 0]).bool() | |
| d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values) | |
| mdev = torch.median(d, dim=0).values | |
| s = d / mdev | |
| try: | |
| rins = colors[fg_mask][s[:, 0] < m, 0] | |
| gins = colors[fg_mask][s[:, 1] < m, 1] | |
| bins = colors[fg_mask][s[:, 2] < m, 2] | |
| rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) | |
| rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) | |
| except: | |
| rins = colors | |
| gins = colors | |
| bins = colors | |
| rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) | |
| rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) | |
| return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat) | |
| def get_pca_map( | |
| feature_map: torch.Tensor, | |
| img_size, | |
| interpolation="bicubic", | |
| return_pca_stats=False, | |
| pca_stats=None, | |
| ): | |
| """ | |
| feature_map: (1, h, w, C) is the feature map of a single image. | |
| """ | |
| if feature_map.shape[0] != 1: | |
| # make it (1, h, w, C) | |
| feature_map = feature_map[None] | |
| if pca_stats is None: | |
| reduct_mat, color_min, color_max = get_robust_pca( | |
| feature_map.reshape(-1, feature_map.shape[-1]) | |
| ) | |
| else: | |
| reduct_mat, color_min, color_max = pca_stats | |
| pca_color = feature_map @ reduct_mat | |
| pca_color = (pca_color - color_min) / (color_max - color_min) | |
| pca_color = pca_color.clamp(0, 1) | |
| pca_color = F.interpolate( | |
| pca_color.permute(0, 3, 1, 2), | |
| size=img_size, | |
| mode=interpolation, | |
| ).permute(0, 2, 3, 1) | |
| pca_color = pca_color.cpu().numpy().squeeze(0) | |
| if return_pca_stats: | |
| return pca_color, (reduct_mat, color_min, color_max) | |
| return pca_color | |
| def pad_image_to_multiple_of(image, multiple=16): | |
| # Calculate the new dimensions to make them multiples | |
| width, height = image.size | |
| new_width = (width + multiple -1) // multiple * multiple | |
| new_height = (height + multiple -1) // multiple * multiple | |
| # Calculate the padding needed on each side | |
| pad_width = new_width - width | |
| pad_height = new_height - height | |
| left = pad_width // 2 | |
| right = pad_width - left | |
| top = pad_height // 2 | |
| bottom = pad_height - top | |
| # Apply the padding | |
| padded_image = ImageOps.expand(image, (left, top, right, bottom), fill='black') | |
| return padded_image | |
| def center_crop_resize(image, size=(1024, 1024)): | |
| # Get dimensions | |
| width, height = image.size | |
| # Determine the center crop box | |
| if width > height: | |
| new_width = height | |
| new_height = height | |
| left = (width - new_width) / 2 | |
| top = 0 | |
| right = (width + new_width) / 2 | |
| bottom = height | |
| else: | |
| new_width = width | |
| new_height = width | |
| left = 0 | |
| top = (height - new_height) / 2 | |
| right = width | |
| bottom = (height + new_height) / 2 | |
| # Crop the image to a square | |
| image = image.crop((left, top, right, bottom)) | |
| # Resize the cropped image to the target size | |
| image = image.resize(size, Image.LANCZOS) | |
| return image | |
| def visualize_anns(orig_image: np.ndarray, anns): | |
| if len(anns) == 0: | |
| return orig_image | |
| sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
| kernel = torch.ones(1, 1, 5, 5, dtype=torch.float32) | |
| # RGBA | |
| mask = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4), dtype=np.float32) | |
| mask[:,:,3] = 0 | |
| for ann in sorted_anns: | |
| m = ann['segmentation'] | |
| color_mask = np.concatenate([np.random.random(3), [0.35]]) | |
| tm = torch.as_tensor(m).reshape(1, 1, *m.shape).float() | |
| cvtm = F.conv2d(tm, kernel, padding=2) | |
| border_mask = (cvtm < 25).flatten(0, 2).numpy() | |
| mask[m] = color_mask | |
| mask[m & border_mask, 3] *= 1.0 / 0.35 | |
| color, alpha = mask[..., :3], mask[..., -1:] | |
| orig_image = orig_image.astype(np.float32) / 255 | |
| overlay = alpha * color + (1 - alpha) * orig_image | |
| overlay = (overlay * 255).astype(np.uint8) | |
| return overlay | |
| def infer_radio(image): | |
| """Define the function to generate the output.""" | |
| model.cuda() | |
| conditioner.cuda() | |
| sam.cuda() | |
| sam_generator = SamAutomaticMaskGenerator(sam, output_mode="binary_mask") | |
| # PCA feature visalization | |
| padded_image=pad_image_to_multiple_of(image, multiple=256) | |
| width, height = padded_image.size | |
| pixel_values = image_processor(images=padded_image, return_tensors='pt').pixel_values | |
| pixel_values = pixel_values.to(torch.bfloat16).cuda() | |
| pixel_values = conditioner(pixel_values) | |
| _, features = model(pixel_values)["backbone"] | |
| num_rows = height // model.patch_size | |
| num_cols = width // model.patch_size | |
| features = features.detach() | |
| features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float() | |
| pca_viz = get_pca_map(features, (height, width), interpolation='bilinear') | |
| # SAM feature visualization | |
| resized_image = center_crop_resize(image) | |
| image_array = np.array(image) | |
| print("image size", image_array.shape) | |
| #image_array = np.transpose(image_array, (2, 0, 1)) | |
| masks = sam_generator.generate(image_array) | |
| overlay = visualize_anns(image_array, masks) | |
| return pca_viz, overlay, f"{features.shape}" | |
| title = """RADIO: Reduce All Domains Into One""" | |
| description = """ | |
| # RADIO | |
| [AM-RADIO](https://github.com/NVlabs/RADIO) is a framework to distill Large Vision Foundation models into a single one. | |
| RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones. | |
| Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence. | |
| Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images. | |
| # Instructions | |
| Paste an image into the input box or pick one from the gallery of examples and then click the "Submit" button. | |
| The RADIO backbone features are processed with a PCA projection to 3 channels and displayed as an RGB channels. | |
| The SAM features are processed using the SAM decoder and shown as an overlay on top of the input image. | |
| """ | |
| inputs = [ | |
| gr.Image(type="pil") | |
| ] | |
| outputs = [ | |
| gr.Image(label="PCA Feature Visalization"), | |
| gr.Image(label="SAM Masks"), | |
| gr.Textbox(label="Feature Shape"), | |
| ] | |
| # Create the Gradio interface | |
| demo = gr.Interface( | |
| fn=infer_radio, | |
| inputs=inputs, | |
| examples="./samples/", | |
| outputs=outputs, | |
| title=title, | |
| description=description, | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |