Spaces:
Paused
Paused
| import glob | |
| import os | |
| import numpy as np | |
| from moviepy.editor import * | |
| import librosa | |
| import soundfile as sf | |
| import argparse | |
| import numpy as np | |
| import os | |
| import sys | |
| import time | |
| from tqdm import tqdm | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import shutil | |
| from config import init_args | |
| import data | |
| import models | |
| from models import * | |
| from utils import utils, torch_utils | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| vision_transform_list = [ | |
| transforms.Resize((128, 128)), | |
| transforms.CenterCrop((112, 112)), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ] | |
| video_transform = transforms.Compose(vision_transform_list) | |
| def read_image(frame_list): | |
| imgs = [] | |
| convert_tensor = transforms.ToTensor() | |
| for img_path in frame_list: | |
| image = Image.open(img_path).convert('RGB') | |
| image = convert_tensor(image) | |
| imgs.append(image.unsqueeze(0)) | |
| # (T, C, H ,W) | |
| imgs = torch.cat(imgs, dim=0).squeeze() | |
| imgs = video_transform(imgs) | |
| imgs = imgs.permute(1, 0, 2, 3) | |
| # (C, T, H ,W) | |
| return imgs | |
| def get_video_frames(origin_video_path): | |
| save_path = 'results/temp' | |
| if os.path.exists(save_path): | |
| os.system(f'rm -rf {save_path}') | |
| os.makedirs(save_path) | |
| command = f'ffmpeg -v quiet -y -i \"{origin_video_path}\" -f image2 -vf \"scale=-1:360,fps=15\" -qscale:v 3 \"{save_path}\"/frame%06d.jpg' | |
| os.system(command) | |
| frame_list = glob.glob(f'{save_path}/*.jpg') | |
| frame_list.sort() | |
| frame_list = frame_list[:2 * 15] | |
| frames = read_image(frame_list) | |
| return frames | |
| def postprocess_video_onsets(probs, thres=0.5, nearest=5): | |
| # import pdb; pdb.set_trace() | |
| video_onsets = [] | |
| pred = np.array(probs, copy=True) | |
| while True: | |
| max_ind = np.argmax(pred) | |
| video_onsets.append(max_ind) | |
| low = max(max_ind - nearest, 0) | |
| high = min(max_ind + nearest, pred.shape[0]) | |
| pred[low: high] = 0 | |
| if (pred > thres).sum() == 0: | |
| break | |
| video_onsets.sort() | |
| video_onsets = np.array(video_onsets) | |
| return video_onsets | |
| def detect_onset_of_audio(audio, sample_rate): | |
| onsets = librosa.onset.onset_detect( | |
| y=audio, sr=sample_rate, units='samples', delta=0.3) | |
| return onsets | |
| def get_onset_audio_range(audio_len, onsets, i): | |
| if i == 0: | |
| prev_offset = int(onsets[i] // 3) | |
| else: | |
| prev_offset = int((onsets[i] - onsets[i - 1]) // 3) | |
| if i == onsets.shape[0] - 1: | |
| post_offset = int((audio_len - onsets[i]) // 4 * 2) | |
| else: | |
| post_offset = int((onsets[i + 1] - onsets[i]) // 4 * 2) | |
| return prev_offset, post_offset | |
| def generate_audio(con_videoclip, video_onsets): | |
| np.random.seed(2022) | |
| con_audioclip = con_videoclip.audio | |
| con_audio, con_sr = con_audioclip.to_soundarray(), con_audioclip.fps | |
| con_audio = con_audio.mean(-1) | |
| target_sr = 22050 | |
| if target_sr != con_sr: | |
| con_audio = librosa.resample(con_audio, orig_sr=con_sr, target_sr=target_sr) | |
| con_sr = target_sr | |
| con_onsets = detect_onset_of_audio(con_audio, con_sr) | |
| gen_audio = np.zeros(int(2 * con_sr)) | |
| for i in range(video_onsets.shape[0]): | |
| prev_offset, post_offset = get_onset_audio_range( | |
| int(con_sr * 2), video_onsets, i) | |
| j = np.random.choice(con_onsets.shape[0]) | |
| prev_offset_con, post_offset_con = get_onset_audio_range( | |
| con_audio.shape[0], con_onsets, j) | |
| prev_offset = min(prev_offset, prev_offset_con) | |
| post_offset = min(post_offset, post_offset_con) | |
| gen_audio[video_onsets[i] - prev_offset: video_onsets[i] + | |
| post_offset] = con_audio[con_onsets[j] - prev_offset: con_onsets[j] + post_offset] | |
| return gen_audio | |
| def generate_video(net, original_video_list, cond_video_lists): | |
| save_folder = 'results/onset_baseline_cxav/vis4' | |
| os.makedirs(save_folder, exist_ok=True) | |
| origin_video_folder = os.path.join(save_folder, '0_original') | |
| os.makedirs(origin_video_folder, exist_ok=True) | |
| for i in range(len(original_video_list)): | |
| # import pdb; pdb.set_trace() | |
| shutil.copy(original_video_list[i], os.path.join( | |
| origin_video_folder, cond_video_lists[0][i].split('/')[-1])) | |
| ori_videoclip = VideoFileClip(original_video_list[i]) | |
| frames = get_video_frames(original_video_list[i]) | |
| inputs = { | |
| 'frames': frames.unsqueeze(0).to(device) | |
| } | |
| pred = net(inputs).squeeze() | |
| pred = torch.sigmoid(pred).data.cpu().numpy() | |
| video_onsets = postprocess_video_onsets(pred, thres=0.5, nearest=4) | |
| video_onsets = (video_onsets / 15 * 22050).astype(int) | |
| for ind, cond_idx in enumerate(range(len(cond_video_lists))): | |
| cond_video = cond_video_lists[cond_idx][i] | |
| cond_video_folder = os.path.join(save_folder, f'{ind * 2 + 1}_conditional_{ind}') | |
| os.makedirs(cond_video_folder, exist_ok=True) | |
| shutil.copy(cond_video, os.path.join( | |
| cond_video_folder, cond_video.split('/')[-1])) | |
| con_videoclip = VideoFileClip(cond_video) | |
| gen_audio = generate_audio(con_videoclip, video_onsets) | |
| save_audio_path = 'results/gen_audio.wav' | |
| sf.write(save_audio_path, gen_audio, 22050) | |
| gen_audioclip = AudioFileClip(save_audio_path) | |
| gen_videoclip = ori_videoclip.set_audio(gen_audioclip) | |
| save_gen_folder = os.path.join(save_folder, f'{ind * 2 + 2}_generate_{ind}') | |
| os.makedirs(save_gen_folder, exist_ok=True) | |
| gen_videoclip.write_videofile(os.path.join(save_gen_folder, cond_video.split('/')[-1])) | |
| if __name__ == '__main__': | |
| net = models.VideoOnsetNet(pretrained=False).to(device) | |
| resume = 'checkpoints/cxav_train/checkpoint_ep100.pth.tar' | |
| net, _ = torch_utils.load_model(resume, net, device=device, strict=True) | |
| read_folder = '' # name to a directory that generated with `audio_generation.py` | |
| cond_video_list_0 = glob.glob(f'{read_folder}/2sec_full_cond_video_0/*.mp4') | |
| cond_video_list_0.sort() | |
| original_video_list = ['_to_'.join(v.replace('2sec_full_cond_video_0', '2sec_full_orig_video').split('_to_')[:2])+'.mp4' for v in cond_video_list_0] | |
| assert len(original_video_list) == len(cond_video_list_0) | |
| generate_video(net, original_video_list, [cond_video_list_0,]) |