lorebianchi98 commited on
Commit
593b176
·
1 Parent(s): bd34a5b

First commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModel
5
+ import os
6
+ import torchvision.transforms.functional as F
7
+ from src.plot import plot_qualitative
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ import base64
11
+ from pathlib import Path
12
+
13
+ # --- Setup ---
14
+ os.environ["GRADIO_TEMP_DIR"] = "tmp"
15
+ os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # --- Load Models ---
20
+ model_B = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTB", trust_remote_code=True).to(device).eval()
21
+ model_L = AutoModel.from_pretrained("lorebianchi98/Talk2DINO-ViTL", trust_remote_code=True).to(device).eval()
22
+ MODELS = {"ViT-B": model_B, "ViT-L": model_L}
23
+
24
+ # --- Example Setup ---
25
+ EXAMPLE_IMAGES_DIR = Path("examples").resolve()
26
+ example_images = sorted([str(p) for p in EXAMPLE_IMAGES_DIR.glob("*.png")])
27
+
28
+ DEFAULT_CLASSES = {
29
+ "0_pikachu.png": "pikachu,traffic_sign,forest,road,cap",
30
+ "1_jurassic.png": "dinosaur,smoke,vegetation,person",
31
+ "2_falcon.png": "millenium_falcon,space"
32
+ }
33
+
34
+ DEFAULT_BG_THRESH = 0.55
35
+ DEFAULT_BG_CLEAN = False
36
+
37
+
38
+ # --- Inference Function ---
39
+ def talk2dino_infer(input_image, class_text, selected_model="ViT-B",
40
+ apply_pamr=True, with_background=False, bg_thresh=0.55, apply_bg_clean=False):
41
+ if input_image is None:
42
+ raise gr.Error("No image detected. Please select or upload an image first.")
43
+
44
+ model = MODELS[selected_model]
45
+ text = [t.strip() for t in class_text.replace("_", " ").split(",") if t.strip()]
46
+ if len(text) == 0:
47
+ raise gr.Error("Please provide at least one class name before generating segmentation.")
48
+
49
+ img = F.to_tensor(input_image).unsqueeze(0).float().to(device) * 255.0
50
+
51
+ # Generate color palette
52
+ palette = [
53
+ [255, 0, 0],
54
+ [255, 255, 0],
55
+ [0, 255, 0],
56
+ [0, 255, 255],
57
+ [0, 0, 255],
58
+ [128, 128, 128]
59
+ ]
60
+ if len(text) > len(palette):
61
+ for _ in range(len(text) - len(palette)):
62
+ palette.append([np.random.randint(0, 255) for _ in range(3)])
63
+
64
+ if with_background:
65
+ palette.insert(0, [0, 0, 0])
66
+ model.with_bg_clean = apply_bg_clean
67
+
68
+ with torch.no_grad():
69
+ text_emb = model.build_dataset_class_tokens("sub_imagenet_template", text)
70
+ text_emb = model.build_text_embedding(text_emb)
71
+ mask, _ = model.generate_masks(img, img_metas=None, text_emb=text_emb,
72
+ classnames=text, apply_pamr=apply_pamr)
73
+ if with_background:
74
+ background = torch.ones_like(mask[:, :1]) * bg_thresh
75
+ mask = torch.cat([background, mask], dim=1)
76
+ mask = mask.argmax(dim=1)
77
+
78
+ if with_background:
79
+ text = ["background"] + text
80
+
81
+ img_out = plot_qualitative(
82
+ img.cpu()[0].permute(1, 2, 0).int().numpy(),
83
+ mask.cpu()[0].numpy(),
84
+ palette,
85
+ texts=text
86
+ )
87
+ return img_out
88
+
89
+
90
+ # --- Gradio Interface ---
91
+ with gr.Blocks(title="Talk2DINO Demo") as demo:
92
+
93
+ # Overview Section
94
+ overview_img = Image.open("assets/overview.png").convert("RGB")
95
+ overview_img = overview_img.resize((int(overview_img.width * 0.7), int(overview_img.height * 0.7)))
96
+ buffered = BytesIO()
97
+ overview_img.save(buffered, format="PNG")
98
+ img_str = base64.b64encode(buffered.getvalue()).decode()
99
+
100
+ gr.Markdown(f"""
101
+ # 🦖 Talk2DINO Demo
102
+
103
+
104
+ ![Overview](data:image/png;base64,{img_str})
105
+
106
+ <div style="font-size: x-large; white-space: nowrap; display: flex; align-items: center; gap: 10px;">
107
+ <a href="https://lorebianchi98.github.io/Talk2DINO/" target="_blank">Project page</a>
108
+ <span>|</span>
109
+ <a href="http://arxiv.org/abs/2411.19331" target="_blank">
110
+ <img src="https://img.shields.io/badge/arXiv-2411.19331-b31b1b.svg" style="height:28px; vertical-align:middle;">
111
+ </a>
112
+ <span>|</span>
113
+ <a href="https://huggingface.co/papers/2411.19331" target="_blank">
114
+ <img src="https://img.shields.io/badge/HuggingFace-Paper-yellow.svg" style="height:28px; vertical-align:middle;">
115
+ </a>
116
+ </div>
117
+
118
+ ---
119
+
120
+ This demo allows you to **perform open-vocabulary semantic segmentation** on images using Talk2DINO.
121
+
122
+ **How to use:**
123
+ 1. Upload an image or select one from the example gallery.
124
+ 2. Enter a comma-separated list of class names you want to segment (e.g., `pikachu, forest, road`).
125
+ 3. Adjust optional parameters:
126
+ - **Model**: choose between ViT-B and ViT-L
127
+ - **Apply PAMR**: refine masks after initial prediction
128
+ - **Include Background**: visualize background areas
129
+ - **Background Threshold**: threshold for background intensity
130
+ - **Apply Background Cleaning**: remove background noise when enabled
131
+ 4. Click **Generate Segmentation** to see the segmentation overlay.
132
+ """)
133
+
134
+ with gr.Row():
135
+ with gr.Column():
136
+ input_image = gr.Image(type="pil", label="Input Image", value=None)
137
+ if example_images:
138
+ example_gallery = gr.Gallery(
139
+ value=example_images,
140
+ label="Or select from example images",
141
+ show_label=True,
142
+ columns=3,
143
+ object_fit="contain",
144
+ height="auto"
145
+ )
146
+
147
+ with gr.Column():
148
+ model_selector = gr.Dropdown(
149
+ label="Select Model",
150
+ choices=["ViT-B", "ViT-L"],
151
+ value="ViT-B"
152
+ )
153
+ class_text = gr.Textbox(
154
+ label="Comma-separated Classes",
155
+ value="",
156
+ placeholder="e.g. pikachu, road, tree"
157
+ )
158
+ apply_pamr = gr.Checkbox(label="Apply PAMR", value=True)
159
+ with_background = gr.Checkbox(label="Include Background", value=False)
160
+ bg_thresh = gr.Slider(
161
+ label="Background Threshold",
162
+ minimum=0.0,
163
+ maximum=1.0,
164
+ value=DEFAULT_BG_THRESH,
165
+ step=0.01,
166
+ interactive=False
167
+ )
168
+ apply_bg_clean = gr.Checkbox(
169
+ label="Apply Background Cleaning",
170
+ value=False,
171
+ interactive=False
172
+ )
173
+
174
+ generate_button = gr.Button("🚀 Generate Segmentation", interactive=False)
175
+ output_image = gr.Image(type="numpy", label="Segmentation Overlay")
176
+
177
+ # --- Background Option Toggle ---
178
+ def toggle_bg_options(with_bg):
179
+ if with_bg:
180
+ return gr.update(interactive=True, value=DEFAULT_BG_THRESH), gr.update(interactive=True, value=DEFAULT_BG_CLEAN)
181
+ else:
182
+ return gr.update(interactive=False, value=DEFAULT_BG_THRESH), gr.update(interactive=False, value=DEFAULT_BG_CLEAN)
183
+
184
+ with_background.change(
185
+ fn=toggle_bg_options,
186
+ inputs=[with_background],
187
+ outputs=[bg_thresh, apply_bg_clean]
188
+ )
189
+
190
+ # --- Enable Button Only When Classes Exist ---
191
+ def enable_generate_button(text):
192
+ return gr.update(interactive=bool(text.strip()))
193
+
194
+ class_text.change(fn=enable_generate_button, inputs=[class_text], outputs=[generate_button])
195
+
196
+ # --- Example Image Loader ---
197
+ def load_example_image(evt: gr.SelectData):
198
+ selected = evt.value["image"]
199
+ if isinstance(selected, str):
200
+ img = Image.open(selected).convert("RGB")
201
+ filename = Path(selected).name
202
+ elif isinstance(selected, dict):
203
+ img = Image.open(selected["path"]).convert("RGB")
204
+ filename = Path(selected["path"]).name
205
+ else:
206
+ img = Image.fromarray(selected)
207
+ filename = None
208
+ class_val = DEFAULT_CLASSES.get(filename, "")
209
+ return img, class_val, gr.update(interactive=bool(class_val.strip()))
210
+
211
+ if example_images:
212
+ example_gallery.select(
213
+ fn=load_example_image,
214
+ inputs=[],
215
+ outputs=[input_image, class_text, generate_button]
216
+ )
217
+
218
+ # --- User Upload Reset ---
219
+ def on_upload_image(img):
220
+ if img is None:
221
+ return None, "", gr.update(interactive=False)
222
+ return img, "", gr.update(interactive=False)
223
+
224
+ input_image.upload(
225
+ fn=on_upload_image,
226
+ inputs=[input_image],
227
+ outputs=[input_image, class_text, generate_button]
228
+ )
229
+
230
+ # --- Generate Segmentation ---
231
+ generate_button.click(
232
+ talk2dino_infer,
233
+ inputs=[input_image, class_text, model_selector, apply_pamr, with_background, bg_thresh, apply_bg_clean],
234
+ outputs=output_image
235
+ )
236
+
237
+ demo.launch(server_port=7870, share=False)
assets/overview.png ADDED

Git LFS Details

  • SHA256: fcefc8c68cf95a966f769852ea51e7efa7ea2398b21936cacaa2eb5c6fff0358
  • Pointer size: 130 Bytes
  • Size of remote file: 89.5 kB
examples/0_pikachu.png ADDED

Git LFS Details

  • SHA256: 7a5efcbce11e4a293ebb743c8857c0654c6bce0b89beb59f6ca71d64311c4106
  • Pointer size: 131 Bytes
  • Size of remote file: 377 kB
examples/1_jurassic.png ADDED

Git LFS Details

  • SHA256: 804a011b7b5e312dda9a6a57ccb32947d6f74413b6311170dd96fcf10b792705
  • Pointer size: 131 Bytes
  • Size of remote file: 364 kB
examples/2_falcon.png ADDED

Git LFS Details

  • SHA256: 80a818fbce8acd2bda1e570dc6c0775d2100d0a227c768c5d7ff83275870709c
  • Pointer size: 131 Bytes
  • Size of remote file: 297 kB
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/openai/CLIP.git
2
+ matplotlib
3
+ opencv-python
4
+ pyyaml
5
+ requests
6
+ scikit-image
7
+ tqdm
8
+ omegaconf
9
+ einops
10
+ timm
11
+ transformers
12
+ webdataset
13
+ numpy==1.24.1
14
+ jaxtyping
15
+ rich
16
+ scikit-learn
17
+ safetensors==0.4.3
18
+ gradio
19
+ torch
20
+ torchvision
src/plot.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from matplotlib.patches import Rectangle
4
+
5
+ def plot_qualitative(image, sim, palette, texts, alpha=0.6, legend_height=0.1):
6
+ """
7
+ image: HxWx3 uint8 image
8
+ sim: HxW segmentation mask with integer class IDs
9
+ palette: list of [R,G,B] colors
10
+ texts: list of class names corresponding to IDs
11
+ alpha: transparency for overlay
12
+ legend_height: fraction of figure height reserved for legend
13
+ """
14
+
15
+ qualitative_plot = np.zeros((sim.shape[0], sim.shape[1], 3), dtype=np.uint8)
16
+ for j in np.unique(sim):
17
+ qualitative_plot[sim == j] = np.array(palette[j])
18
+
19
+ # Normalize images for alpha blending
20
+ img_float = image.astype(np.float32) / 255.0
21
+ overlay_float = qualitative_plot.astype(np.float32) / 255.0
22
+
23
+ # Figure with space for legend
24
+ fig_height = img_float.shape[0] / 100
25
+ fig_width = img_float.shape[1] / 100
26
+ fig = plt.figure(figsize=(fig_width, fig_height + legend_height * fig_height), dpi=100)
27
+
28
+ # Main image axis
29
+ ax_img = fig.add_axes([0, legend_height, 1, 1 - legend_height])
30
+ ax_img.imshow(img_float)
31
+ ax_img.imshow(overlay_float, alpha=alpha)
32
+ ax_img.axis("off")
33
+
34
+ # Legend axis
35
+ ax_legend = fig.add_axes([0, 0, 1, legend_height])
36
+ ax_legend.axis("off")
37
+
38
+ # Draw legend rectangles
39
+ unique_classes = np.unique(sim)
40
+ num_classes = len(unique_classes)
41
+ for idx, cls in enumerate(unique_classes):
42
+ color = np.array(palette[cls]) / 255.0
43
+ # Rectangle: (x, y), width, height
44
+ rect_width = 1 / num_classes * 0.8
45
+ rect = Rectangle((idx / num_classes, 0.1), rect_width, 0.6, facecolor=color)
46
+ ax_legend.add_patch(rect)
47
+ # Add text label centered on rectangle
48
+ ax_legend.text(idx / num_classes + rect_width / 2, 0.8, texts[cls],
49
+ ha='center', va='bottom', fontsize=10)
50
+
51
+ # Extract as NumPy array
52
+ fig.canvas.draw()
53
+ buf = np.asarray(fig.canvas.renderer.buffer_rgba())
54
+ img_array = (buf[:, :, :3]).copy() # drop alpha
55
+
56
+ plt.close(fig)
57
+ return img_array