Spaces:
Build error
Build error
| import argparse | |
| import os | |
| import random | |
| from typing import List, Optional, Tuple | |
| import imageio.v3 as iio | |
| import numpy as np | |
| import PIL | |
| import rootutils | |
| import torch | |
| from diffusers import ( | |
| AutoencoderKLCogVideoX, | |
| CogVideoXDPMScheduler, | |
| CogVideoXTransformer3DModel, | |
| ) | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
| from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402 | |
| AetherV1PipelineCogVideoX, | |
| AetherV1PipelineOutput, | |
| ) | |
| from aether.utils.postprocess_utils import ( # noqa: E402 | |
| align_camera_extrinsics, | |
| apply_transformation, | |
| colorize_depth, | |
| compute_scale, | |
| get_intrinsics, | |
| interpolate_poses, | |
| postprocess_pointmap, | |
| project, | |
| raymap_to_poses, | |
| ) | |
| from aether.utils.visualize_utils import predictions_to_glb # noqa: E402 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def seed_all(seed: int = 0) -> None: | |
| """ | |
| Set random seeds of all components. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def parse_args() -> argparse.Namespace: | |
| """Parse command line arguments.""" | |
| parser = argparse.ArgumentParser(description="AetherV1-CogvideoX Inference Demo") | |
| parser.add_argument( | |
| "--task", | |
| type=str, | |
| required=True, | |
| choices=["reconstruction", "prediction", "planning"], | |
| help="Task to perform: 'reconstruction', 'prediction' or 'planning'.", | |
| ) | |
| parser.add_argument( | |
| "--video", | |
| type=str, | |
| default=None, | |
| help="Path to a video file. Only used for 'reconstruction' task.", | |
| ) | |
| parser.add_argument( | |
| "--image", | |
| type=str, | |
| default=None, | |
| help="Path to an image file. Only used for 'prediction' and 'planning' tasks.", | |
| ) | |
| parser.add_argument( | |
| "--goal", | |
| type=str, | |
| default=None, | |
| help="Path to a goal image file. Only used for 'planning' task.", | |
| ) | |
| parser.add_argument( | |
| "--raymap_action", | |
| type=str, | |
| default=None, | |
| help="Path to a raymap action file. Should be a numpy array of shape (num_frame, 6, latent_height, latent_width).", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="outputs", | |
| help="Path to save the outputs.", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=42, | |
| help="Random seed.", | |
| ) | |
| parser.add_argument( | |
| "--fps", | |
| type=int, | |
| default=12, | |
| choices=[8, 10, 12, 15, 24], | |
| help="Frames per second. Options: 8, 10, 12, 15, 24.", | |
| ) | |
| parser.add_argument( | |
| "--num_inference_steps", | |
| type=int, | |
| default=None, | |
| help="Number of inference steps. If not specified, will use the default number of steps for the task.", | |
| ) | |
| parser.add_argument( | |
| "--guidance_scale", | |
| type=float, | |
| default=None, | |
| help="Guidance scale. If not specified, will use the default guidance scale for the task.", | |
| ) | |
| parser.add_argument( | |
| "--use_dynamic_cfg", | |
| action="store_true", | |
| default=True, | |
| help="Use dynamic cfg.", | |
| ) | |
| parser.add_argument( | |
| "--height", | |
| type=int, | |
| default=480, | |
| help="Height of the output video.", | |
| ) | |
| parser.add_argument( | |
| "--width", | |
| type=int, | |
| default=720, | |
| help="Width of the output video.", | |
| ) | |
| parser.add_argument( | |
| "--num_frames", | |
| type=int, | |
| default=41, | |
| help="Number of frames to predict.", | |
| ) | |
| parser.add_argument( | |
| "--max_depth", | |
| type=float, | |
| default=100.0, | |
| help="Maximum depth of the scene in meters.", | |
| ) | |
| parser.add_argument( | |
| "--rtol", | |
| type=float, | |
| default=0.03, | |
| help="Relative tolerance for depth edge detection.", | |
| ) | |
| parser.add_argument( | |
| "--cogvideox_pretrained_model_name_or_path", | |
| type=str, | |
| default="THUDM/CogVideoX-5b-I2V", | |
| help="Name or path of the CogVideoX model to use.", | |
| ) | |
| parser.add_argument( | |
| "--aether_pretrained_model_name_or_path", | |
| type=str, | |
| default="AetherWorldModel/AetherV1-CogVideoX", | |
| help="Name or path of the Aether model to use.", | |
| ) | |
| parser.add_argument( | |
| "--smooth_camera", | |
| action="store_true", | |
| default=True, | |
| help="Smooth the camera trajectory.", | |
| ) | |
| parser.add_argument( | |
| "--smooth_method", | |
| type=str, | |
| default="kalman", | |
| choices=["kalman", "simple"], | |
| help="Smooth method.", | |
| ) | |
| parser.add_argument( | |
| "--sliding_window_stride", | |
| type=int, | |
| default=24, | |
| help="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task.", | |
| ) | |
| parser.add_argument( | |
| "--post_reconstruction", | |
| action="store_true", | |
| default=True, | |
| help="Run reconstruction after prediction for better quality. Only used for 'prediction' and 'planning' tasks.", | |
| ) | |
| parser.add_argument( | |
| "--pointcloud_save_frame_interval", | |
| type=int, | |
| default=10, | |
| help="Pointcloud save frame interval.", | |
| ) | |
| parser.add_argument( | |
| "--align_pointmaps", | |
| action="store_true", | |
| default=False, | |
| help="Align pointmaps.", | |
| ) | |
| return parser.parse_args() | |
| def build_pipeline(args: argparse.Namespace) -> AetherV1PipelineCogVideoX: | |
| pipeline = AetherV1PipelineCogVideoX( | |
| tokenizer=AutoTokenizer.from_pretrained( | |
| args.cogvideox_pretrained_model_name_or_path, | |
| subfolder="tokenizer", | |
| ), | |
| text_encoder=T5EncoderModel.from_pretrained( | |
| args.cogvideox_pretrained_model_name_or_path, subfolder="text_encoder" | |
| ), | |
| vae=AutoencoderKLCogVideoX.from_pretrained( | |
| args.cogvideox_pretrained_model_name_or_path, subfolder="vae" | |
| ), | |
| scheduler=CogVideoXDPMScheduler.from_pretrained( | |
| args.cogvideox_pretrained_model_name_or_path, subfolder="scheduler" | |
| ), | |
| transformer=CogVideoXTransformer3DModel.from_pretrained( | |
| args.aether_pretrained_model_name_or_path, subfolder="transformer" | |
| ), | |
| ) | |
| pipeline.vae.enable_slicing() | |
| pipeline.vae.enable_tiling() | |
| pipeline.to(device) | |
| return pipeline | |
| def get_window_starts( | |
| total_frames: int, sliding_window_size: int, temporal_stride: int | |
| ) -> List[int]: | |
| """Calculate window start indices.""" | |
| starts = list( | |
| range( | |
| 0, | |
| total_frames - sliding_window_size + 1, | |
| temporal_stride, | |
| ) | |
| ) | |
| if ( | |
| total_frames > sliding_window_size | |
| and (total_frames - sliding_window_size) % temporal_stride != 0 | |
| ): | |
| starts.append(total_frames - sliding_window_size) | |
| return starts | |
| def blend_and_merge_window_results( | |
| window_results: List[AetherV1PipelineOutput], | |
| window_indices: List[int], | |
| args: argparse.Namespace, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| """Blend and merge window results.""" | |
| merged_rgb = None | |
| merged_disparity = None | |
| merged_poses = None | |
| merged_focals = None | |
| if args.align_pointmaps: | |
| merged_pointmaps = None | |
| w1 = window_results[0].disparity | |
| for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)): | |
| t_end = t_start + window_result.rgb.shape[0] | |
| if idx == 0: | |
| merged_rgb = window_result.rgb | |
| merged_disparity = window_result.disparity | |
| pointmap_dict = postprocess_pointmap( | |
| window_result.disparity, | |
| window_result.raymap, | |
| vae_downsample_scale=8, | |
| ray_o_scale_inv=0.1, | |
| smooth_camera=args.smooth_camera, | |
| smooth_method=args.smooth_method if args.smooth_camera else "none", | |
| ) | |
| merged_poses = pointmap_dict["camera_pose"] | |
| merged_focals = ( | |
| pointmap_dict["intrinsics"][:, 0, 0] | |
| + pointmap_dict["intrinsics"][:, 1, 1] | |
| ) / 2 | |
| if args.align_pointmaps: | |
| merged_pointmaps = pointmap_dict["pointmap"] | |
| else: | |
| overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start | |
| window_disparity = window_result.disparity | |
| # Align disparity | |
| disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1 | |
| scale = compute_scale( | |
| window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]), | |
| merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]), | |
| disp_mask.reshape(1, -1, w1.shape[-1]), | |
| ) | |
| window_disparity = scale * window_disparity | |
| # Blend disparity | |
| result_disparity = np.ones((t_end, *w1.shape[1:])) | |
| result_disparity[:t_start] = merged_disparity[:t_start] | |
| result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:] | |
| weight = np.linspace(1, 0, overlap_t)[:, None, None] | |
| result_disparity[t_start : t_start + overlap_t] = merged_disparity[ | |
| t_start : t_start + overlap_t | |
| ] * weight + window_disparity[:overlap_t] * (1 - weight) | |
| merged_disparity = result_disparity | |
| # Blend RGB | |
| result_rgb = np.ones((t_end, *w1.shape[1:], 3)) | |
| result_rgb[:t_start] = merged_rgb[:t_start] | |
| result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:] | |
| weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None] | |
| result_rgb[t_start : t_start + overlap_t] = merged_rgb[ | |
| t_start : t_start + overlap_t | |
| ] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb) | |
| merged_rgb = result_rgb | |
| # Align poses | |
| window_raymap = window_result.raymap | |
| window_poses, window_Fov_x, window_Fov_y = raymap_to_poses( | |
| window_raymap, ray_o_scale_inv=0.1 | |
| ) | |
| rel_r, rel_t, rel_s = align_camera_extrinsics( | |
| torch.from_numpy(window_poses[:overlap_t]), | |
| torch.from_numpy(merged_poses[-overlap_t:]), | |
| ) | |
| aligned_window_poses = ( | |
| apply_transformation( | |
| torch.from_numpy(window_poses), | |
| rel_r, | |
| rel_t, | |
| rel_s, | |
| return_extri=True, | |
| ) | |
| .cpu() | |
| .numpy() | |
| ) | |
| result_poses = np.ones((t_end, 4, 4)) | |
| result_poses[:t_start] = merged_poses[:t_start] | |
| result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:] | |
| # Interpolate poses in overlap region | |
| weights = np.linspace(1, 0, overlap_t) | |
| for t in range(overlap_t): | |
| weight = weights[t] | |
| pose1 = merged_poses[t_start + t] | |
| pose2 = aligned_window_poses[t] | |
| result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight) | |
| merged_poses = result_poses | |
| # Align intrinsics | |
| window_intrinsics, _ = get_intrinsics( | |
| batch_size=window_poses.shape[0], | |
| h=window_result.disparity.shape[1], | |
| w=window_result.disparity.shape[2], | |
| fovx=window_Fov_x, | |
| fovy=window_Fov_y, | |
| ) | |
| window_focals = ( | |
| window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1] | |
| ) / 2 | |
| scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean() | |
| window_focals = scale * window_focals | |
| result_focals = np.ones((t_end,)) | |
| result_focals[:t_start] = merged_focals[:t_start] | |
| result_focals[t_start + overlap_t :] = window_focals[overlap_t:] | |
| weight = np.linspace(1, 0, overlap_t) | |
| result_focals[t_start : t_start + overlap_t] = merged_focals[ | |
| t_start : t_start + overlap_t | |
| ] * weight + window_focals[:overlap_t] * (1 - weight) | |
| merged_focals = result_focals | |
| if args.align_pointmaps: | |
| # Align pointmaps | |
| window_pointmaps = postprocess_pointmap( | |
| result_disparity[t_start:], | |
| window_raymap, | |
| vae_downsample_scale=8, | |
| camera_pose=aligned_window_poses, | |
| focal=window_focals, | |
| ray_o_scale_inv=0.1, | |
| smooth_camera=args.smooth_camera, | |
| smooth_method=args.smooth_method if args.smooth_camera else "none", | |
| ) | |
| result_pointmaps = np.ones((t_end, *w1.shape[1:], 3)) | |
| result_pointmaps[:t_start] = merged_pointmaps[:t_start] | |
| result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][ | |
| overlap_t: | |
| ] | |
| weight = np.linspace(1, 0, overlap_t)[:, None, None, None] | |
| result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[ | |
| t_start : t_start + overlap_t | |
| ] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight) | |
| merged_pointmaps = result_pointmaps | |
| # project to pointmaps | |
| intrinsics = [ | |
| np.array([[f, 0, 0.5 * args.width], [0, f, 0.5 * args.height], [0, 0, 1]]) | |
| for f in merged_focals | |
| ] | |
| if args.align_pointmaps: | |
| pointmaps = merged_pointmaps | |
| else: | |
| pointmaps = np.stack( | |
| [ | |
| project( | |
| 1 / np.clip(merged_disparity[i], 1e-8, 1e8), | |
| intrinsics[i], | |
| merged_poses[i], | |
| ) | |
| for i in range(merged_poses.shape[0]) | |
| ] | |
| ) | |
| return merged_rgb, merged_disparity, merged_poses, pointmaps | |
| def save_output( | |
| rgb: np.ndarray, | |
| disparity: np.ndarray, | |
| poses: Optional[np.ndarray] = None, | |
| raymap: Optional[np.ndarray] = None, | |
| pointmap: Optional[np.ndarray] = None, | |
| args: argparse.Namespace = None, | |
| ) -> None: | |
| output_dir = args.output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| if pointmap is None: | |
| assert raymap is not None, "Raymap is required for saving pointmap." | |
| pointmap_dict = postprocess_pointmap( | |
| disparity, | |
| raymap, | |
| vae_downsample_scale=8, | |
| ray_o_scale_inv=0.1, | |
| smooth_camera=args.smooth_camera, | |
| smooth_method=args.smooth_method, | |
| ) | |
| pointmap = pointmap_dict["pointmap"] | |
| if poses is None: | |
| assert raymap is not None, "Raymap is required for saving poses." | |
| poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1) | |
| if args.task == "reconstruction": | |
| filename = f"reconstruction_{args.video.split('/')[-1].split('.')[0]}" | |
| elif args.task == "prediction": | |
| filename = f"prediction_{args.image.split('/')[-1].split('.')[0]}" | |
| elif args.task == "planning": | |
| filename = f"planning_{args.image.split('/')[-1].split('.')[0]}_{args.goal.split('/')[-1].split('.')[0]}" | |
| filename = os.path.join(output_dir, filename) | |
| iio.imwrite( | |
| f"{filename}_rgb.mp4", | |
| (np.clip(rgb, 0, 1) * 255).astype(np.uint8), | |
| fps=12, | |
| ) | |
| iio.imwrite( | |
| f"{filename}_disparity.mp4", | |
| (colorize_depth(disparity) * 255).astype(np.uint8), | |
| fps=12, | |
| ) | |
| print("Building GLB scene") | |
| for frame_idx in range(pointmap.shape[0])[:: args.pointcloud_save_frame_interval]: | |
| predictions = { | |
| "world_points": pointmap[frame_idx : frame_idx + 1], | |
| "images": rgb[frame_idx : frame_idx + 1], | |
| "depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8), | |
| "camera_poses": poses[frame_idx : frame_idx + 1], | |
| } | |
| scene_3d = predictions_to_glb( | |
| predictions, | |
| filter_by_frames="all", | |
| show_cam=True, | |
| max_depth=args.max_depth, | |
| rtol=args.rtol, | |
| frame_rel_idx=float(frame_idx) / pointmap.shape[0], | |
| ) | |
| scene_3d.export(f"{filename}_pointcloud_frame_{frame_idx}.glb") | |
| print("GLB Scene built") | |
| def main() -> None: | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| args = parse_args() | |
| seed_all(args.seed) | |
| if args.num_inference_steps is None: | |
| args.num_inference_steps = 4 if args.task == "reconstruction" else 50 | |
| if args.guidance_scale is None: | |
| args.guidance_scale = 1.0 if args.task == "reconstruction" else 3.0 | |
| pipeline = build_pipeline(args) | |
| if args.task == "reconstruction": | |
| assert args.video is not None, "Video is required for reconstruction task." | |
| assert args.image is None, "Image is not required for reconstruction task." | |
| assert args.goal is None, "Goal is not required for reconstruction task." | |
| video = iio.imread(args.video).astype(np.float32) / 255.0 | |
| image, goal = None, None | |
| elif args.task == "prediction": | |
| assert args.image is not None, "Image is required for prediction task." | |
| assert args.goal is None, "Goal is not required for prediction task." | |
| image = PIL.Image.open(args.image) | |
| video, goal = None, None | |
| elif args.task == "planning": | |
| assert args.image is not None, "Image is required for planning task." | |
| assert args.goal is not None, "Goal is required for planning task." | |
| image = PIL.Image.open(args.image) | |
| goal = PIL.Image.open(args.goal) | |
| video = None | |
| if args.raymap_action is not None: | |
| raymap = np.load(args.raymap_action) | |
| else: | |
| raymap = None | |
| if args.task != "reconstruction": | |
| output = pipeline( | |
| task=args.task, | |
| image=image, | |
| video=video, | |
| goal=goal, | |
| raymap=raymap, | |
| height=args.height, | |
| width=args.width, | |
| num_frames=args.num_frames, | |
| fps=args.fps, | |
| num_inference_steps=args.num_inference_steps, | |
| guidance_scale=args.guidance_scale, | |
| use_dynamic_cfg=args.use_dynamic_cfg, | |
| generator=torch.Generator(device=device).manual_seed(args.seed), | |
| return_dict=True, | |
| ) | |
| if not args.post_reconstruction: | |
| save_output( | |
| rgb=output.rgb, | |
| disparity=output.disparity, | |
| raymap=output.raymap, | |
| args=args, | |
| ) | |
| else: | |
| recon_output = pipeline( | |
| task="reconstruction", | |
| video=output.rgb, | |
| height=args.height, | |
| width=args.width, | |
| num_frames=args.num_frames, | |
| fps=args.fps, | |
| num_inference_steps=4, | |
| guidance_scale=1.0, # we don't need guidance scale for reconstruction task | |
| use_dynamic_cfg=False, | |
| generator=torch.Generator(device=device).manual_seed(args.seed), | |
| ) | |
| save_output( | |
| rgb=output.rgb, | |
| disparity=recon_output.disparity, | |
| raymap=recon_output.raymap, | |
| args=args, | |
| ) | |
| else: | |
| # for reconstruction task, we have to employ sliding window on long videos | |
| window_results = [] | |
| window_indices = get_window_starts( | |
| len(video), args.num_frames, args.sliding_window_stride | |
| ) | |
| for start_idx in window_indices: | |
| output = pipeline( | |
| task=args.task, | |
| image=None, | |
| goal=None, | |
| video=video[start_idx : start_idx + args.num_frames], | |
| raymap=raymap[start_idx : start_idx + args.num_frames] | |
| if raymap is not None | |
| else None, | |
| height=args.height, | |
| width=args.width, | |
| num_frames=args.num_frames, | |
| fps=args.fps, | |
| num_inference_steps=args.num_inference_steps, | |
| guidance_scale=1.0, # we don't need guidance scale for reconstruction task | |
| use_dynamic_cfg=False, | |
| generator=torch.Generator(device=device).manual_seed(args.seed), | |
| ) | |
| window_results.append(output) | |
| # merge window results | |
| ( | |
| merged_rgb, | |
| merged_disparity, | |
| merged_poses, | |
| pointmaps, | |
| ) = blend_and_merge_window_results(window_results, window_indices, args) | |
| save_output( | |
| rgb=merged_rgb, | |
| disparity=merged_disparity, | |
| poses=merged_poses, | |
| pointmap=pointmaps, | |
| args=args, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |