Spaces:
Running
on
Zero
Running
on
Zero
Fix ZeroGPU compatibility
Browse files- README.md +1 -1
- app.py +80 -3
- requirements.txt +2 -3
- sam2/configs/sam2.1_hiera_b+.yaml +2 -2
- sam2/modeling/memory_attention.py +4 -2
- sam2/modeling/memory_encoder.py +7 -2
- sam2/modeling/sam2_base.py +5 -1
README.md
CHANGED
|
@@ -8,5 +8,5 @@ sdk_version: 5.48.0
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: true
|
| 10 |
license: bsd-3-clause
|
| 11 |
-
short_description: Unified Object Referring and Segmentation
|
| 12 |
---
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: true
|
| 10 |
license: bsd-3-clause
|
| 11 |
+
short_description: An MLLM for Unified Object Referring and Segmentation
|
| 12 |
---
|
app.py
CHANGED
|
@@ -49,11 +49,12 @@ function init() {
|
|
| 49 |
}
|
| 50 |
"""
|
| 51 |
|
| 52 |
-
model, processor = build_model(MODEL)
|
| 53 |
-
device = next(model.parameters()).device
|
| 54 |
|
| 55 |
sam2_transform = get_sam2_transform(model.config.sam2_image_size)
|
| 56 |
|
|
|
|
|
|
|
| 57 |
colors = sample_color()
|
| 58 |
color_map = {f'Target {i + 1}': f'#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}' for i, c in enumerate(colors * 255)}
|
| 59 |
color_map_light = {
|
|
@@ -100,6 +101,8 @@ def update_video(video, prompt_idx):
|
|
| 100 |
|
| 101 |
@spaces.GPU
|
| 102 |
def infer_seg(media, query, sample_frames=16, media_type=None):
|
|
|
|
|
|
|
| 103 |
if not media:
|
| 104 |
gr.Warning('Please upload an image or a video.')
|
| 105 |
return None, None, None
|
|
@@ -136,6 +139,8 @@ def infer_seg(media, query, sample_frames=16, media_type=None):
|
|
| 136 |
data['frames'] = [sam2_transform(frames).to(model.sam2.dtype)]
|
| 137 |
data['frame_size'] = [frames.shape[1:3]]
|
| 138 |
|
|
|
|
|
|
|
| 139 |
output_ids = model.generate(
|
| 140 |
**data.to(device),
|
| 141 |
do_sample=False,
|
|
@@ -182,6 +187,8 @@ infer_seg_video = partial(infer_seg, media_type='video')
|
|
| 182 |
|
| 183 |
@spaces.GPU
|
| 184 |
def infer_reg(blob, query, prompt_idx=1, video=None):
|
|
|
|
|
|
|
| 185 |
if blob['background'] is None:
|
| 186 |
gr.Warning('Please upload an image or a video.')
|
| 187 |
return
|
|
@@ -246,6 +253,8 @@ def infer_reg(blob, query, prompt_idx=1, video=None):
|
|
| 246 |
data['frame_size'] = [frames.shape[1:3]]
|
| 247 |
data['refer_mask'] = [refer_mask]
|
| 248 |
|
|
|
|
|
|
|
| 249 |
output_ids = model.generate(
|
| 250 |
**data.to(device),
|
| 251 |
do_sample=False,
|
|
@@ -274,7 +283,75 @@ def infer_reg(blob, query, prompt_idx=1, video=None):
|
|
| 274 |
|
| 275 |
|
| 276 |
def build_demo():
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
gr.HTML(HEADER)
|
| 279 |
|
| 280 |
with gr.Tab('Image Segmentation'):
|
|
|
|
| 49 |
}
|
| 50 |
"""
|
| 51 |
|
| 52 |
+
model, processor = build_model(MODEL, attn_implementation='sdpa')
|
|
|
|
| 53 |
|
| 54 |
sam2_transform = get_sam2_transform(model.config.sam2_image_size)
|
| 55 |
|
| 56 |
+
device = torch.device('cuda')
|
| 57 |
+
|
| 58 |
colors = sample_color()
|
| 59 |
color_map = {f'Target {i + 1}': f'#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}' for i, c in enumerate(colors * 255)}
|
| 60 |
color_map_light = {
|
|
|
|
| 101 |
|
| 102 |
@spaces.GPU
|
| 103 |
def infer_seg(media, query, sample_frames=16, media_type=None):
|
| 104 |
+
global model
|
| 105 |
+
|
| 106 |
if not media:
|
| 107 |
gr.Warning('Please upload an image or a video.')
|
| 108 |
return None, None, None
|
|
|
|
| 139 |
data['frames'] = [sam2_transform(frames).to(model.sam2.dtype)]
|
| 140 |
data['frame_size'] = [frames.shape[1:3]]
|
| 141 |
|
| 142 |
+
model = model.to(device)
|
| 143 |
+
|
| 144 |
output_ids = model.generate(
|
| 145 |
**data.to(device),
|
| 146 |
do_sample=False,
|
|
|
|
| 187 |
|
| 188 |
@spaces.GPU
|
| 189 |
def infer_reg(blob, query, prompt_idx=1, video=None):
|
| 190 |
+
global model
|
| 191 |
+
|
| 192 |
if blob['background'] is None:
|
| 193 |
gr.Warning('Please upload an image or a video.')
|
| 194 |
return
|
|
|
|
| 253 |
data['frame_size'] = [frames.shape[1:3]]
|
| 254 |
data['refer_mask'] = [refer_mask]
|
| 255 |
|
| 256 |
+
model = model.to(device)
|
| 257 |
+
|
| 258 |
output_ids = model.generate(
|
| 259 |
**data.to(device),
|
| 260 |
do_sample=False,
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
def build_demo():
|
| 286 |
+
apple_theme = gr.themes.Base(
|
| 287 |
+
primary_hue=gr.themes.colors.blue,
|
| 288 |
+
secondary_hue=gr.themes.colors.gray,
|
| 289 |
+
neutral_hue=gr.themes.colors.gray,
|
| 290 |
+
spacing_size=gr.themes.sizes.spacing_md,
|
| 291 |
+
radius_size=gr.themes.sizes.radius_md,
|
| 292 |
+
text_size=gr.themes.sizes.text_md,
|
| 293 |
+
font=["-apple-system", "BlinkMacSystemFont", "Segoe UI", "Helvetica Neue", "Arial", "sans-serif"],
|
| 294 |
+
font_mono=["SF Mono", "Monaco", "Inconsolata", "Roboto Mono", "monospace"]).set(
|
| 295 |
+
body_background_fill="white",
|
| 296 |
+
body_background_fill_dark="#000000",
|
| 297 |
+
block_background_fill="#ffffff",
|
| 298 |
+
block_background_fill_dark="#1c1c1e",
|
| 299 |
+
block_border_color="#d1d1d6",
|
| 300 |
+
block_border_color_dark="#38383a",
|
| 301 |
+
block_border_width="1px",
|
| 302 |
+
block_label_background_fill="transparent",
|
| 303 |
+
block_label_background_fill_dark="transparent",
|
| 304 |
+
block_label_text_color="#1d1d1f",
|
| 305 |
+
block_label_text_color_dark="#f5f5f7",
|
| 306 |
+
block_label_text_weight="600",
|
| 307 |
+
block_label_text_size="*text_sm",
|
| 308 |
+
block_title_text_weight="600",
|
| 309 |
+
block_title_text_color="#1d1d1f",
|
| 310 |
+
block_title_text_color_dark="#f5f5f7",
|
| 311 |
+
button_primary_background_fill="#007aff",
|
| 312 |
+
button_primary_background_fill_hover="#0051d5",
|
| 313 |
+
button_primary_background_fill_dark="#0a84ff",
|
| 314 |
+
button_primary_background_fill_hover_dark="#409cff",
|
| 315 |
+
button_primary_text_color="white",
|
| 316 |
+
button_primary_border_color="transparent",
|
| 317 |
+
button_secondary_background_fill="#f5f5f7",
|
| 318 |
+
button_secondary_background_fill_hover="#e8e8ed",
|
| 319 |
+
button_secondary_background_fill_dark="#2c2c2e",
|
| 320 |
+
button_secondary_background_fill_hover_dark="#3a3a3c",
|
| 321 |
+
button_secondary_text_color="#1d1d1f",
|
| 322 |
+
button_secondary_text_color_dark="#f5f5f7",
|
| 323 |
+
button_secondary_border_color="transparent",
|
| 324 |
+
button_cancel_background_fill="#ff3b30",
|
| 325 |
+
button_cancel_background_fill_hover="#ff453a",
|
| 326 |
+
button_cancel_text_color="white",
|
| 327 |
+
input_background_fill="#ffffff",
|
| 328 |
+
input_background_fill_dark="#1c1c1e",
|
| 329 |
+
input_border_color="#d1d1d6",
|
| 330 |
+
input_border_color_dark="#38383a",
|
| 331 |
+
input_border_color_focus="#007aff",
|
| 332 |
+
input_border_color_focus_dark="#0a84ff",
|
| 333 |
+
input_placeholder_color="#8e8e93",
|
| 334 |
+
input_placeholder_color_dark="#98989d",
|
| 335 |
+
slider_color="#007aff",
|
| 336 |
+
slider_color_dark="#0a84ff",
|
| 337 |
+
checkbox_background_color="#007aff",
|
| 338 |
+
checkbox_background_color_dark="#0a84ff",
|
| 339 |
+
checkbox_background_color_selected="#007aff",
|
| 340 |
+
checkbox_background_color_selected_dark="#0a84ff",
|
| 341 |
+
checkbox_border_color="#d1d1d6",
|
| 342 |
+
checkbox_border_color_dark="#38383a",
|
| 343 |
+
checkbox_border_color_selected="#007aff",
|
| 344 |
+
checkbox_border_color_selected_dark="#0a84ff",
|
| 345 |
+
panel_background_fill="#f5f5f7",
|
| 346 |
+
panel_background_fill_dark="#1c1c1e",
|
| 347 |
+
panel_border_color="#d1d1d6",
|
| 348 |
+
panel_border_color_dark="#38383a",
|
| 349 |
+
shadow_drop="0px 1px 3px 0px rgba(0,0,0,0.1)",
|
| 350 |
+
shadow_drop_lg="0px 10px 30px 0px rgba(0,0,0,0.15)",
|
| 351 |
+
loader_color="#007aff",
|
| 352 |
+
loader_color_dark="#0a84ff")
|
| 353 |
+
|
| 354 |
+
with gr.Blocks(title=TITLE, js=JS, theme=apple_theme) as demo:
|
| 355 |
gr.HTML(HEADER)
|
| 356 |
|
| 357 |
with gr.Tab('Image Segmentation'):
|
requirements.txt
CHANGED
|
@@ -20,13 +20,12 @@ sentencepiece==0.2.0
|
|
| 20 |
spaces==0.42.1
|
| 21 |
tensordict==0.9.1
|
| 22 |
termplotlib==0.3.9
|
|
|
|
|
|
|
| 23 |
transformers==4.53.3
|
| 24 |
triton==3.3.1
|
| 25 |
wandb==0.21.0
|
| 26 |
|
| 27 |
-
# torch==2.7.1+cu128
|
| 28 |
-
# torchvision==0.22.1+cu128
|
| 29 |
-
|
| 30 |
# https://github.com/Dao-AILab/flash-attention/pull/1751
|
| 31 |
# flash_attn==2.8.2
|
| 32 |
|
|
|
|
| 20 |
spaces==0.42.1
|
| 21 |
tensordict==0.9.1
|
| 22 |
termplotlib==0.3.9
|
| 23 |
+
torch==2.7.1
|
| 24 |
+
torchvision==0.22.1
|
| 25 |
transformers==4.53.3
|
| 26 |
triton==3.3.1
|
| 27 |
wandb==0.21.0
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
# https://github.com/Dao-AILab/flash-attention/pull/1751
|
| 30 |
# flash_attn==2.8.2
|
| 31 |
|
sam2/configs/sam2.1_hiera_b+.yaml
CHANGED
|
@@ -29,7 +29,7 @@ model:
|
|
| 29 |
d_model: 256
|
| 30 |
pos_enc_at_input: true
|
| 31 |
layer:
|
| 32 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 33 |
activation: relu
|
| 34 |
dim_feedforward: 2048
|
| 35 |
dropout: 0.1
|
|
@@ -74,7 +74,7 @@ model:
|
|
| 74 |
fuser:
|
| 75 |
_target_: sam2.modeling.memory_encoder.Fuser
|
| 76 |
layer:
|
| 77 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 78 |
dim: 256
|
| 79 |
kernel_size: 7
|
| 80 |
padding: 3
|
|
|
|
| 29 |
d_model: 256
|
| 30 |
pos_enc_at_input: true
|
| 31 |
layer:
|
| 32 |
+
# _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 33 |
activation: relu
|
| 34 |
dim_feedforward: 2048
|
| 35 |
dropout: 0.1
|
|
|
|
| 74 |
fuser:
|
| 75 |
_target_: sam2.modeling.memory_encoder.Fuser
|
| 76 |
layer:
|
| 77 |
+
# _target_: sam2.modeling.memory_encoder.CXBlock
|
| 78 |
dim: 256
|
| 79 |
kernel_size: 7
|
| 80 |
padding: 3
|
sam2/modeling/memory_attention.py
CHANGED
|
@@ -11,7 +11,7 @@ from torch import nn, Tensor
|
|
| 11 |
|
| 12 |
from sam2.modeling.sam.transformer import RoPEAttention
|
| 13 |
|
| 14 |
-
from sam2.modeling.sam2_utils import get_activation_fn
|
| 15 |
|
| 16 |
|
| 17 |
class MemoryAttentionLayer(nn.Module):
|
|
@@ -111,7 +111,9 @@ class MemoryAttention(nn.Module):
|
|
| 111 |
):
|
| 112 |
super().__init__()
|
| 113 |
self.d_model = d_model
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
self.num_layers = num_layers
|
| 116 |
self.norm = nn.LayerNorm(d_model)
|
| 117 |
self.pos_enc_at_input = pos_enc_at_input
|
|
|
|
| 11 |
|
| 12 |
from sam2.modeling.sam.transformer import RoPEAttention
|
| 13 |
|
| 14 |
+
from sam2.modeling.sam2_utils import get_activation_fn
|
| 15 |
|
| 16 |
|
| 17 |
class MemoryAttentionLayer(nn.Module):
|
|
|
|
| 111 |
):
|
| 112 |
super().__init__()
|
| 113 |
self.d_model = d_model
|
| 114 |
+
# NOTE: avoid using copy.deepcopy with zero3 or ZeroGPUs
|
| 115 |
+
self.layers = nn.ModuleList([MemoryAttentionLayer(**layer) for _ in range(num_layers)])
|
| 116 |
+
# self.layers = get_clones(layer, num_layers)
|
| 117 |
self.num_layers = num_layers
|
| 118 |
self.norm = nn.LayerNorm(d_model)
|
| 119 |
self.pos_enc_at_input = pos_enc_at_input
|
sam2/modeling/memory_encoder.py
CHANGED
|
@@ -11,7 +11,7 @@ import torch
|
|
| 11 |
import torch.nn as nn
|
| 12 |
import torch.nn.functional as F
|
| 13 |
|
| 14 |
-
from sam2.modeling.sam2_utils import DropPath,
|
| 15 |
|
| 16 |
|
| 17 |
class MaskDownSampler(nn.Module):
|
|
@@ -119,7 +119,9 @@ class Fuser(nn.Module):
|
|
| 119 |
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
| 120 |
super().__init__()
|
| 121 |
self.proj = nn.Identity()
|
| 122 |
-
|
|
|
|
|
|
|
| 123 |
|
| 124 |
if input_projection:
|
| 125 |
assert dim is not None
|
|
@@ -154,6 +156,9 @@ class MemoryEncoder(nn.Module):
|
|
| 154 |
if out_dim != in_dim:
|
| 155 |
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
def forward(
|
| 158 |
self,
|
| 159 |
pix_feat: torch.Tensor,
|
|
|
|
| 11 |
import torch.nn as nn
|
| 12 |
import torch.nn.functional as F
|
| 13 |
|
| 14 |
+
from sam2.modeling.sam2_utils import DropPath, LayerNorm2d
|
| 15 |
|
| 16 |
|
| 17 |
class MaskDownSampler(nn.Module):
|
|
|
|
| 119 |
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
| 120 |
super().__init__()
|
| 121 |
self.proj = nn.Identity()
|
| 122 |
+
# NOTE: avoid using copy.deepcopy with zero3 or ZeroGPUs
|
| 123 |
+
self.layers = nn.ModuleList([CXBlock(**layer) for _ in range(num_layers)])
|
| 124 |
+
# self.layers = get_clones(layer, num_layers)
|
| 125 |
|
| 126 |
if input_projection:
|
| 127 |
assert dim is not None
|
|
|
|
| 156 |
if out_dim != in_dim:
|
| 157 |
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
| 158 |
|
| 159 |
+
# save out_dim to avoid accessing model weights (breaks zero3)
|
| 160 |
+
self.out_dim = out_dim
|
| 161 |
+
|
| 162 |
def forward(
|
| 163 |
self,
|
| 164 |
pix_feat: torch.Tensor,
|
sam2/modeling/sam2_base.py
CHANGED
|
@@ -126,7 +126,11 @@ class SAM2Base(torch.nn.Module):
|
|
| 126 |
self.mem_dim = self.hidden_dim
|
| 127 |
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
|
| 128 |
# if there is compression of memories along channel dim
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
self.num_maskmem = num_maskmem # Number of memories accessible
|
| 131 |
# Temporal encoding of the memories
|
| 132 |
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
|
|
|
|
| 126 |
self.mem_dim = self.hidden_dim
|
| 127 |
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
|
| 128 |
# if there is compression of memories along channel dim
|
| 129 |
+
# NOTE: avoid directly accessing weights under zero3
|
| 130 |
+
self.mem_dim = self.memory_encoder.out_dim
|
| 131 |
+
if self.memory_encoder.out_proj.weight.shape[0] != 0:
|
| 132 |
+
assert self.mem_dim == self.memory_encoder.out_proj.weight.shape[0]
|
| 133 |
+
# self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
| 134 |
self.num_maskmem = num_maskmem # Number of memories accessible
|
| 135 |
# Temporal encoding of the memories
|
| 136 |
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
|