import sys # sys.path.append("./") import spaces import os import json import time import psutil import argparse import cv2 import torch import torchvision import numpy as np import gradio as gr from tools.painter import mask_painter from track_anything import TrackingAnything from utils.misc import get_device from utils.download_util import load_file_from_url from transformers import AutoTokenizer from omegaconf import OmegaConf from torchvision.transforms import functional as TF from torchvision.utils import save_image from einops import rearrange from PIL import Image from rose.models import AutoencoderKLWan, CLIPModel, WanT5EncoderModel, WanTransformer3DModel from rose.pipeline import WanFunInpaintPipeline from diffusers import FlowMatchEulerDiscreteScheduler def filter_kwargs(cls, kwargs): import inspect sig = inspect.signature(cls.__init__) valid_params = set(sig.parameters.keys()) - {'self', 'cls'} return {k: v for k, v in kwargs.items() if k in valid_params} from huggingface_hub import snapshot_download def download_component_subfolder(repo_id, subfolder): local_dir = snapshot_download( repo_id=repo_id, repo_type="model", local_dir="ckpt/Wan2.1-Fun-1.3B-InP", local_dir_use_symlinks=False, # allow_patterns=[f"{subfolder}/*"] ) return os.path.join(local_dir, subfolder) pretrained_model_path = "alibaba-pai/Wan2.1-Fun-1.3B-InP" transformer_path = "Kunbyte/ROSE" config_path = "configs/wan2.1/wan_civitai.yaml" config = OmegaConf.load(config_path) text_encoder_path = download_component_subfolder("alibaba-pai/Wan2.1-Fun-1.3B-InP", config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')) tokenizer_path = download_component_subfolder("alibaba-pai/Wan2.1-Fun-1.3B-InP", config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')) image_encoder_path = download_component_subfolder("alibaba-pai/Wan2.1-Fun-1.3B-InP", config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')) vae_path = download_component_subfolder("alibaba-pai/Wan2.1-Fun-1.3B-InP", config['vae_kwargs'].get('vae_subpath', 'vae')) transformer_path = download_component_subfolder("Kunbyte/ROSE", config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')) tokenizer= AutoTokenizer.from_pretrained(tokenizer_path) text_encoder = WanT5EncoderModel.from_pretrained( text_encoder_path, additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 ) clip_image_encoder = CLIPModel.from_pretrained(image_encoder_path) vae = AutoencoderKLWan.from_pretrained( vae_path, additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), ) transformer_subpath = config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer') transformer3d = WanTransformer3DModel.from_pretrained( transformer_path, transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), ) noise_scheduler = FlowMatchEulerDiscreteScheduler( **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) ) pipeline = WanFunInpaintPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer3d, scheduler=noise_scheduler, clip_image_encoder=clip_image_encoder ).to("cuda", torch.float16) def parse_augment(): parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default=None) parser.add_argument('--sam_model_type', type=str, default="vit_h") parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications") parser.add_argument('--mask_save', default=False) args = parser.parse_args() if not args.device: args.device = str(get_device()) return args # convert points input to prompt state def get_prompt(click_state, click_input): inputs = json.loads(click_input) points = click_state[0] labels = click_state[1] for input in inputs: points.append(input[:2]) labels.append(input[2]) click_state[0] = points click_state[1] = labels prompt = { "prompt_type":["click"], "input_point":click_state[0], "input_label":click_state[1], "multimask_output":"True", } return prompt @spaces.GPU # extract frames from upload video def get_frames_from_video(video_input, video_state): """ Args: video_path:str timestamp:float64 Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ video_path = video_input frames = [] user_name = time.time() operation_log = [("[Must Do]", "Click image"), (": Video uploaded! Try to click the image shown in step2 to add masks.\n", None)] try: cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) while cap.isOpened(): ret, frame = cap.read() if ret == True: current_memory_usage = psutil.virtual_memory().percent frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) if current_memory_usage > 90: operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")] print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.") break else: break except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: print("read_frame_source:{} error. {}\n".format(video_path, str(e))) image_size = (frames[0].shape[0],frames[0].shape[1]) # initialize video_state video_state = { "user_name": user_name, "video_name": os.path.split(video_path)[-1], "origin_images": frames, "painted_images": frames.copy(), "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), "logits": [None]*len(frames), "select_frame_number": 0, "fps": fps } video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size) model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True, choices=[], value=[]), \ gr.update(visible=True, value=operation_log), gr.update(visible=True, value=operation_log) # get the select frame from gradio slider def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown): # images = video_state[1] image_selection_slider -= 1 video_state["select_frame_number"] = image_selection_slider # once select a new template frame, set the image in sam model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) operation_log = [("",""), ("Select tracking start frame {}. Try to click the image to add masks for tracking.".format(image_selection_slider),"Normal")] return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log, operation_log # set the tracking end frame def get_end_number(track_pause_number_slider, video_state, interactive_state): interactive_state["track_end_number"] = track_pause_number_slider operation_log = [("",""),("Select tracking finish frame {}.Try to click the image to add masks for tracking.".format(track_pause_number_slider),"Normal")] return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log @spaces.GPU # use sam to get the mask def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData): """ Args: template_frame: PIL.Image point_prompt: flag for positive or negative button click click_state: [[points], [labels]] """ if point_prompt == "Positive": coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) interactive_state["positive_click_times"] += 1 else: coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) interactive_state["negative_click_times"] += 1 # prompt for sam model model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) prompt = get_prompt(click_state=click_state, click_input=coordinate) mask, logit, painted_image = model.first_frame_click( image=video_state["origin_images"][video_state["select_frame_number"]], points=np.array(prompt["input_point"]), labels=np.array(prompt["input_label"]), multimask=prompt["multimask_output"], ) video_state["masks"][video_state["select_frame_number"]] = mask video_state["logits"][video_state["select_frame_number"]] = logit video_state["painted_images"][video_state["select_frame_number"]] = painted_image operation_log = [("[Must Do]", "Add mask"), (": add the current displayed mask for video segmentation.\n", None), ("[Optional]", "Remove mask"), (": remove all added masks.\n", None), ("[Optional]", "Clear clicks"), (": clear current displayed mask.\n", None), ("[Optional]", "Click image"), (": Try to click the image shown in step2 if you want to generate more masks.\n", None)] return painted_image, video_state, interactive_state, operation_log, operation_log @spaces.GPU def add_multi_mask(video_state, interactive_state, mask_dropdown): try: mask = video_state["masks"][video_state["select_frame_number"]] interactive_state["multi_mask"]["masks"].append(mask) interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown) operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")] except: operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")] return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log def clear_click(video_state, click_state): click_state = [[],[]] template_frame = video_state["origin_images"][video_state["select_frame_number"]] operation_log = [("",""), ("Cleared points history and refresh the image.","Normal")] return template_frame, click_state, operation_log, operation_log def remove_multi_mask(interactive_state, mask_dropdown): interactive_state["multi_mask"]["mask_names"]= [] interactive_state["multi_mask"]["masks"] = [] operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")] return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log @spaces.GPU def show_mask(video_state, interactive_state, mask_dropdown): mask_dropdown.sort() select_frame = video_state["origin_images"][video_state["select_frame_number"]] for i in range(len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 mask = interactive_state["multi_mask"]["masks"][mask_number] select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")] return select_frame, operation_log, operation_log @spaces.GPU # tracking vos def vos_tracking_video(video_state, interactive_state, mask_dropdown): operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")] model.cutie.clear_memory() if interactive_state["track_end_number"]: following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] else: following_frames = video_state["origin_images"][video_state["select_frame_number"]:] if interactive_state["multi_mask"]["masks"]: if len(mask_dropdown) == 0: mask_dropdown = ["mask_001"] mask_dropdown.sort() template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) for i in range(1,len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) video_state["masks"][video_state["select_frame_number"]]= template_mask else: template_mask = video_state["masks"][video_state["select_frame_number"]] fps = float(video_state["fps"]) # operation error if len(np.unique(template_mask))==1: template_mask[0][0]=1 operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")] # return video_output, video_state, interactive_state, operation_error masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) # clear GPU memory model.cutie.clear_memory() if interactive_state["track_end_number"]: video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images else: video_state["masks"][video_state["select_frame_number"]:] = masks video_state["logits"][video_state["select_frame_number"]:] = logits video_state["painted_images"][video_state["select_frame_number"]:] = painted_images video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video interactive_state["inference_times"] += 1 print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"], interactive_state["positive_click_times"]+interactive_state["negative_click_times"], interactive_state["positive_click_times"], interactive_state["negative_click_times"])) #### shanggao code for mask save if interactive_state["mask_save"]: if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])): os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0])) i = 0 print("save mask") for mask in video_state["masks"]: np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask) i+=1 # save_mask(video_state["masks"], video_state["video_name"]) #### shanggao code for mask save return video_output, video_state, interactive_state, operation_log, operation_log @spaces.GPU(duration=180) def inpaint_video(video_state, *_): operation_log = [("", ""), ("Inpainting finished!", "Normal")] # import pdb;pdb.set_trace() frames = video_state["origin_images"] masks = video_state["masks"] # masks = masks * 255 fps = int(video_state["fps"]) total_frames = len(frames) target_frame_count = (total_frames - 1) // 16 * 16 + 1 frames = frames[:target_frame_count] masks = masks[:target_frame_count] frames_resized = [cv2.resize(frame, (720, 480), interpolation=cv2.INTER_CUBIC) for frame in frames] masks_resized = [cv2.resize(mask, (720, 480), interpolation=cv2.INTER_CUBIC) for mask in masks] with torch.no_grad(): video_tensor = torch.stack([TF.to_tensor(Image.fromarray(f)) for f in frames_resized], dim=1).unsqueeze(0).to("cuda", torch.float16) mask_tensor = torch.stack([TF.to_tensor(Image.fromarray(m*255)) for m in masks_resized], dim=1).unsqueeze(0).to("cuda", torch.float16) #video_tensor = torch.stack([torch.from_numpy(f).float() for f in frames_resized], dim=1).unsqueeze(0).to("cuda", torch.bfloat16) #mask_tensor = torch.stack([torch.from_numpy(m).float() for m in masks_resized], dim=1).unsqueeze(0).to("cuda", torch.bfloat16) output = pipeline( prompt="", video=video_tensor, mask_video=mask_tensor, num_frames=video_tensor.shape[2], num_inference_steps=50 ).videos output = output.clamp(0, 1).cpu() output_np = (output[0].permute(1, 2, 3, 0).numpy() * 255).astype(np.uint8) output_path = f"./result/inpaint/{video_state['video_name']}" os.makedirs(os.path.dirname(output_path), exist_ok=True) torchvision.io.write_video(output_path, torch.from_numpy(output_np), fps=fps, video_codec="libx264") return output_path, operation_log, operation_log @spaces.GPU # generate video after vos inference def generate_video_from_frames(frames, output_path, fps=30): """ Generates a video from a list of frames. Args: frames (list of numpy arrays): The frames to include in the video. output_path (str): The path to save the generated video. fps (int, optional): The frame rate of the output video. Defaults to 30. """ frames = torch.from_numpy(np.asarray(frames)) if not os.path.exists(os.path.dirname(output_path)): os.makedirs(os.path.dirname(output_path)) fps = int(fps) # import pdb;pdb.set_trace() torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264") return output_path @spaces.GPU def restart(): operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")] return { "user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 30 }, { "inference_times": 0, "negative_click_times" : 0, "positive_click_times": 0, "mask_save": args.mask_save, "multi_mask": { "mask_names": [], "masks": [] }, "track_end_number": None, }, [[],[]], None, None, None, \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", \ gr.update(visible=True, value=operation_log), gr.update(visible=False, value=operation_log) # args, defined in track_anything.py args = parse_augment() pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/' sam_checkpoint_url_dict = { 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" } checkpoint_fodler = os.path.join('.', 'weights') sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_fodler) cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'cutie-base-mega.pth'), checkpoint_fodler) # initialize sam, cutie, propainter models model = TrackingAnything(sam_checkpoint, cutie_checkpoint, args) title = r"""