Spaces:
Running
Running
| import sys, os | |
| import torch | |
| TORCH_VERSION = ".".join(torch.__version__.split(".")[:2]) | |
| CUDA_VERSION = torch.__version__.split("+")[-1] | |
| print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION) | |
| # Install detectron2 that matches the above pytorch version | |
| # See https://detectron2.readthedocs.io/tutorials/install.html for instructions | |
| os.system(f'pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/{CUDA_VERSION}/torch{TORCH_VERSION}/index.html') | |
| os.system("pip install jinja2") | |
| os.system("pip install git+https://github.com/cocodataset/panopticapi.git") | |
| # Imports | |
| import gradio as gr | |
| import detectron2 | |
| from detectron2.utils.logger import setup_logger | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| from torchvision import datasets, transforms | |
| from einops import rearrange | |
| from PIL import Image | |
| import imutils | |
| import matplotlib.pyplot as plt | |
| from mpl_toolkits.axes_grid1 import ImageGrid | |
| from tqdm import tqdm | |
| import random | |
| from functools import partial | |
| import time | |
| # import some common detectron2 utilities | |
| from detectron2 import model_zoo | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.utils.visualizer import Visualizer, ColorMode | |
| from detectron2.data import MetadataCatalog | |
| from detectron2.projects.deeplab import add_deeplab_config | |
| coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic") | |
| # Import Mask2Former | |
| from mask2former import add_maskformer2_config | |
| # DPT dependencies for depth pseudo labeling | |
| from dpt.models import DPTDepthModel | |
| from multimae.input_adapters import PatchedInputAdapter, SemSegInputAdapter | |
| from multimae.output_adapters import SpatialOutputAdapter | |
| from multimae.multimae import pretrain_multimae_base | |
| from utils.data_constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| torch.set_grad_enabled(False) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f'device: {device}') | |
| # Initialize COCO Mask2Former | |
| cfg = get_cfg() | |
| cfg.MODEL.DEVICE='cpu' | |
| add_deeplab_config(cfg) | |
| add_maskformer2_config(cfg) | |
| cfg.merge_from_file("mask2former/configs/coco/panoptic-segmentation/swin/maskformer2_swin_small_bs16_50ep.yaml") | |
| cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/maskformer/mask2former/coco/panoptic/maskformer2_swin_small_bs16_50ep/model_final_a407fd.pkl' | |
| cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True | |
| cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = True | |
| cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = True | |
| semseg_model = DefaultPredictor(cfg) | |
| def predict_semseg(img): | |
| return semseg_model(255*img.permute(1,2,0).numpy())['sem_seg'].argmax(0) | |
| def plot_semseg(img, semseg, ax): | |
| v = Visualizer(img.permute(1,2,0), coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW) | |
| semantic_result = v.draw_sem_seg(semseg.cpu()).get_image() | |
| ax.imshow(semantic_result) | |
| # Initialize Omnidata depth model | |
| os.system("wget https://datasets.epfl.ch/vilab/iccv21/weights/omnidata_rgb2depth_dpt_hybrid.pth -P pretrained_models") | |
| omnidata_ckpt = torch.load('./pretrained_models/omnidata_rgb2depth_dpt_hybrid.pth', map_location='cpu') | |
| depth_model = DPTDepthModel() | |
| depth_model.load_state_dict(omnidata_ckpt) | |
| depth_model = depth_model.to(device).eval() | |
| def predict_depth(img): | |
| depth_model_input = (img.unsqueeze(0) - 0.5) / 0.5 | |
| return depth_model(depth_model_input.to(device)) | |
| # MultiMAE model setup | |
| DOMAIN_CONF = { | |
| 'rgb': { | |
| 'input_adapter': partial(PatchedInputAdapter, num_channels=3, stride_level=1), | |
| 'output_adapter': partial(SpatialOutputAdapter, num_channels=3, stride_level=1), | |
| }, | |
| 'depth': { | |
| 'input_adapter': partial(PatchedInputAdapter, num_channels=1, stride_level=1), | |
| 'output_adapter': partial(SpatialOutputAdapter, num_channels=1, stride_level=1), | |
| }, | |
| 'semseg': { | |
| 'input_adapter': partial(SemSegInputAdapter, num_classes=133, | |
| dim_class_emb=64, interpolate_class_emb=False, stride_level=4), | |
| 'output_adapter': partial(SpatialOutputAdapter, num_channels=133, stride_level=4), | |
| }, | |
| } | |
| DOMAINS = ['rgb', 'depth', 'semseg'] | |
| input_adapters = { | |
| domain: dinfo['input_adapter']( | |
| patch_size_full=16, | |
| ) | |
| for domain, dinfo in DOMAIN_CONF.items() | |
| } | |
| output_adapters = { | |
| domain: dinfo['output_adapter']( | |
| patch_size_full=16, | |
| dim_tokens=256, | |
| use_task_queries=True, | |
| depth=2, | |
| context_tasks=DOMAINS, | |
| task=domain | |
| ) | |
| for domain, dinfo in DOMAIN_CONF.items() | |
| } | |
| multimae = pretrain_multimae_base( | |
| input_adapters=input_adapters, | |
| output_adapters=output_adapters, | |
| ) | |
| CKPT_URL = 'https://github.com/EPFL-VILAB/MultiMAE/releases/download/pretrained-weights/multimae-b_98_rgb+-depth-semseg_1600e_multivit-afff3f8c.pth' | |
| ckpt = torch.hub.load_state_dict_from_url(CKPT_URL, map_location='cpu') | |
| multimae.load_state_dict(ckpt['model'], strict=False) | |
| multimae = multimae.to(device).eval() | |
| # Plotting | |
| def get_masked_image(img, mask, image_size=224, patch_size=16, mask_value=0.0): | |
| img_token = rearrange( | |
| img.detach().cpu(), | |
| 'b c (nh ph) (nw pw) -> b (nh nw) (c ph pw)', | |
| ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
| ) | |
| img_token[mask.detach().cpu()!=0] = mask_value | |
| img = rearrange( | |
| img_token, | |
| 'b (nh nw) (c ph pw) -> b c (nh ph) (nw pw)', | |
| ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
| ) | |
| return img | |
| def denormalize(img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): | |
| return TF.normalize( | |
| img.clone(), | |
| mean= [-m/s for m, s in zip(mean, std)], | |
| std= [1/s for s in std] | |
| ) | |
| def plot_semseg_gt(input_dict, ax=None, image_size=224): | |
| metadata = MetadataCatalog.get("coco_2017_val_panoptic") | |
| instance_mode = ColorMode.IMAGE | |
| img_viz = 255 * denormalize(input_dict['rgb'].detach().cpu())[0].permute(1,2,0) | |
| semseg = F.interpolate( | |
| input_dict['semseg'].unsqueeze(0).cpu().float(), size=image_size, mode='nearest' | |
| ).long()[0,0] | |
| visualizer = Visualizer(img_viz, metadata, instance_mode=instance_mode, scale=1) | |
| visualizer.draw_sem_seg(semseg) | |
| if ax is not None: | |
| ax.imshow(visualizer.get_output().get_image()) | |
| else: | |
| return visualizer.get_output().get_image() | |
| def plot_semseg_gt_masked(input_dict, mask, ax=None, mask_value=1.0, image_size=224): | |
| img = plot_semseg_gt(input_dict, image_size=image_size) | |
| img = torch.LongTensor(img).permute(2,0,1).unsqueeze(0) | |
| masked_img = get_masked_image(img.float()/255.0, mask, image_size=image_size, patch_size=16, mask_value=mask_value) | |
| masked_img = masked_img[0].permute(1,2,0) | |
| if ax is not None: | |
| ax.imshow(masked_img) | |
| else: | |
| return masked_img | |
| def get_pred_with_input(gt, pred, mask, image_size=224, patch_size=16): | |
| gt_token = rearrange( | |
| gt.detach().cpu(), | |
| 'b c (nh ph) (nw pw) -> b (nh nw) (c ph pw)', | |
| ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
| ) | |
| pred_token = rearrange( | |
| pred.detach().cpu(), | |
| 'b c (nh ph) (nw pw) -> b (nh nw) (c ph pw)', | |
| ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
| ) | |
| pred_token[mask.detach().cpu()==0] = gt_token[mask.detach().cpu()==0] | |
| img = rearrange( | |
| pred_token, | |
| 'b (nh nw) (c ph pw) -> b c (nh ph) (nw pw)', | |
| ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
| ) | |
| return img | |
| def plot_semseg_pred_masked(rgb, semseg_preds, semseg_gt, mask, ax=None, image_size=224): | |
| metadata = MetadataCatalog.get("coco_2017_val_panoptic") | |
| instance_mode = ColorMode.IMAGE | |
| img_viz = 255 * denormalize(rgb.detach().cpu())[0].permute(1,2,0) | |
| semseg = get_pred_with_input( | |
| semseg_gt.unsqueeze(1), | |
| semseg_preds.argmax(1).unsqueeze(1), | |
| mask, | |
| image_size=image_size//4, | |
| patch_size=4 | |
| ) | |
| semseg = F.interpolate(semseg.float(), size=image_size, mode='nearest')[0,0].long() | |
| visualizer = Visualizer(img_viz, metadata, instance_mode=instance_mode, scale=1) | |
| visualizer.draw_sem_seg(semseg) | |
| if ax is not None: | |
| ax.imshow(visualizer.get_output().get_image()) | |
| else: | |
| return visualizer.get_output().get_image() | |
| def plot_predictions(input_dict, preds, masks, image_size=224): | |
| masked_rgb = get_masked_image( | |
| denormalize(input_dict['rgb']), | |
| masks['rgb'], | |
| image_size=image_size, | |
| mask_value=1.0 | |
| )[0].permute(1,2,0).detach().cpu() | |
| masked_depth = get_masked_image( | |
| input_dict['depth'], | |
| masks['depth'], | |
| image_size=image_size, | |
| mask_value=np.nan | |
| )[0,0].detach().cpu() | |
| pred_rgb = denormalize(preds['rgb'])[0].permute(1,2,0).clamp(0,1) | |
| pred_depth = preds['depth'][0,0].detach().cpu() | |
| pred_rgb2 = get_pred_with_input( | |
| denormalize(input_dict['rgb']), | |
| denormalize(preds['rgb']).clamp(0,1), | |
| masks['rgb'], | |
| image_size=image_size | |
| )[0].permute(1,2,0).detach().cpu() | |
| pred_depth2 = get_pred_with_input( | |
| input_dict['depth'], | |
| preds['depth'], | |
| masks['depth'], | |
| image_size=image_size | |
| )[0,0].detach().cpu() | |
| fig = plt.figure(figsize=(10, 10)) | |
| grid = ImageGrid(fig, 111, nrows_ncols=(3, 3), axes_pad=0) | |
| grid[0].imshow(masked_rgb) | |
| grid[1].imshow(pred_rgb2) | |
| grid[2].imshow(denormalize(input_dict['rgb'])[0].permute(1,2,0).detach().cpu()) | |
| grid[3].imshow(masked_depth) | |
| grid[4].imshow(pred_depth2) | |
| grid[5].imshow(input_dict['depth'][0,0].detach().cpu()) | |
| plot_semseg_gt_masked(input_dict, masks['semseg'], grid[6], mask_value=1.0, image_size=image_size) | |
| plot_semseg_pred_masked(input_dict['rgb'], preds['semseg'], input_dict['semseg'], masks['semseg'], grid[7], image_size=image_size) | |
| plot_semseg_gt(input_dict, grid[8], image_size=image_size) | |
| for ax in grid: | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| fontsize = 16 | |
| grid[0].set_title('Masked inputs', fontsize=fontsize) | |
| grid[1].set_title('MultiMAE predictions', fontsize=fontsize) | |
| grid[2].set_title('Original Reference', fontsize=fontsize) | |
| grid[0].set_ylabel('RGB', fontsize=fontsize) | |
| grid[3].set_ylabel('Depth', fontsize=fontsize) | |
| grid[6].set_ylabel('Semantic', fontsize=fontsize) | |
| plt.savefig('./output.png', dpi=300, bbox_inches='tight') | |
| plt.close() | |
| def inference(img, num_tokens, manual_mode, num_rgb, num_depth, num_semseg, seed): | |
| num_tokens = int(588 * num_tokens / 100.0) | |
| num_rgb = int(196 * num_rgb / 100.0) | |
| num_depth = int(196 * num_depth / 100.0) | |
| num_semseg = int(196 * num_semseg / 100.0) | |
| im = Image.open(img) | |
| # Center crop and resize RGB | |
| image_size = 224 # Train resolution | |
| img = TF.center_crop(TF.to_tensor(im), min(im.size)) | |
| img = TF.resize(img, image_size, interpolation=TF.InterpolationMode.BICUBIC) | |
| # Predict depth and semseg | |
| depth = predict_depth(img) | |
| semseg = predict_semseg(img) | |
| # Pre-process RGB, depth and semseg to the MultiMAE input format | |
| input_dict = {} | |
| # Normalize RGB | |
| input_dict['rgb'] = TF.normalize(img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD).unsqueeze(0) | |
| # Normalize depth robustly | |
| trunc_depth = torch.sort(depth.flatten())[0] | |
| trunc_depth = trunc_depth[int(0.1 * trunc_depth.shape[0]): int(0.9 * trunc_depth.shape[0])] | |
| depth = (depth - trunc_depth.mean()[None,None,None]) / torch.sqrt(trunc_depth.var()[None,None,None] + 1e-6) | |
| input_dict['depth'] = depth.unsqueeze(0) | |
| # Downsample semantic segmentation | |
| stride = 4 | |
| semseg = TF.resize(semseg.unsqueeze(0), (semseg.shape[0] // stride, semseg.shape[1] // stride), interpolation=TF.InterpolationMode.NEAREST) | |
| input_dict['semseg'] = semseg | |
| # To GPU | |
| input_dict = {k: v.to(device) for k,v in input_dict.items()} | |
| if not manual_mode: | |
| # Randomly sample masks | |
| torch.manual_seed(int(time.time())) # Random mode is random | |
| preds, masks = multimae.forward( | |
| input_dict, | |
| mask_inputs=True, # True if forward pass should sample random masks | |
| num_encoded_tokens=num_tokens, | |
| alphas=1.0 | |
| ) | |
| else: | |
| # Randomly sample masks using the specified number of tokens per modality | |
| torch.manual_seed(int(seed)) # change seed to resample new mask | |
| task_masks = {domain: torch.ones(1,196).long().to(device) for domain in DOMAINS} | |
| selected_rgb_idxs = torch.randperm(196)[:num_rgb] | |
| selected_depth_idxs = torch.randperm(196)[:num_depth] | |
| selected_semseg_idxs = torch.randperm(196)[:num_semseg] | |
| task_masks['rgb'][:,selected_rgb_idxs] = 0 | |
| task_masks['depth'][:,selected_depth_idxs] = 0 | |
| task_masks['semseg'][:,selected_semseg_idxs] = 0 | |
| preds, masks = multimae.forward( | |
| input_dict, | |
| mask_inputs=True, | |
| task_masks=task_masks | |
| ) | |
| preds = {domain: pred.detach().cpu() for domain, pred in preds.items()} | |
| masks = {domain: mask.detach().cpu() for domain, mask in masks.items()} | |
| plot_predictions(input_dict, preds, masks) | |
| return 'output.png' | |
| title = "MultiMAE" | |
| description = "Gradio demo for MultiMAE: Multi-modal Multi-task Masked Autoencoders. \ | |
| Upload your own images or try one of the examples below to explore the multi-modal masked reconstruction of a pre-trained MultiMAE model. \ | |
| Uploaded images are pseudo labeled using a DPT trained on Omnidata depth, and a Mask2Former trained on COCO. \ | |
| Choose the percentage of visible tokens using the sliders below and see how MultiMAE reconstructs the modalities!" | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.01678' \ | |
| target='_blank'>MultiMAE: Multi-modal Multi-task Masked Autoencoders</a> | \ | |
| <a href='https://github.com/EPFL-VILAB/MultiMAE' target='_blank'>Github Repo</a></p>" | |
| css = '.output-image{height: 713px !important}' | |
| # Example images | |
| #os.system("wget https://i.imgur.com/c9ObJdK.jpg") | |
| #os.system("wget https://i.imgur.com/KTKgYKi.jpg") | |
| #os.system("wget https://i.imgur.com/lWYuRI7.jpg") | |
| examples = [ | |
| ['c9ObJdK.jpg', 15, False, 15, 15, 15, 0], | |
| ['KTKgYKi.jpg', 15, False, 15, 15, 15, 0], | |
| ['lWYuRI7.jpg', 15, False, 15, 15, 15, 0], | |
| ] | |
| gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.inputs.Image(label='RGB input image', type='filepath'), | |
| gr.inputs.Slider(label='Percentage of input tokens', default=15, step=0.1, minimum=0, maximum=100), | |
| gr.inputs.Checkbox(label='Manual mode: Check this to manually set the number of input tokens per modality using the sliders below', default=False), | |
| gr.inputs.Slider(label='Percentage of RGB input tokens (for manual mode only)', default=15, step=0.1, minimum=0, maximum=100), | |
| gr.inputs.Slider(label='Percentage of depth input tokens (for manual mode only)', default=15, step=0.1, minimum=0, maximum=100), | |
| gr.inputs.Slider(label='Percentage of semantic input tokens (for manual mode only)', default=15, step=0.1, minimum=0, maximum=100), | |
| gr.inputs.Number(label='Random seed: Change this to sample different masks (for manual mode only)', default=0), | |
| ], | |
| outputs=[ | |
| gr.outputs.Image(label='MultiMAE predictions', type='filepath') | |
| ], | |
| css=css, | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples | |
| ).launch(enable_queue=True) | |