Spaces:
Configuration error
Configuration error
| import os | |
| import torch | |
| from argparse import ArgumentParser, Namespace | |
| import json | |
| from typing import Any, Dict, List, Mapping, Tuple | |
| from easydict import EasyDict | |
| from video_to_video.video_to_video_model import VideoToVideo_sr | |
| from video_to_video.utils.seed import setup_seed | |
| from video_to_video.utils.logger import get_logger | |
| from video_super_resolution.color_fix import adain_color_fix | |
| from inference_utils import * | |
| logger = get_logger() | |
| class VEnhancer_sr(): | |
| def __init__(self, | |
| result_dir='./results/', | |
| file_name='000_video.mp4', | |
| model_path='./pretrained_weight', | |
| solver_mode='fast', | |
| steps=15, | |
| guide_scale=7.5, | |
| upscale=4, | |
| max_chunk_len=32, | |
| variant_info=None, | |
| ): | |
| self.model_path=model_path | |
| logger.info('checkpoint_path: {}'.format(self.model_path)) | |
| self.result_dir = result_dir | |
| self.file_name = file_name | |
| os.makedirs(self.result_dir, exist_ok=True) | |
| model_cfg = EasyDict(__name__='model_cfg') | |
| model_cfg.model_path = self.model_path | |
| self.model = VideoToVideo_sr(model_cfg) | |
| steps = 15 if solver_mode == 'fast' else steps | |
| self.solver_mode=solver_mode | |
| self.steps=steps | |
| self.guide_scale=guide_scale | |
| self.upscale = upscale | |
| self.max_chunk_len=max_chunk_len | |
| self.variant_info=variant_info | |
| def enhance_a_video(self, video_path, prompt): | |
| logger.info('input video path: {}'.format(video_path)) | |
| text = prompt | |
| logger.info('text: {}'.format(text)) | |
| caption = text + self.model.positive_prompt | |
| input_frames, input_fps = load_video(video_path) | |
| in_f_num = len(input_frames) | |
| logger.info('input frames length: {}'.format(in_f_num)) | |
| logger.info('input fps: {}'.format(input_fps)) | |
| video_data = preprocess(input_frames) | |
| _, _, h, w = video_data.shape | |
| logger.info('input resolution: {}'.format((h, w))) | |
| target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4) | |
| logger.info('target resolution: {}'.format((target_h, target_w))) | |
| pre_data = {'video_data': video_data, 'y': caption} | |
| pre_data['target_res'] = (target_h, target_w) | |
| total_noise_levels = 900 | |
| setup_seed(666) | |
| with torch.no_grad(): | |
| data_tensor = collate_fn(pre_data, 'cuda:0') | |
| output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \ | |
| solver_mode=self.solver_mode, guide_scale=self.guide_scale, \ | |
| max_chunk_len=self.max_chunk_len | |
| ) | |
| output = tensor2vid(output) | |
| # Using color fix | |
| output = adain_color_fix(output, video_data) | |
| save_video(output, self.result_dir, self.file_name, fps=input_fps) | |
| return os.path.join(self.result_dir, self.file_name) | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument("--input_path", required=True, type=str, help="input video path") | |
| parser.add_argument("--save_dir", type=str, default='results', help="save directory") | |
| parser.add_argument("--file_name", type=str, help="file name") | |
| parser.add_argument("--model_path", type=str, default='./pretrained_weight/model.pt', help="model path") | |
| parser.add_argument("--prompt", type=str, default='a good video', help="prompt") | |
| parser.add_argument("--upscale", type=int, default=4, help='up-scale') | |
| parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len') | |
| parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy') | |
| parser.add_argument("--cfg", type=float, default=7.5) | |
| parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal') | |
| parser.add_argument("--steps", type=int, default=15) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| input_path = args.input_path | |
| prompt = args.prompt | |
| model_path = args.model_path | |
| save_dir = args.save_dir | |
| file_name = args.file_name | |
| upscale = args.upscale | |
| max_chunk_len = args.max_chunk_len | |
| steps = args.steps | |
| solver_mode = args.solver_mode | |
| guide_scale = args.cfg | |
| assert solver_mode in ('fast', 'normal') | |
| venhancer_sr = VEnhancer_sr( | |
| result_dir=save_dir, | |
| file_name=file_name, # new added | |
| model_path=model_path, | |
| solver_mode=solver_mode, | |
| steps=steps, | |
| guide_scale=guide_scale, | |
| upscale=upscale, | |
| max_chunk_len=max_chunk_len, | |
| variant_info=None, | |
| ) | |
| venhancer_sr.enhance_a_video(input_path, prompt) | |
| if __name__ == '__main__': | |
| main() | |