Spaces:
Build error
Build error
| """ | |
| input: json file with video, audio, motion paths | |
| output: igraph object with nodes containing video, audio, motion, position, velocity, axis_angle, previous, next, frame, fps | |
| preprocess: | |
| 1. assume you have a video for one speaker in folder, listed in | |
| -- video_a.mp4 | |
| -- video_b.mp4 | |
| run process_video.py to extract frames and audio | |
| """ | |
| import os | |
| import smplx | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import librosa | |
| import igraph | |
| import json | |
| import utils.rotation_conversions as rc | |
| from moviepy.editor import VideoClip, AudioFileClip, VideoFileClip | |
| from tqdm import tqdm | |
| import imageio | |
| import tempfile | |
| import argparse | |
| def get_motion_reps_tensor(motion_tensor, smplx_model, pose_fps=30, device='cuda'): | |
| bs, n, _ = motion_tensor.shape | |
| motion_tensor = motion_tensor.float().to(device) | |
| motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165) | |
| output = smplx_model( | |
| betas=torch.zeros(bs * n, 300, device=device), | |
| transl=torch.zeros(bs * n, 3, device=device), | |
| expression=torch.zeros(bs * n, 100, device=device), | |
| jaw_pose=torch.zeros(bs * n, 3, device=device), | |
| global_orient=torch.zeros(bs * n, 3, device=device), | |
| body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3], | |
| left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3], | |
| right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3], | |
| return_joints=True, | |
| leye_pose=torch.zeros(bs * n, 3, device=device), | |
| reye_pose=torch.zeros(bs * n, 3, device=device), | |
| ) | |
| joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :] | |
| dt = 1 / pose_fps | |
| init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt | |
| middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt) | |
| final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt | |
| vel = torch.cat([init_vel, middle_vel, final_vel], dim=1) | |
| position = joints | |
| rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3)) | |
| rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6) | |
| init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt | |
| middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt) | |
| final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt | |
| angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3) | |
| rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15) | |
| return { | |
| "position": position, | |
| "velocity": vel, | |
| "rotation": rot6d, | |
| "axis_angle": motion_tensor, | |
| "angular_velocity": angular_velocity, | |
| "rep15d": rep15d, | |
| } | |
| def get_motion_reps(motion, smplx_model, pose_fps=30): | |
| gt_motion_tensor = motion["poses"] | |
| n = gt_motion_tensor.shape[0] | |
| bs = 1 | |
| gt_motion_tensor = torch.from_numpy(gt_motion_tensor).float().to(device).unsqueeze(0) | |
| gt_motion_tensor_reshaped = gt_motion_tensor.reshape(bs * n, 165) | |
| output = smplx_model( | |
| betas=torch.zeros(bs * n, 300).to(device), | |
| transl=torch.zeros(bs * n, 3).to(device), | |
| expression=torch.zeros(bs * n, 100).to(device), | |
| jaw_pose=torch.zeros(bs * n, 3).to(device), | |
| global_orient=torch.zeros(bs * n, 3).to(device), | |
| body_pose=gt_motion_tensor_reshaped[:, 3:21 * 3 + 3], | |
| left_hand_pose=gt_motion_tensor_reshaped[:, 25 * 3:40 * 3], | |
| right_hand_pose=gt_motion_tensor_reshaped[:, 40 * 3:55 * 3], | |
| return_joints=True, | |
| leye_pose=torch.zeros(bs * n, 3).to(device), | |
| reye_pose=torch.zeros(bs * n, 3).to(device), | |
| ) | |
| joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :] | |
| dt = 1 / pose_fps | |
| init_vel = (joints[1:2] - joints[0:1]) / dt | |
| middle_vel = (joints[2:] - joints[:-2]) / (2 * dt) | |
| final_vel = (joints[-1:] - joints[-2:-1]) / dt | |
| vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0) | |
| position = joints | |
| rot_matrices = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3))[0] | |
| rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy() | |
| init_vel = (motion["poses"][1:2] - motion["poses"][0:1]) / dt | |
| middle_vel = (motion["poses"][2:] - motion["poses"][:-2]) / (2 * dt) | |
| final_vel = (motion["poses"][-1:] - motion["poses"][-2:-1]) / dt | |
| angular_velocity = np.concatenate([init_vel, middle_vel, final_vel], axis=0).reshape(n, 55, 3) | |
| rep15d = np.concatenate([ | |
| position, | |
| vel, | |
| rot6d, | |
| angular_velocity], | |
| axis=2 | |
| ).reshape(n, 55*15) | |
| return { | |
| "position": position, | |
| "velocity": vel, | |
| "rotation": rot6d, | |
| "axis_angle": motion["poses"], | |
| "angular_velocity": angular_velocity, | |
| "rep15d": rep15d, | |
| "trans": motion["trans"] | |
| } | |
| def create_graph(json_path, smplx_model): | |
| fps = 30 | |
| data_meta = json.load(open(json_path, "r")) | |
| graph = igraph.Graph(directed=True) | |
| global_i = 0 | |
| for data_item in data_meta: | |
| video_path = os.path.join(data_item['video_path'], data_item['video_id'] + ".mp4") | |
| # audio_path = os.path.join(data_item['audio_path'], data_item['video_id'] + ".wav") | |
| motion_path = os.path.join(data_item['motion_path'], data_item['video_id'] + ".npz") | |
| video_id = data_item.get("video_id", "") | |
| motion = np.load(motion_path, allow_pickle=True) | |
| motion_reps = get_motion_reps(motion, smplx_model) | |
| position = motion_reps['position'] | |
| velocity = motion_reps['velocity'] | |
| trans = motion_reps['trans'] | |
| axis_angle = motion_reps['axis_angle'] | |
| # audio, sr = librosa.load(audio_path, sr=None) | |
| # audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) | |
| all_frames = [] | |
| reader = imageio.get_reader(video_path) | |
| all_frames = [] | |
| for frame in reader: | |
| all_frames.append(frame) | |
| video_frames = np.array(all_frames) | |
| min_frames = min(len(video_frames), position.shape[0]) | |
| position = position[:min_frames] | |
| velocity = velocity[:min_frames] | |
| video_frames = video_frames[:min_frames] | |
| # print(min_frames) | |
| for i in tqdm(range(min_frames)): | |
| if i == 0: | |
| previous = -1 | |
| next_node = global_i + 1 | |
| elif i == min_frames - 1: | |
| previous = global_i - 1 | |
| next_node = -1 | |
| else: | |
| previous = global_i - 1 | |
| next_node = global_i + 1 | |
| graph.add_vertex( | |
| idx=global_i, | |
| name=video_id, | |
| motion=motion_reps, | |
| position=position[i], | |
| velocity=velocity[i], | |
| axis_angle=axis_angle[i], | |
| trans=trans[i], | |
| # audio=audio[], | |
| video=video_frames[i], | |
| previous=previous, | |
| next=next_node, | |
| frame=i, | |
| fps=fps, | |
| ) | |
| global_i += 1 | |
| return graph | |
| def create_edges(graph, threshold_edges): | |
| adaptive_length = [-4, -3, -2, -1, 1, 2, 3, 4] | |
| # print() | |
| for i, node in enumerate(graph.vs): | |
| current_position = node['position'] | |
| current_velocity = node['velocity'] | |
| current_trans = node['trans'] | |
| # print(current_position.shape, current_velocity.shape) | |
| avg_position = np.zeros(current_position.shape[0]) | |
| avg_velocity = np.zeros(current_position.shape[0]) | |
| avg_trans = 0 | |
| count = 0 | |
| for node_offset in adaptive_length: | |
| idx = i + node_offset | |
| if idx < 0 or idx >= len(graph.vs): | |
| continue | |
| if node_offset < 0: | |
| if graph.vs[idx]['next'] == -1:continue | |
| else: | |
| if graph.vs[idx]['previous'] == -1:continue | |
| # add check | |
| other_node = graph.vs[idx] | |
| other_position = other_node['position'] | |
| other_velocity = other_node['velocity'] | |
| other_trans = other_node['trans'] | |
| # print(other_position.shape, other_velocity.shape) | |
| avg_position += np.linalg.norm(current_position - other_position, axis=1) | |
| avg_velocity += np.linalg.norm(current_velocity - other_velocity, axis=1) | |
| avg_trans += np.linalg.norm(current_trans - other_trans, axis=0) | |
| count += 1 | |
| if count == 0: | |
| continue | |
| threshold_position = avg_position / count | |
| threshold_velocity = avg_velocity / count | |
| threshold_trans = avg_trans / count | |
| # print(threshold_position, threshold_velocity, threshold_trans) | |
| for j, other_node in enumerate(graph.vs): | |
| if i == j: | |
| continue | |
| if j == node['previous'] or j == node['next']: | |
| graph.add_edge(i, j, is_continue=1) | |
| continue | |
| other_position = other_node['position'] | |
| other_velocity = other_node['velocity'] | |
| other_trans = other_node['trans'] | |
| position_similarity = np.linalg.norm(current_position - other_position, axis=1) | |
| velocity_similarity = np.linalg.norm(current_velocity - other_velocity, axis=1) | |
| trans_similarity = np.linalg.norm(current_trans - other_trans, axis=0) | |
| if trans_similarity < threshold_trans: | |
| if np.sum(position_similarity < threshold_edges*threshold_position) >= 45 and np.sum(velocity_similarity < threshold_edges*threshold_velocity) >= 45: | |
| graph.add_edge(i, j, is_continue=0) | |
| print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}") | |
| in_degrees = graph.indegree() | |
| out_degrees = graph.outdegree() | |
| avg_in_degree = sum(in_degrees) / len(in_degrees) | |
| avg_out_degree = sum(out_degrees) / len(out_degrees) | |
| print(f"Average In-degree: {avg_in_degree}") | |
| print(f"Average Out-degree: {avg_out_degree}") | |
| print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}") | |
| print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}") | |
| # igraph.plot(graph, target="/content/test.png", bbox=(1000, 1000), vertex_size=10) | |
| return graph | |
| def random_walk(graph, walk_length, start_node=None): | |
| if start_node is None: | |
| start_node = np.random.choice(graph.vs) | |
| walk = [start_node] | |
| is_continue = [1] | |
| for _ in range(walk_length): | |
| current_node = walk[-1] | |
| neighbor_indices = graph.neighbors(current_node.index, mode='OUT') | |
| if not neighbor_indices: | |
| break | |
| next_idx = np.random.choice(neighbor_indices) | |
| edge_id = graph.get_eid(current_node.index, next_idx) | |
| is_cont = graph.es[edge_id]['is_continue'] | |
| walk.append(graph.vs[next_idx]) | |
| is_continue.append(is_cont) | |
| return walk, is_continue | |
| import subprocess | |
| def path_visualization(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False): | |
| all_frames = [node['video'] for node in path] | |
| average_dis_continue = 1 - sum(is_continue) / len(is_continue) | |
| if verbose_continue: | |
| print("average_dis_continue:", average_dis_continue) | |
| fps = graph.vs[0]['fps'] | |
| duration = len(all_frames) / fps | |
| def make_frame(t): | |
| idx = min(int(t * fps), len(all_frames) - 1) | |
| return all_frames[idx] | |
| video_only_path = 'video_only.mp4' # Temporary file | |
| video_clip = VideoClip(make_frame, duration=duration) | |
| video_clip.write_videofile( | |
| video_only_path, | |
| codec='libx264', | |
| fps=fps, | |
| audio=False | |
| ) | |
| # Optionally, ensure audio and video durations match | |
| if audio_path is not None: | |
| audio_clip = AudioFileClip(audio_path) | |
| video_duration = video_clip.duration | |
| audio_duration = audio_clip.duration | |
| if audio_duration > video_duration: | |
| # Trim the audio | |
| trimmed_audio_path = 'trimmed_audio.aac' | |
| audio_clip = audio_clip.subclip(0, video_duration) | |
| audio_clip.write_audiofile(trimmed_audio_path) | |
| audio_input = trimmed_audio_path | |
| else: | |
| audio_input = audio_path | |
| # Use FFmpeg to combine video and audio | |
| ffmpeg_command = [ | |
| 'ffmpeg', '-y', | |
| '-i', video_only_path, | |
| '-i', audio_input, | |
| '-c:v', 'copy', | |
| '-c:a', 'aac', | |
| '-strict', 'experimental', | |
| save_path | |
| ] | |
| subprocess.check_call(ffmpeg_command) | |
| # Clean up temporary files if necessary | |
| os.remove(video_only_path) | |
| if audio_input != audio_path: | |
| os.remove(audio_input) | |
| if return_motion: | |
| all_motion = [node['axis_angle'] for node in path] | |
| all_motion = np.stack(all_motion, 0) | |
| return all_motion | |
| def generate_transition_video(frame_start_path, frame_end_path, output_video_path): | |
| import subprocess | |
| import os | |
| # Define the path to your model and inference script | |
| model_path = "./frame-interpolation-pytorch/film_net_fp32.pt" | |
| inference_script = "./frame-interpolation-pytorch/inference.py" | |
| # Build the command to run the inference script | |
| command = [ | |
| "python", | |
| inference_script, | |
| model_path, | |
| frame_start_path, | |
| frame_end_path, | |
| "--save_path", output_video_path, | |
| "--gpu", | |
| "--frames", "3", | |
| "--fps", "30" | |
| ] | |
| # Run the command | |
| try: | |
| subprocess.run(command, check=True) | |
| print(f"Generated transition video saved at {output_video_path}") | |
| except subprocess.CalledProcessError as e: | |
| print(f"Error occurred while generating transition video: {e}") | |
| def path_visualization_v2(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False): | |
| ''' | |
| this is for hugging face demo for fast interpolation. our paper use a diffusion based interpolation method | |
| ''' | |
| all_frames = [node['video'] for node in path] | |
| average_dis_continue = 1 - sum(is_continue) / len(is_continue) | |
| if verbose_continue: | |
| print("average_dis_continue:", average_dis_continue) | |
| duration = len(all_frames) / graph.vs[0]['fps'] | |
| # First loop: Confirm where blending is needed | |
| discontinuity_indices = [] | |
| for i, cont in enumerate(is_continue): | |
| if cont == 0: | |
| discontinuity_indices.append(i) | |
| # Identify blending positions without overlapping | |
| blend_positions = [] | |
| processed_frames = set() | |
| for i in discontinuity_indices: | |
| # Define the frames for blending: i-2 to i+2 | |
| start_idx = i - 2 | |
| end_idx = i + 2 | |
| # Check index boundaries | |
| if start_idx < 0 or end_idx >= len(all_frames): | |
| continue # Skip if indices are out of bounds | |
| # Check for overlapping frames | |
| overlap = any(idx in processed_frames for idx in range(i - 1, i + 2)) | |
| if overlap: | |
| continue # Skip if frames have been processed | |
| # Mark frames as processed | |
| processed_frames.update(range(i - 1, i + 2)) | |
| blend_positions.append(i) | |
| # Second loop: Perform blending | |
| temp_dir = tempfile.mkdtemp(prefix='blending_frames_') | |
| for i in tqdm(blend_positions): | |
| start_frame_idx = i - 2 | |
| end_frame_idx = i + 2 | |
| frame_start = all_frames[start_frame_idx] | |
| frame_end = all_frames[end_frame_idx] | |
| frame_start_path = os.path.join(temp_dir, f'frame_{start_frame_idx}.png') | |
| frame_end_path = os.path.join(temp_dir, f'frame_{end_frame_idx}.png') | |
| # Save the start and end frames as images | |
| imageio.imwrite(frame_start_path, frame_start) | |
| imageio.imwrite(frame_end_path, frame_end) | |
| # Call FiLM API to generate video | |
| generated_video_path = os.path.join(temp_dir, f'generated_{start_frame_idx}_{end_frame_idx}.mp4') | |
| generate_transition_video(frame_start_path, frame_end_path, generated_video_path) | |
| # Read the generated video frames | |
| reader = imageio.get_reader(generated_video_path) | |
| generated_frames = [frame for frame in reader] | |
| reader.close() | |
| # Replace the middle three frames (i-1, i, i+1) in all_frames | |
| total_generated_frames = len(generated_frames) | |
| if total_generated_frames < 5: | |
| print(f"Generated video has insufficient frames ({total_generated_frames}). Skipping blending at position {i}.") | |
| continue | |
| middle_start = 1 # Start index for middle 3 frames | |
| middle_frames = generated_frames[middle_start:middle_start+3] | |
| for idx, frame_idx in enumerate(range(i - 1, i + 2)): | |
| all_frames[frame_idx] = middle_frames[idx] | |
| # Create the video clip | |
| def make_frame(t): | |
| idx = min(int(t * graph.vs[0]['fps']), len(all_frames) - 1) | |
| return all_frames[idx] | |
| video_clip = VideoClip(make_frame, duration=duration) | |
| if audio_path is not None: | |
| audio_clip = AudioFileClip(audio_path) | |
| video_clip = video_clip.set_audio(audio_clip) | |
| video_clip.write_videofile(save_path, codec='libx264', fps=graph.vs[0]['fps'], audio_codec='aac') | |
| if return_motion: | |
| all_motion = [node['axis_angle'] for node in path] | |
| all_motion = np.stack(all_motion, 0) | |
| return all_motion | |
| def graph_pruning(graph): | |
| ascc = graph.clusters(mode="STRONG") | |
| lascc = ascc.giant() | |
| print(f"before nodes: {len(graph.vs)}, edges: {len(graph.es)}") | |
| print(f"after nodes: {len(lascc.vs)}, edges: {len(lascc.es)}") | |
| in_degrees = lascc.indegree() | |
| out_degrees = lascc.outdegree() | |
| avg_in_degree = sum(in_degrees) / len(in_degrees) | |
| avg_out_degree = sum(out_degrees) / len(out_degrees) | |
| print(f"Average In-degree: {avg_in_degree}") | |
| print(f"Average Out-degree: {avg_out_degree}") | |
| print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}") | |
| print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}") | |
| return lascc | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--json_save_path", type=str, default="") | |
| parser.add_argument("--graph_save_path", type=str, default="") | |
| parser.add_argument("--threshold", type=float, default=1.0) | |
| args = parser.parse_args() | |
| json_path = args.json_save_path | |
| graph_path = args.graph_save_path | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| smplx_model = smplx.create( | |
| "./emage/smplx_models/", | |
| model_type='smplx', | |
| gender='NEUTRAL_2020', | |
| use_face_contour=False, | |
| num_betas=300, | |
| num_expression_coeffs=100, | |
| ext='npz', | |
| use_pca=False, | |
| ).to(device).eval() | |
| # single_test | |
| # graph = create_graph('/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/Abortion_Laws_-_Last_Week_Tonight_with_John_Oliver_HBO-DRauXXz6t0Y.webm.json') | |
| graph = create_graph(json_path, smplx_model) | |
| graph = create_edges(graph, args.threshold) | |
| # pool_path = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/show-oliver-test.pkl" | |
| # graph = igraph.Graph.Read_Pickle(fname=pool_path) | |
| # graph = igraph.Graph.Read_Pickle(fname="/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/test.pkl") | |
| # walk, is_continue = random_walk(graph, 100) | |
| # motion = path_visualization(graph, walk, is_continue, "./test.mp4", audio_path=None, verbose_continue=True, return_motion=True) | |
| # print(motion.shape) | |
| save_graph = graph.write_pickle(fname=graph_path) | |
| # graph = graph_pruning(graph) | |
| # show-oliver | |
| # json_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/" | |
| # pre_node_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/cached_graph/show_oliver_test/" | |
| # for json_file in tqdm(os.listdir(json_path)): | |
| # graph = create_graph(os.path.join(json_path, json_file)) | |
| # graph = create_edges(graph) | |
| # if not len(graph.vs) >= 1500: | |
| # print(f"skip: {len(graph.vs)}", json_file) | |
| # graph.write_pickle(fname=os.path.join(pre_node_path, json_file.split(".")[0] + ".pkl")) | |
| # print(f"Graph saved at {json_file.split('.')[0]}.pkl") |