Spaces:
Configuration error
Configuration error
| import os | |
| import subprocess | |
| import tempfile | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| from typing import Mapping | |
| from einops import rearrange | |
| import numpy as np | |
| import torchvision.transforms.functional as transforms_F | |
| from video_to_video.utils.logger import get_logger | |
| logger = get_logger() | |
| def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): | |
| mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) | |
| std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) | |
| video = video.mul_(std).add_(mean) | |
| video.clamp_(0, 1) | |
| video = video * 255.0 | |
| images = rearrange(video, 'b c f h w -> b f h w c')[0] | |
| return images | |
| def preprocess(input_frames): | |
| out_frame_list = [] | |
| for pointer in range(len(input_frames)): | |
| frame = input_frames[pointer] | |
| frame = frame[:, :, ::-1] | |
| frame = Image.fromarray(frame.astype('uint8')).convert('RGB') | |
| frame = transforms_F.to_tensor(frame) | |
| out_frame_list.append(frame) | |
| out_frames = torch.stack(out_frame_list, dim=0) | |
| out_frames.clamp_(0, 1) | |
| mean = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1) | |
| std = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1) | |
| out_frames.sub_(mean.view(1, -1, 1, 1)).div_(std.view(1, -1, 1, 1)) | |
| return out_frames | |
| def adjust_resolution(h, w, up_scale): | |
| if h*up_scale < 720: | |
| up_s = 720/h | |
| target_h = int(up_s*h//2*2) | |
| target_w = int(up_s*w//2*2) | |
| elif h*w*up_scale*up_scale > 1280*2048: | |
| up_s = np.sqrt(1280*2048/(h*w)) | |
| target_h = int(up_s*h//2*2) | |
| target_w = int(up_s*w//2*2) | |
| else: | |
| target_h = int(up_scale*h//2*2) | |
| target_w = int(up_scale*w//2*2) | |
| return (target_h, target_w) | |
| def make_mask_cond(in_f_num, interp_f_num): | |
| mask_cond = [] | |
| interp_cond = [-1 for _ in range(interp_f_num)] | |
| for i in range(in_f_num): | |
| mask_cond.append(i) | |
| if i != in_f_num - 1: | |
| mask_cond += interp_cond | |
| return mask_cond | |
| def load_video(vid_path): | |
| capture = cv2.VideoCapture(vid_path) | |
| _fps = capture.get(cv2.CAP_PROP_FPS) | |
| _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) | |
| pointer = 0 | |
| frame_list = [] | |
| stride = 1 | |
| while len(frame_list) < _total_frame_num: | |
| ret, frame = capture.read() | |
| pointer += 1 | |
| if (not ret) or (frame is None): | |
| break | |
| if pointer >= _total_frame_num + 1: | |
| break | |
| if pointer % stride == 0: | |
| frame_list.append(frame) | |
| capture.release() | |
| return frame_list, _fps | |
| def save_video(video, save_dir, file_name, fps=16.0): | |
| output_path = os.path.join(save_dir, file_name) | |
| images = [(img.numpy()).astype('uint8') for img in video] | |
| temp_dir = tempfile.mkdtemp() | |
| for fid, frame in enumerate(images): | |
| tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1)) | |
| cv2.imwrite(tpth, frame[:, :, ::-1]) | |
| tmp_path = os.path.join(save_dir, 'tmp.mp4') | |
| cmd = f'ffmpeg -y -f image2 -framerate {fps} -i {temp_dir}/%06d.png \ | |
| -vcodec libx264 -preset ultrafast -crf 0 -pix_fmt yuv420p {tmp_path}' | |
| status, output = subprocess.getstatusoutput(cmd) | |
| if status != 0: | |
| logger.error('Save Video Error with {}'.format(output)) | |
| os.system(f'rm -rf {temp_dir}') | |
| os.rename(tmp_path, output_path) | |
| def collate_fn(data, device): | |
| """Prepare the input just before the forward function. | |
| This method will move the tensors to the right device. | |
| Usually this method does not need to be overridden. | |
| Args: | |
| data: The data out of the dataloader. | |
| device: The device to move data to. | |
| Returns: The processed data. | |
| """ | |
| from torch.utils.data.dataloader import default_collate | |
| def get_class_name(obj): | |
| return obj.__class__.__name__ | |
| if isinstance(data, dict) or isinstance(data, Mapping): | |
| return type(data)({ | |
| k: collate_fn(v, device) if k != 'img_metas' else v | |
| for k, v in data.items() | |
| }) | |
| elif isinstance(data, (tuple, list)): | |
| if 0 == len(data): | |
| return torch.Tensor([]) | |
| if isinstance(data[0], (int, float)): | |
| return default_collate(data).to(device) | |
| else: | |
| return type(data)(collate_fn(v, device) for v in data) | |
| elif isinstance(data, np.ndarray): | |
| if data.dtype.type is np.str_: | |
| return data | |
| else: | |
| return collate_fn(torch.from_numpy(data), device) | |
| elif isinstance(data, torch.Tensor): | |
| return data.to(device) | |
| elif isinstance(data, (bytes, str, int, float, bool, type(None))): | |
| return data | |
| else: | |
| raise ValueError(f'Unsupported data type {type(data)}') |