Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import time | |
| import argparse | |
| from diffueraser.diffueraser import DiffuEraser | |
| from propainter.inference import Propainter, get_device | |
| import gradio as gr | |
| # Download Weights | |
| from huggingface_hub import snapshot_download | |
| # List of subdirectories to create inside "checkpoints" | |
| subfolders = [ | |
| "diffuEraser", | |
| "stable-diffusion-v1-5", | |
| "PCM_Weights", | |
| "propainter", | |
| "sd-vae-ft-mse" | |
| ] | |
| # Create each subdirectory | |
| for subfolder in subfolders: | |
| os.makedirs(os.path.join("weigths", subfolder), exist_ok=True) | |
| snapshot_download( | |
| repo_id = "lixiaowen/diffuEraser", | |
| local_dir = "./weights/diffuEraser" | |
| ) | |
| snapshot_download( | |
| repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5", | |
| local_dir = "./weights/stable-diffusion-v1-5" | |
| ) | |
| snapshot_download( | |
| repo_id = "wangfuyun/PCM_Weights", | |
| local_dir = "./weights/PCM_Weights" | |
| ) | |
| snapshot_download( | |
| repo_id = "camenduru/ProPainter", | |
| local_dir = "./weights/propainter" | |
| ) | |
| snapshot_download( | |
| repo_id = "stabilityai/sd-vae-ft-mse", | |
| local_dir = "./weights/sd-vae-ft-mse" | |
| ) | |
| # βββββββββββββββββββββ | |
| def infer(input_video, input_mask): | |
| video_length = 10 # The maximum length of output video | |
| mask_dilation_iter = 8 # Adjust it to change the degree of mask expansion | |
| max_img_size = 960 # The maximum length of output width and height | |
| save_path = "results" # Path to the output | |
| ref_stride = 10 | |
| neighbor_length = 10 | |
| subvideo_length = 50 | |
| base_model_path = "weights/stable-diffusion-v1-5" | |
| vae_path = "weights/sd-vae-ft-mse" | |
| diffueraser_path = "weights/diffuEraser" | |
| propainter_model_dir = "weights/propainter" | |
| if not os.path.exists(save_path): | |
| os.makedirs(save_path) | |
| priori_path = os.path.join(save_path, "priori.mp4") | |
| output_path = os.path.join(save_path, "diffueraser_result.mp4") | |
| ## model initialization | |
| device = get_device() | |
| # PCM params | |
| ckpt = "2-Step" | |
| video_inpainting_sd = DiffuEraser(device, base_model_path, vae_path, diffueraser_path, ckpt=ckpt) | |
| propainter = Propainter(propainter_model_dir, device=device) | |
| start_time = time.time() | |
| ## priori | |
| propainter.forward(input_video, input_mask, priori_path, video_length=video_length, | |
| ref_stride=ref_stride, neighbor_length=neighbor_length, subvideo_length = subvideo_length, | |
| mask_dilation = mask_dilation_iter) | |
| ## diffueraser | |
| guidance_scale = None # The default value is 0. | |
| video_inpainting_sd.forward(input_video, input_mask, priori_path, output_path, | |
| max_img_size = max_img_size, video_length=video_length, mask_dilation_iter=mask_dilation_iter, | |
| guidance_scale=guidance_scale) | |
| end_time = time.time() | |
| inference_time = end_time - start_time | |
| print(f"DiffuEraser inference time: {inference_time:.4f} s") | |
| torch.cuda.empty_cache() | |
| return output_path | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.Markdown("# DiffuEraser: A Diffusion Model for Video Inpainting") | |
| gr.Markdown("DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model Propainter in both content completeness and temporal consistency while maintaining acceptable efficiency.") | |
| gr.HTML(""" | |
| <div style="display:flex;column-gap:4px;"> | |
| <a href="https://github.com/lixiaowen-xw/DiffuEraser"> | |
| <img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
| </a> | |
| <a href="https://lixiaowen-xw.github.io/DiffuEraser-page"> | |
| <img src='https://img.shields.io/badge/Project-Page-green'> | |
| </a> | |
| <a href="https://lixiaowen-xw.github.io/DiffuEraser-page"> | |
| <img src='https://img.shields.io/badge/ArXiv-Paper-red'> | |
| </a> | |
| <a href="https://huggingface.co/spaces/fffiloni/DiffuEraser-demo?duplicate=true"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Input Video (MP4 ONLY)") | |
| input_mask = gr.Video(label="Input Mask Video (MP4 ONLY)") | |
| submit_btn = gr.Button("Submit") | |
| with gr.Column(): | |
| video_result = gr.Video(label="Result") | |
| gr.Examples( | |
| examples = [ | |
| ["./examples/example1/video.mp4", "./examples/example1/mask.mp4"], | |
| ["./examples/example2/video.mp4", "./examples/example2/mask.mp4"], | |
| ["./examples/example3/video.mp4", "./examples/example3/mask.mp4"], | |
| ], | |
| inputs = [input_video, input_mask] | |
| ) | |
| submit_btn.click( | |
| fn = infer, | |
| inputs = [input_video, input_mask], | |
| outputs = [video_result] | |
| ) | |
| demo.queue().launch(share = True ,show_api=True, show_error=True) | |