Spaces:
Runtime error
Runtime error
| # this script is modified from https://github.com/MCG-NKU/AMT/blob/main/demos/demo_2x.py | |
| from json import load | |
| import os | |
| import cv2 | |
| import sys | |
| import glob | |
| import torch | |
| import argparse | |
| import numpy as np | |
| import os.path as osp | |
| from warnings import warn | |
| from omegaconf import OmegaConf | |
| from torchvision.utils import make_grid | |
| sys.path.append('.') | |
| from utils.utils import ( | |
| read, write, | |
| img2tensor, tensor2img, | |
| check_dim_and_resize | |
| ) | |
| from utils.build_utils import build_from_cfg | |
| from utils.utils import InputPadder | |
| AMT_G = { | |
| 'name': 'networks.AMT-G.Model', | |
| 'params':{ | |
| 'corr_radius': 3, | |
| 'corr_lvls': 4, | |
| 'num_flows': 5, | |
| } | |
| } | |
| def init(device="cuda"): | |
| ''' | |
| initialize the device and the anchor resolution. | |
| ''' | |
| if device == 'cuda': | |
| anchor_resolution = 1024 * 512 | |
| anchor_memory = 1500 * 1024**2 | |
| anchor_memory_bias = 2500 * 1024**2 | |
| vram_avail = torch.cuda.get_device_properties(device).total_memory | |
| print("VRAM available: {:.1f} MB".format(vram_avail / 1024 ** 2)) | |
| else: | |
| # Do not resize in cpu mode | |
| anchor_resolution = 8192*8192 | |
| anchor_memory = 1 | |
| anchor_memory_bias = 0 | |
| vram_avail = 1 | |
| return anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail | |
| def get_input_video_from_path(input_path, device="cuda"): | |
| ''' | |
| Get the input video from the input_path. | |
| params: | |
| input_path: str, the path of the input video. | |
| devices: str, the device to run the model. | |
| returns: | |
| inputs: list, the list of the input frames. | |
| scale: float, the scale of the input frames. | |
| padder: InputPadder, the padder to pad the input frames. | |
| ''' | |
| anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail = init(device) | |
| if osp.splitext(input_path)[-1] in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', | |
| '.webm', '.MP4', '.AVI', '.MOV', '.MKV', '.FLV', | |
| '.WMV', '.WEBM']: | |
| vcap = cv2.VideoCapture(input_path) | |
| inputs = [] | |
| w = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) | |
| scale = 1 if scale > 1 else scale | |
| scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 | |
| if scale < 1: | |
| print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") | |
| padding = int(16 / scale) | |
| padder = InputPadder((h, w), padding) | |
| while True: | |
| ret, frame = vcap.read() | |
| if ret is False: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame_t = img2tensor(frame).to(device) | |
| frame_t = padder.pad(frame_t) | |
| inputs.append(frame_t) | |
| print(f'Loading the [video] from {input_path}, the number of frames [{len(inputs)}]') | |
| else: | |
| raise TypeError("Input should be a video.") | |
| return inputs, scale, padder | |
| def load_model(ckpt_path, device="cuda"): | |
| ''' | |
| load the frame interpolation model. | |
| ''' | |
| network_cfg = AMT_G | |
| network_name = network_cfg['name'] | |
| print(f'Loading [{network_name}] from [{ckpt_path}]...') | |
| model = build_from_cfg(network_cfg) | |
| ckpt = torch.load(ckpt_path) | |
| model.load_state_dict(ckpt['state_dict']) | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| def interpolater(model, inputs, scale, padder, iters=1): | |
| ''' | |
| interpolating with the interpolation model. | |
| params: | |
| model: nn.Module, the frame interpolation model. | |
| inputs: list, the list of the input frames. | |
| scale: float, the scale of the input frames. | |
| iters: int, the number of iterations of interpolation. The final frames model generating is 2 ** iters * (m - 1) + 1 and m is input frames. | |
| returns: | |
| outputs: list, the list of the output frames. | |
| ''' | |
| print(f'Start frame interpolation:') | |
| embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) | |
| for i in range(iters): | |
| print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') | |
| outputs = [inputs[0]] | |
| for in_0, in_1 in zip(inputs[:-1], inputs[1:]): | |
| in_0 = in_0.to(device) | |
| in_1 = in_1.to(device) | |
| with torch.no_grad(): | |
| imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred'] | |
| outputs += [imgt_pred.cpu(), in_1.cpu()] | |
| inputs = outputs | |
| outputs = padder.unpad(*outputs) | |
| return outputs | |
| def write(outputs, input_path, output_path, frame_rate=30): | |
| ''' | |
| write results to the output_path. | |
| ''' | |
| if osp.exists(output_path) is False: | |
| os.makedirs(output_path) | |
| size = outputs[0].shape[2:][::-1] | |
| _, file_name_with_extension = os.path.split(input_path) | |
| file_name, _ = os.path.splitext(file_name_with_extension) | |
| save_video_path = f'{output_path}/output_{file_name}.mp4' | |
| writer = cv2.VideoWriter(save_video_path, cv2.VideoWriter_fourcc(*"mp4v"), | |
| frame_rate, size) | |
| for i, imgt_pred in enumerate(outputs): | |
| imgt_pred = tensor2img(imgt_pred) | |
| imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) | |
| writer.write(imgt_pred) | |
| print(f"Demo video is saved to [{save_video_path}]") | |
| writer.release() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--ckpt', type=str, default='amt-g.pth', help="The pretrained model.") | |
| parser.add_argument('--niters', type=int, default=1, help="Iter of Interpolation. The number of frames will be double after per iter.") | |
| parser.add_argument('--input', default="test.mp4", help="Input video.") | |
| parser.add_argument('--output_path', type=str, default='results', help="Output path.") | |
| parser.add_argument('--frame_rate', type=int, default=30, help="Frames rate of the output video.") | |
| args = parser.parse_args() | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| ckpt_path = args.ckpt | |
| input_path = args.input | |
| output_path = args.output_path | |
| iters = int(args.niters) | |
| frame_rate = int(args.frame_rate) | |
| inputs, scale, padder = get_input_video_from_path(input_path, device) | |
| model = load_model(ckpt_path, device) | |
| outputs = interpolater(model, inputs, scale, padder, iters) | |
| write(outputs, input_path, output_path, frame_rate) | |