yeliudev commited on
Commit
41e934b
·
1 Parent(s): f880dff

Fix ZeroGPU compatibility

Browse files
README.md CHANGED
@@ -8,5 +8,5 @@ sdk_version: 5.48.0
8
  app_file: app.py
9
  pinned: true
10
  license: bsd-3-clause
11
- short_description: Unified Object Referring and Segmentation for Pixel-Level Visual Reasoning
12
  ---
 
8
  app_file: app.py
9
  pinned: true
10
  license: bsd-3-clause
11
+ short_description: An MLLM for Unified Object Referring and Segmentation
12
  ---
app.py CHANGED
@@ -49,11 +49,12 @@ function init() {
49
  }
50
  """
51
 
52
- model, processor = build_model(MODEL)
53
- device = next(model.parameters()).device
54
 
55
  sam2_transform = get_sam2_transform(model.config.sam2_image_size)
56
 
 
 
57
  colors = sample_color()
58
  color_map = {f'Target {i + 1}': f'#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}' for i, c in enumerate(colors * 255)}
59
  color_map_light = {
@@ -100,6 +101,8 @@ def update_video(video, prompt_idx):
100
 
101
  @spaces.GPU
102
  def infer_seg(media, query, sample_frames=16, media_type=None):
 
 
103
  if not media:
104
  gr.Warning('Please upload an image or a video.')
105
  return None, None, None
@@ -136,6 +139,8 @@ def infer_seg(media, query, sample_frames=16, media_type=None):
136
  data['frames'] = [sam2_transform(frames).to(model.sam2.dtype)]
137
  data['frame_size'] = [frames.shape[1:3]]
138
 
 
 
139
  output_ids = model.generate(
140
  **data.to(device),
141
  do_sample=False,
@@ -182,6 +187,8 @@ infer_seg_video = partial(infer_seg, media_type='video')
182
 
183
  @spaces.GPU
184
  def infer_reg(blob, query, prompt_idx=1, video=None):
 
 
185
  if blob['background'] is None:
186
  gr.Warning('Please upload an image or a video.')
187
  return
@@ -246,6 +253,8 @@ def infer_reg(blob, query, prompt_idx=1, video=None):
246
  data['frame_size'] = [frames.shape[1:3]]
247
  data['refer_mask'] = [refer_mask]
248
 
 
 
249
  output_ids = model.generate(
250
  **data.to(device),
251
  do_sample=False,
@@ -274,7 +283,75 @@ def infer_reg(blob, query, prompt_idx=1, video=None):
274
 
275
 
276
  def build_demo():
277
- with gr.Blocks(title=TITLE, js=JS) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  gr.HTML(HEADER)
279
 
280
  with gr.Tab('Image Segmentation'):
 
49
  }
50
  """
51
 
52
+ model, processor = build_model(MODEL, attn_implementation='sdpa')
 
53
 
54
  sam2_transform = get_sam2_transform(model.config.sam2_image_size)
55
 
56
+ device = torch.device('cuda')
57
+
58
  colors = sample_color()
59
  color_map = {f'Target {i + 1}': f'#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}' for i, c in enumerate(colors * 255)}
60
  color_map_light = {
 
101
 
102
  @spaces.GPU
103
  def infer_seg(media, query, sample_frames=16, media_type=None):
104
+ global model
105
+
106
  if not media:
107
  gr.Warning('Please upload an image or a video.')
108
  return None, None, None
 
139
  data['frames'] = [sam2_transform(frames).to(model.sam2.dtype)]
140
  data['frame_size'] = [frames.shape[1:3]]
141
 
142
+ model = model.to(device)
143
+
144
  output_ids = model.generate(
145
  **data.to(device),
146
  do_sample=False,
 
187
 
188
  @spaces.GPU
189
  def infer_reg(blob, query, prompt_idx=1, video=None):
190
+ global model
191
+
192
  if blob['background'] is None:
193
  gr.Warning('Please upload an image or a video.')
194
  return
 
253
  data['frame_size'] = [frames.shape[1:3]]
254
  data['refer_mask'] = [refer_mask]
255
 
256
+ model = model.to(device)
257
+
258
  output_ids = model.generate(
259
  **data.to(device),
260
  do_sample=False,
 
283
 
284
 
285
  def build_demo():
286
+ apple_theme = gr.themes.Base(
287
+ primary_hue=gr.themes.colors.blue,
288
+ secondary_hue=gr.themes.colors.gray,
289
+ neutral_hue=gr.themes.colors.gray,
290
+ spacing_size=gr.themes.sizes.spacing_md,
291
+ radius_size=gr.themes.sizes.radius_md,
292
+ text_size=gr.themes.sizes.text_md,
293
+ font=["-apple-system", "BlinkMacSystemFont", "Segoe UI", "Helvetica Neue", "Arial", "sans-serif"],
294
+ font_mono=["SF Mono", "Monaco", "Inconsolata", "Roboto Mono", "monospace"]).set(
295
+ body_background_fill="white",
296
+ body_background_fill_dark="#000000",
297
+ block_background_fill="#ffffff",
298
+ block_background_fill_dark="#1c1c1e",
299
+ block_border_color="#d1d1d6",
300
+ block_border_color_dark="#38383a",
301
+ block_border_width="1px",
302
+ block_label_background_fill="transparent",
303
+ block_label_background_fill_dark="transparent",
304
+ block_label_text_color="#1d1d1f",
305
+ block_label_text_color_dark="#f5f5f7",
306
+ block_label_text_weight="600",
307
+ block_label_text_size="*text_sm",
308
+ block_title_text_weight="600",
309
+ block_title_text_color="#1d1d1f",
310
+ block_title_text_color_dark="#f5f5f7",
311
+ button_primary_background_fill="#007aff",
312
+ button_primary_background_fill_hover="#0051d5",
313
+ button_primary_background_fill_dark="#0a84ff",
314
+ button_primary_background_fill_hover_dark="#409cff",
315
+ button_primary_text_color="white",
316
+ button_primary_border_color="transparent",
317
+ button_secondary_background_fill="#f5f5f7",
318
+ button_secondary_background_fill_hover="#e8e8ed",
319
+ button_secondary_background_fill_dark="#2c2c2e",
320
+ button_secondary_background_fill_hover_dark="#3a3a3c",
321
+ button_secondary_text_color="#1d1d1f",
322
+ button_secondary_text_color_dark="#f5f5f7",
323
+ button_secondary_border_color="transparent",
324
+ button_cancel_background_fill="#ff3b30",
325
+ button_cancel_background_fill_hover="#ff453a",
326
+ button_cancel_text_color="white",
327
+ input_background_fill="#ffffff",
328
+ input_background_fill_dark="#1c1c1e",
329
+ input_border_color="#d1d1d6",
330
+ input_border_color_dark="#38383a",
331
+ input_border_color_focus="#007aff",
332
+ input_border_color_focus_dark="#0a84ff",
333
+ input_placeholder_color="#8e8e93",
334
+ input_placeholder_color_dark="#98989d",
335
+ slider_color="#007aff",
336
+ slider_color_dark="#0a84ff",
337
+ checkbox_background_color="#007aff",
338
+ checkbox_background_color_dark="#0a84ff",
339
+ checkbox_background_color_selected="#007aff",
340
+ checkbox_background_color_selected_dark="#0a84ff",
341
+ checkbox_border_color="#d1d1d6",
342
+ checkbox_border_color_dark="#38383a",
343
+ checkbox_border_color_selected="#007aff",
344
+ checkbox_border_color_selected_dark="#0a84ff",
345
+ panel_background_fill="#f5f5f7",
346
+ panel_background_fill_dark="#1c1c1e",
347
+ panel_border_color="#d1d1d6",
348
+ panel_border_color_dark="#38383a",
349
+ shadow_drop="0px 1px 3px 0px rgba(0,0,0,0.1)",
350
+ shadow_drop_lg="0px 10px 30px 0px rgba(0,0,0,0.15)",
351
+ loader_color="#007aff",
352
+ loader_color_dark="#0a84ff")
353
+
354
+ with gr.Blocks(title=TITLE, js=JS, theme=apple_theme) as demo:
355
  gr.HTML(HEADER)
356
 
357
  with gr.Tab('Image Segmentation'):
requirements.txt CHANGED
@@ -20,13 +20,12 @@ sentencepiece==0.2.0
20
  spaces==0.42.1
21
  tensordict==0.9.1
22
  termplotlib==0.3.9
 
 
23
  transformers==4.53.3
24
  triton==3.3.1
25
  wandb==0.21.0
26
 
27
- # torch==2.7.1+cu128
28
- # torchvision==0.22.1+cu128
29
-
30
  # https://github.com/Dao-AILab/flash-attention/pull/1751
31
  # flash_attn==2.8.2
32
 
 
20
  spaces==0.42.1
21
  tensordict==0.9.1
22
  termplotlib==0.3.9
23
+ torch==2.7.1
24
+ torchvision==0.22.1
25
  transformers==4.53.3
26
  triton==3.3.1
27
  wandb==0.21.0
28
 
 
 
 
29
  # https://github.com/Dao-AILab/flash-attention/pull/1751
30
  # flash_attn==2.8.2
31
 
sam2/configs/sam2.1_hiera_b+.yaml CHANGED
@@ -29,7 +29,7 @@ model:
29
  d_model: 256
30
  pos_enc_at_input: true
31
  layer:
32
- _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
33
  activation: relu
34
  dim_feedforward: 2048
35
  dropout: 0.1
@@ -74,7 +74,7 @@ model:
74
  fuser:
75
  _target_: sam2.modeling.memory_encoder.Fuser
76
  layer:
77
- _target_: sam2.modeling.memory_encoder.CXBlock
78
  dim: 256
79
  kernel_size: 7
80
  padding: 3
 
29
  d_model: 256
30
  pos_enc_at_input: true
31
  layer:
32
+ # _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
33
  activation: relu
34
  dim_feedforward: 2048
35
  dropout: 0.1
 
74
  fuser:
75
  _target_: sam2.modeling.memory_encoder.Fuser
76
  layer:
77
+ # _target_: sam2.modeling.memory_encoder.CXBlock
78
  dim: 256
79
  kernel_size: 7
80
  padding: 3
sam2/modeling/memory_attention.py CHANGED
@@ -11,7 +11,7 @@ from torch import nn, Tensor
11
 
12
  from sam2.modeling.sam.transformer import RoPEAttention
13
 
14
- from sam2.modeling.sam2_utils import get_activation_fn, get_clones
15
 
16
 
17
  class MemoryAttentionLayer(nn.Module):
@@ -111,7 +111,9 @@ class MemoryAttention(nn.Module):
111
  ):
112
  super().__init__()
113
  self.d_model = d_model
114
- self.layers = get_clones(layer, num_layers)
 
 
115
  self.num_layers = num_layers
116
  self.norm = nn.LayerNorm(d_model)
117
  self.pos_enc_at_input = pos_enc_at_input
 
11
 
12
  from sam2.modeling.sam.transformer import RoPEAttention
13
 
14
+ from sam2.modeling.sam2_utils import get_activation_fn
15
 
16
 
17
  class MemoryAttentionLayer(nn.Module):
 
111
  ):
112
  super().__init__()
113
  self.d_model = d_model
114
+ # NOTE: avoid using copy.deepcopy with zero3 or ZeroGPUs
115
+ self.layers = nn.ModuleList([MemoryAttentionLayer(**layer) for _ in range(num_layers)])
116
+ # self.layers = get_clones(layer, num_layers)
117
  self.num_layers = num_layers
118
  self.norm = nn.LayerNorm(d_model)
119
  self.pos_enc_at_input = pos_enc_at_input
sam2/modeling/memory_encoder.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
 
14
- from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
 
16
 
17
  class MaskDownSampler(nn.Module):
@@ -119,7 +119,9 @@ class Fuser(nn.Module):
119
  def __init__(self, layer, num_layers, dim=None, input_projection=False):
120
  super().__init__()
121
  self.proj = nn.Identity()
122
- self.layers = get_clones(layer, num_layers)
 
 
123
 
124
  if input_projection:
125
  assert dim is not None
@@ -154,6 +156,9 @@ class MemoryEncoder(nn.Module):
154
  if out_dim != in_dim:
155
  self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
156
 
 
 
 
157
  def forward(
158
  self,
159
  pix_feat: torch.Tensor,
 
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
 
14
+ from sam2.modeling.sam2_utils import DropPath, LayerNorm2d
15
 
16
 
17
  class MaskDownSampler(nn.Module):
 
119
  def __init__(self, layer, num_layers, dim=None, input_projection=False):
120
  super().__init__()
121
  self.proj = nn.Identity()
122
+ # NOTE: avoid using copy.deepcopy with zero3 or ZeroGPUs
123
+ self.layers = nn.ModuleList([CXBlock(**layer) for _ in range(num_layers)])
124
+ # self.layers = get_clones(layer, num_layers)
125
 
126
  if input_projection:
127
  assert dim is not None
 
156
  if out_dim != in_dim:
157
  self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
158
 
159
+ # save out_dim to avoid accessing model weights (breaks zero3)
160
+ self.out_dim = out_dim
161
+
162
  def forward(
163
  self,
164
  pix_feat: torch.Tensor,
sam2/modeling/sam2_base.py CHANGED
@@ -126,7 +126,11 @@ class SAM2Base(torch.nn.Module):
126
  self.mem_dim = self.hidden_dim
127
  if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
128
  # if there is compression of memories along channel dim
129
- self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
 
 
 
 
130
  self.num_maskmem = num_maskmem # Number of memories accessible
131
  # Temporal encoding of the memories
132
  self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
 
126
  self.mem_dim = self.hidden_dim
127
  if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
128
  # if there is compression of memories along channel dim
129
+ # NOTE: avoid directly accessing weights under zero3
130
+ self.mem_dim = self.memory_encoder.out_dim
131
+ if self.memory_encoder.out_proj.weight.shape[0] != 0:
132
+ assert self.mem_dim == self.memory_encoder.out_proj.weight.shape[0]
133
+ # self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
134
  self.num_maskmem = num_maskmem # Number of memories accessible
135
  # Temporal encoding of the memories
136
  self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))