Spaces:
Runtime error
Runtime error
| import av | |
| import os | |
| import torch | |
| import tempfile | |
| import shutil | |
| import atexit | |
| import subprocess | |
| import gradio as gr | |
| from convert import convert_video | |
| def get_video_length_av(video_path): | |
| with av.open(video_path) as container: | |
| stream = container.streams.video[0] | |
| if container.duration is not None: | |
| duration_in_seconds = float(container.duration) / av.time_base | |
| else: | |
| duration_in_seconds = stream.duration * stream.time_base | |
| return duration_in_seconds | |
| def get_video_dimensions(video_path): | |
| with av.open(video_path) as container: | |
| video_stream = container.streams.video[0] | |
| width = video_stream.width | |
| height = video_stream.height | |
| return width, height | |
| def get_free_memory_gb(): | |
| gpu_index = torch.cuda.current_device() | |
| gpu_properties = torch.cuda.get_device_properties(gpu_index) | |
| total_memory = gpu_properties.total_memory | |
| allocated_memory = torch.cuda.memory_allocated(gpu_index) | |
| free_memory = total_memory - allocated_memory | |
| return free_memory / 1024**3 | |
| def cleanup_temp_directories(): | |
| print("Deleting temporary files") | |
| for temp_dir in temp_directories: | |
| try: | |
| shutil.rmtree(temp_dir) | |
| except FileNotFoundError: | |
| print(f"Could not delete directory {temp_dir}") | |
| def ffmpeg_remux_audio(source_video_path, dest_video_path, output_path): | |
| # Build the ffmpeg command to extract audio and remux into another video | |
| command = [ | |
| "ffmpeg", | |
| "-i", | |
| dest_video_path, # Input destination video file | |
| "-i", | |
| source_video_path, # Input source video file (for the audio) | |
| "-c:v", | |
| "copy", # Copy the video stream as is | |
| "-c:a", | |
| "copy", # Copy the audio stream as is | |
| "-map", | |
| "0:v:0", # Map the video stream from the destination file | |
| "-map", | |
| "1:a:0", # Map the audio stream from the source file | |
| output_path, # Specify the output file path | |
| ] | |
| try: | |
| # Run the ffmpeg command | |
| subprocess.run(command, check=True) | |
| except subprocess.CalledProcessError as e: | |
| # Handle errors during the subprocess execution | |
| print(f"An error occurred: {e}") | |
| return dest_video_path | |
| return output_path | |
| def inference(video): | |
| if get_video_length_av(video) > 30: | |
| raise gr.Error("Length of video cannot be over 30 seconds") | |
| if get_video_dimensions(video) > (1920, 1920): | |
| raise gr.Error("Video resolution must not be higher than 1920x1080") | |
| temp_dir = tempfile.mkdtemp() | |
| temp_directories.append(temp_dir) | |
| output_composition = temp_dir + "/matted_video.mp4" | |
| convert_video( | |
| model, # The loaded model, can be on any device (cpu or cuda). | |
| input_source=video, # A video file or an image sequence directory. | |
| downsample_ratio=0.25, # [Optional] If None, make downsampled max size be 512px. | |
| output_composition=output_composition, # File path if video; directory path if png sequence. | |
| output_alpha=None, # [Optional] Output the raw alpha prediction. | |
| output_foreground=None, # [Optional] Output the raw foreground prediction. | |
| output_video_mbps=4, # Output video mbps. Not needed for png sequence. | |
| seq_chunk=12, # Process n frames at once for better parallelism. | |
| num_workers=1, # Only for image sequence input. Reader threads. | |
| progress=True, # Print conversion progress. | |
| ) | |
| resulting_video = f"{temp_dir}/matted_{os.path.split(video)[1]}" | |
| return ffmpeg_remux_audio(video, output_composition, resulting_video) | |
| if __name__ == "__main__": | |
| temp_directories = [] | |
| atexit.register(cleanup_temp_directories) | |
| model = torch.hub.load( | |
| "PeterL1n/RobustVideoMatting", "mobilenetv3", trust_repo=True | |
| ) | |
| if torch.cuda.is_available(): | |
| free_memory = get_free_memory_gb() | |
| concurrency_count = int(free_memory // 7) | |
| print(f"Using GPU with concurrency: {concurrency_count}") | |
| print(f"Available video memory: {free_memory} GB") | |
| model = model.cuda() | |
| else: | |
| print("Using CPU") | |
| concurrency_count = 1 | |
| with gr.Blocks(title="Robust Video Matting") as block: | |
| gr.Markdown("# Robust Video Matting") | |
| gr.Markdown( | |
| "Gradio demo for Robust Video Matting. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below." | |
| ) | |
| with gr.Row(): | |
| inp = gr.Video(label="Input Video", sources=["upload"], include_audio=True) | |
| out = gr.Video(label="Output Video") | |
| btn = gr.Button("Run") | |
| btn.click(inference, inputs=inp, outputs=out) | |
| gr.Examples( | |
| examples=[["example.mp4"]], | |
| inputs=[inp], | |
| ) | |
| gr.HTML( | |
| "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>" | |
| ) | |
| block.queue(api_open=False, max_size=5, concurrency_count=concurrency_count).launch( | |
| share=False | |
| ) | |