Spaces:
merve
/
Running on Zero

adding mps support for Apple silicon

#1
by clementlr - opened
Files changed (1) hide show
  1. app.py +42 -10
app.py CHANGED
@@ -21,6 +21,7 @@ import spaces
21
  import torch
22
 
23
  from moviepy.editor import ImageSequenceClip
 
24
  from PIL import Image
25
  from sam2.build_sam import build_sam2_video_predictor
26
 
@@ -72,7 +73,9 @@ examples = [
72
  OBJ_ID = 0
73
  sam2_checkpoint = "checkpoints/edgetam.pt"
74
  model_cfg = "edgetam.yaml"
 
75
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
 
76
 
77
 
78
  def get_video_fps(video_path):
@@ -340,11 +343,26 @@ def propagate_to_all(
340
  input_points,
341
  inference_state,
342
  ):
343
- if torch.cuda.get_device_properties(0).major >= 8:
344
- torch.backends.cuda.matmul.allow_tf32 = True
345
- torch.backends.cudnn.allow_tf32 = True
346
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
347
- predictor.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  if inference_state:
349
  inference_state["device"] = predictor.device
350
 
@@ -374,15 +392,16 @@ def propagate_to_all(
374
  out_mask = video_segments[out_frame_idx][OBJ_ID]
375
  mask_image = show_mask(out_mask)
376
  output_frame = Image.alpha_composite(transparent_background, mask_image)
 
377
  output_frame = np.array(output_frame)
378
  output_frames.append(output_frame)
379
 
380
- torch.cuda.empty_cache()
 
381
 
382
  # Create a video clip from the image sequence
383
  original_fps = get_video_fps(video_in)
384
- fps = original_fps # Frames per second
385
- clip = ImageSequenceClip(output_frames, fps=fps)
386
  # Write the result to a file
387
  unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
388
  final_vid_output_path = f"output_video_{unique_id}.mp4"
@@ -390,8 +409,21 @@ def propagate_to_all(
390
  tempfile.gettempdir(), final_vid_output_path
391
  )
392
 
393
- # Write the result to a file
394
- clip.write_videofile(final_vid_output_path, codec="libx264")
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  return gr.update(value=final_vid_output_path)
397
 
 
21
  import torch
22
 
23
  from moviepy.editor import ImageSequenceClip
24
+ from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter # adding this for MPS compatibility
25
  from PIL import Image
26
  from sam2.build_sam import build_sam2_video_predictor
27
 
 
73
  OBJ_ID = 0
74
  sam2_checkpoint = "checkpoints/edgetam.pt"
75
  model_cfg = "edgetam.yaml"
76
+ device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") # MPS support
77
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
78
+ predictor.to(device)
79
 
80
 
81
  def get_video_fps(video_path):
 
343
  input_points,
344
  inference_state,
345
  ):
346
+ # Set boolean for cuda or mps support
347
+ use_cuda = torch.cuda.is_available()
348
+ use_mps = torch.backends.mps.is_available() and torch.backends.mps.is_built()
349
+
350
+ if use_cuda:
351
+ if torch.cuda.get_device_properties(0).major >= 8:
352
+ torch.backends.cuda.matmul.allow_tf32 = True
353
+ torch.backends.cudnn.allow_tf32 = True
354
+ autocast_kwargs = dict(device_type="cuda", dtype=torch.bfloat16)
355
+
356
+ elif use_mps:
357
+ autocast_kwargs = dict(device_type="mps", dtype=torch.float16)
358
+
359
+ with torch.autocast(**autocast_kwargs):
360
+
361
+ if use_cuda:
362
+ predictor.to("cuda")
363
+ elif use_mps:
364
+ predictor.to("mps")
365
+
366
  if inference_state:
367
  inference_state["device"] = predictor.device
368
 
 
392
  out_mask = video_segments[out_frame_idx][OBJ_ID]
393
  mask_image = show_mask(out_mask)
394
  output_frame = Image.alpha_composite(transparent_background, mask_image)
395
+ output_frame = output_frame.convert("RGB")
396
  output_frame = np.array(output_frame)
397
  output_frames.append(output_frame)
398
 
399
+ if use_cuda:
400
+ torch.cuda.empty_cache()
401
 
402
  # Create a video clip from the image sequence
403
  original_fps = get_video_fps(video_in)
404
+ fps = float(original_fps) # Frames per second
 
405
  # Write the result to a file
406
  unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
407
  final_vid_output_path = f"output_video_{unique_id}.mp4"
 
409
  tempfile.gettempdir(), final_vid_output_path
410
  )
411
 
412
+ # Write the result to a file using moviepy ImageSequenceClip
413
+ if use_cuda:
414
+ clip = ImageSequenceClip(output_frames, fps=fps)
415
+ clip.write_videofile(final_vid_output_path, codec="libx264")
416
+
417
+ # Write the result to a file using moviepy FFMPEG_VideoWriter for MPS compatibility
418
+ elif use_mps:
419
+ clip_array = output_frames # list of RGB numpy arrays
420
+ size = clip_array[0].shape[1], clip_array[0].shape[0] # (width, height)
421
+ writer = FFMPEG_VideoWriter(final_vid_output_path, size, fps=fps, codec="libx264")
422
+
423
+ for frame in clip_array:
424
+ writer.write_frame(frame)
425
+
426
+ writer.close()
427
 
428
  return gr.update(value=final_vid_output_path)
429