Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from sklearn.decomposition import PCA | |
| from torchvision import transforms as T | |
| from sklearn.preprocessing import MinMaxScaler | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') | |
| dino.eval() | |
| dino.to(device) | |
| pca = PCA(n_components=3) | |
| scaler = MinMaxScaler(clip=True) | |
| def plot_img(img_array: np.array) -> go.Figure: | |
| fig = px.imshow(img_array) | |
| fig.update_layout( | |
| xaxis=dict(showticklabels=False), | |
| yaxis=dict(showticklabels=False) | |
| ) | |
| return fig | |
| def app_fn( | |
| img: np.ndarray, | |
| threshold: float, | |
| object_larger_than_bg: bool | |
| ) -> go.Figure: | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| patch_h = 40 | |
| patch_w = 40 | |
| transform = T.Compose([ | |
| T.Resize((14 * patch_h, 14 * patch_w)), | |
| T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
| ]) | |
| img = torch.from_numpy(img).type(torch.float).permute(2, 0, 1) / 255 | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| out = dino.forward_features(img_tensor) | |
| features = out["x_prenorm"][:, 1:, :] | |
| features = features.squeeze(0) | |
| features = features.cpu().numpy() | |
| pca_features = pca.fit_transform(features) | |
| pca_features = scaler.fit_transform(pca_features) | |
| if object_larger_than_bg: | |
| pca_features_bg = pca_features[:, 0] > threshold | |
| else: | |
| pca_features_bg = pca_features[:, 0] < threshold | |
| pca_features_fg = ~pca_features_bg | |
| pca_features_fg_seg = pca.fit_transform(features[pca_features_fg]) | |
| pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg) | |
| pca_features_rgb = np.zeros((patch_h * patch_w, 3)) | |
| pca_features_rgb[pca_features_bg] = 0 | |
| pca_features_rgb[pca_features_fg] = pca_features_fg_seg | |
| pca_features_rgb = pca_features_rgb.reshape(patch_h, patch_w, 3) | |
| fig_pca = plot_img(pca_features_rgb) | |
| return fig_pca | |
| if __name__=="__main__": | |
| title = "π¦ DINOv2 Features Visualization π¦" | |
| with gr.Blocks(title=title) as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown( | |
| """ | |
| ### This app visualizes the features extracted by [DINOv2](https://arxiv.org/pdf/2304.07193.pdf) model. \ | |
| To create the visualizations we use a 2-step PCA. \ | |
| The first step we reduce the features to 3 dimensions and then threshold the first component \ | |
| to segment the background and foreground. Then, we run the second PCA on the foreground features \ | |
| so we can visualize foreground objects as RGB. | |
| [Paper](https://arxiv.org/pdf/2304.07193.pdf) | |
| [Github](https://github.com/facebookresearch/dinov2) | |
| Created by: [Eduardo Pacheco](https://github.com/EduardoPach) | |
| """ | |
| ) | |
| with gr.Row(): | |
| threshold = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05, label="Threshold") | |
| object_larger_than_bg = gr.Checkbox(label="Object Larger than Background", value=False) | |
| btn = gr.Button(label="Visualize") | |
| with gr.Row(): | |
| img = gr.Image() | |
| fig_pca = gr.Plot(label="PCA Features") | |
| btn.click(fn=app_fn, inputs=[img, threshold, object_larger_than_bg], outputs=[fig_pca]) | |
| examples = gr.Examples( | |
| examples=[ | |
| ["assets/neca-the-cat.jpeg", 0.6, True], | |
| ["assets/dog.png", 0.7, False] | |
| ], | |
| inputs=[img, threshold, object_larger_than_bg], | |
| outputs=[fig_pca], | |
| fn=app_fn, | |
| cache_examples=True | |
| ) | |
| demo.queue(max_size=5).launch() | |