Spaces:
Running
on
Zero
Running
on
Zero
adding mps support for Apple silicon
#1
by
clementlr
- opened
app.py
CHANGED
|
@@ -21,6 +21,7 @@ import spaces
|
|
| 21 |
import torch
|
| 22 |
|
| 23 |
from moviepy.editor import ImageSequenceClip
|
|
|
|
| 24 |
from PIL import Image
|
| 25 |
from sam2.build_sam import build_sam2_video_predictor
|
| 26 |
|
|
@@ -72,7 +73,9 @@ examples = [
|
|
| 72 |
OBJ_ID = 0
|
| 73 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
| 74 |
model_cfg = "edgetam.yaml"
|
|
|
|
| 75 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
def get_video_fps(video_path):
|
|
@@ -340,11 +343,26 @@ def propagate_to_all(
|
|
| 340 |
input_points,
|
| 341 |
inference_state,
|
| 342 |
):
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
if inference_state:
|
| 349 |
inference_state["device"] = predictor.device
|
| 350 |
|
|
@@ -374,15 +392,16 @@ def propagate_to_all(
|
|
| 374 |
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
| 375 |
mask_image = show_mask(out_mask)
|
| 376 |
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
|
|
|
| 377 |
output_frame = np.array(output_frame)
|
| 378 |
output_frames.append(output_frame)
|
| 379 |
|
| 380 |
-
|
|
|
|
| 381 |
|
| 382 |
# Create a video clip from the image sequence
|
| 383 |
original_fps = get_video_fps(video_in)
|
| 384 |
-
fps = original_fps # Frames per second
|
| 385 |
-
clip = ImageSequenceClip(output_frames, fps=fps)
|
| 386 |
# Write the result to a file
|
| 387 |
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 388 |
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
|
@@ -390,8 +409,21 @@ def propagate_to_all(
|
|
| 390 |
tempfile.gettempdir(), final_vid_output_path
|
| 391 |
)
|
| 392 |
|
| 393 |
-
# Write the result to a file
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
return gr.update(value=final_vid_output_path)
|
| 397 |
|
|
|
|
| 21 |
import torch
|
| 22 |
|
| 23 |
from moviepy.editor import ImageSequenceClip
|
| 24 |
+
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter # adding this for MPS compatibility
|
| 25 |
from PIL import Image
|
| 26 |
from sam2.build_sam import build_sam2_video_predictor
|
| 27 |
|
|
|
|
| 73 |
OBJ_ID = 0
|
| 74 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
| 75 |
model_cfg = "edgetam.yaml"
|
| 76 |
+
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") # MPS support
|
| 77 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 78 |
+
predictor.to(device)
|
| 79 |
|
| 80 |
|
| 81 |
def get_video_fps(video_path):
|
|
|
|
| 343 |
input_points,
|
| 344 |
inference_state,
|
| 345 |
):
|
| 346 |
+
# Set boolean for cuda or mps support
|
| 347 |
+
use_cuda = torch.cuda.is_available()
|
| 348 |
+
use_mps = torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
| 349 |
+
|
| 350 |
+
if use_cuda:
|
| 351 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
| 352 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 353 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 354 |
+
autocast_kwargs = dict(device_type="cuda", dtype=torch.bfloat16)
|
| 355 |
+
|
| 356 |
+
elif use_mps:
|
| 357 |
+
autocast_kwargs = dict(device_type="mps", dtype=torch.float16)
|
| 358 |
+
|
| 359 |
+
with torch.autocast(**autocast_kwargs):
|
| 360 |
+
|
| 361 |
+
if use_cuda:
|
| 362 |
+
predictor.to("cuda")
|
| 363 |
+
elif use_mps:
|
| 364 |
+
predictor.to("mps")
|
| 365 |
+
|
| 366 |
if inference_state:
|
| 367 |
inference_state["device"] = predictor.device
|
| 368 |
|
|
|
|
| 392 |
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
| 393 |
mask_image = show_mask(out_mask)
|
| 394 |
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
| 395 |
+
output_frame = output_frame.convert("RGB")
|
| 396 |
output_frame = np.array(output_frame)
|
| 397 |
output_frames.append(output_frame)
|
| 398 |
|
| 399 |
+
if use_cuda:
|
| 400 |
+
torch.cuda.empty_cache()
|
| 401 |
|
| 402 |
# Create a video clip from the image sequence
|
| 403 |
original_fps = get_video_fps(video_in)
|
| 404 |
+
fps = float(original_fps) # Frames per second
|
|
|
|
| 405 |
# Write the result to a file
|
| 406 |
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 407 |
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
|
|
|
| 409 |
tempfile.gettempdir(), final_vid_output_path
|
| 410 |
)
|
| 411 |
|
| 412 |
+
# Write the result to a file using moviepy ImageSequenceClip
|
| 413 |
+
if use_cuda:
|
| 414 |
+
clip = ImageSequenceClip(output_frames, fps=fps)
|
| 415 |
+
clip.write_videofile(final_vid_output_path, codec="libx264")
|
| 416 |
+
|
| 417 |
+
# Write the result to a file using moviepy FFMPEG_VideoWriter for MPS compatibility
|
| 418 |
+
elif use_mps:
|
| 419 |
+
clip_array = output_frames # list of RGB numpy arrays
|
| 420 |
+
size = clip_array[0].shape[1], clip_array[0].shape[0] # (width, height)
|
| 421 |
+
writer = FFMPEG_VideoWriter(final_vid_output_path, size, fps=fps, codec="libx264")
|
| 422 |
+
|
| 423 |
+
for frame in clip_array:
|
| 424 |
+
writer.write_frame(frame)
|
| 425 |
+
|
| 426 |
+
writer.close()
|
| 427 |
|
| 428 |
return gr.update(value=final_vid_output_path)
|
| 429 |
|