""" Structured3D Dataset using WAI format data. """ import os import numpy as np from mapanything.datasets.base.base_dataset import BaseDataset from mapanything.utils.wai.core import load_data, load_frame class Structured3DWAI(BaseDataset): """ Structured3D dataset containing large diversity of synthetic multi-room indoor scenes. """ def __init__( self, *args, ROOT, dataset_metadata_dir, split, overfit_num_sets=None, sample_specific_scene: bool = False, specific_scene_name: str = None, **kwargs, ): """ Initialize the dataset attributes. Args: ROOT: Root directory of the dataset. dataset_metadata_dir: Path to the dataset metadata directory. split: Dataset split (train, val, test). overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. sample_specific_scene: Whether to sample a specific scene from the dataset. specific_scene_name: Name of the specific scene to sample. """ # Initialize the dataset attributes super().__init__(*args, **kwargs) self.ROOT = ROOT self.dataset_metadata_dir = dataset_metadata_dir self.split = split self.overfit_num_sets = overfit_num_sets self.sample_specific_scene = sample_specific_scene self.specific_scene_name = specific_scene_name self._load_data() # Define the dataset type flags self.is_metric_scale = True self.is_synthetic = True def _load_data(self): "Load the precomputed dataset metadata" # Load the dataset metadata corresponding to the split split_metadata_path = os.path.join( self.dataset_metadata_dir, self.split, f"structured3d_scene_list_{self.split}.npy", ) split_scene_list = np.load(split_metadata_path, allow_pickle=True) # Get the list of all scenes if not self.sample_specific_scene: self.scenes = list(split_scene_list) else: self.scenes = [self.specific_scene_name] self.num_of_scenes = len(self.scenes) def _get_views(self, sampled_idx, num_views_to_sample, resolution): # Get the scene name of the sampled index scene_index = sampled_idx scene_name = self.scenes[scene_index] # Get the metadata corresponding to the scene scene_root = os.path.join(self.ROOT, scene_name) scene_meta = load_data( os.path.join(scene_root, "scene_meta.json"), "scene_meta" ) scene_file_names = list(scene_meta["frame_names"].keys()) num_views_in_scene = len(scene_file_names) # Load the scene pairwise covisibility mmap covisibility_version_key = "v0" covisibility_map_dir = os.path.join( scene_root, "covisibility", covisibility_version_key ) # Assumes only npy file in directory is covisbility map covisibility_map_name = next( f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") ) covisibility_map_path = os.path.join( scene_root, "covisibility", covisibility_version_key, covisibility_map_name ) pairwise_covisibility = load_data(covisibility_map_path, "mmap") # Get the indices of the N views in the scene view_indices = self._sample_view_indices( num_views_to_sample, num_views_in_scene, pairwise_covisibility ) # Get the views corresponding to the selected view indices views = [] for view_index in view_indices: # Load the data corresponding to the view view_file_name = scene_file_names[view_index] view_data = load_frame( scene_root, view_file_name, modalities=["image", "depth"], scene_meta=scene_meta, ) # Convert necessary data to numpy image = view_data["image"].permute(1, 2, 0).numpy() image = image[:, :, :3] # RGBA to RGB image = (image * 255).astype(np.uint8) depthmap = view_data["depth"].numpy().astype(np.float32) intrinsics = view_data["intrinsics"].numpy().astype(np.float32) c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) # Ensure that the depthmap has all valid values depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) # Resize the data to match the desired resolution image, depthmap, intrinsics = self._crop_resize_if_necessary( image=image, resolution=resolution, depthmap=depthmap, intrinsics=intrinsics, additional_quantities=None, ) # Append the view dictionary to the list of views views.append( dict( img=image, depthmap=depthmap, camera_pose=c2w_pose, # cam2world camera_intrinsics=intrinsics, dataset="Structured3D", label=scene_name, instance=os.path.join("images", str(view_file_name)), ) ) return views def get_parser(): import argparse parser = argparse.ArgumentParser() parser.add_argument( "-rd", "--root_dir", default="/fsx/xrtech/data/structured3d", type=str ) parser.add_argument( "-dmd", "--dataset_metadata_dir", default="/fsx/nkeetha/mapanything_dataset_metadata", type=str, ) parser.add_argument( "-nv", "--num_of_views", default=2, type=int, ) parser.add_argument("--viz", action="store_true") return parser if __name__ == "__main__": import rerun as rr from tqdm import tqdm from mapanything.datasets.base.base_dataset import view_name from mapanything.utils.image import rgb from mapanything.utils.viz import script_add_rerun_args parser = get_parser() script_add_rerun_args( parser ) # Options: --headless, --connect, --serve, --addr, --save, --stdout args = parser.parse_args() dataset = Structured3DWAI( num_views=args.num_of_views, split="train", covisibility_thres=0.25, ROOT=args.root_dir, dataset_metadata_dir=args.dataset_metadata_dir, resolution=(518, 294), aug_crop=16, transform="colorjitter+grayscale+gaublur", data_norm_type="dinov2", ) # dataset = Structured3DWAI( # num_views=args.num_of_views, # split="val", # covisibility_thres=0.25, # ROOT=args.root_dir, # dataset_metadata_dir=args.dataset_metadata_dir, # resolution=(518, 294), # seed=777, # transform="imgnorm", # data_norm_type="dinov2", # ) print(dataset.get_stats()) if args.viz: rr.script_setup(args, "Structured3D_Dataloader") rr.set_time("stable_time", sequence=0) rr.log("world", rr.ViewCoordinates.RDF, static=True) sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) for num, idx in enumerate(tqdm(sampled_indices)): views = dataset[idx] assert len(views) == args.num_of_views sample_name = f"{idx}" for view_idx in range(args.num_of_views): sample_name += f" {view_name(views[view_idx])}" print(sample_name) for view_idx in range(args.num_of_views): image = rgb( views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] ) depthmap = views[view_idx]["depthmap"] pose = views[view_idx]["camera_pose"] intrinsics = views[view_idx]["camera_intrinsics"] pts3d = views[view_idx]["pts3d"] valid_mask = views[view_idx]["valid_mask"] if "non_ambiguous_mask" in views[view_idx]: non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] else: non_ambiguous_mask = None if "prior_depth_along_ray" in views[view_idx]: prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] else: prior_depth_along_ray = None if args.viz: rr.set_time("stable_time", sequence=num) base_name = f"world/view_{view_idx}" pts_name = f"world/view_{view_idx}_pointcloud" # Log camera info and loaded data height, width = image.shape[0], image.shape[1] rr.log( base_name, rr.Transform3D( translation=pose[:3, 3], mat3x3=pose[:3, :3], ), ) rr.log( f"{base_name}/pinhole", rr.Pinhole( image_from_camera=intrinsics, height=height, width=width, camera_xyz=rr.ViewCoordinates.RDF, ), ) rr.log( f"{base_name}/pinhole/rgb", rr.Image(image), ) rr.log( f"{base_name}/pinhole/depth", rr.DepthImage(depthmap), ) if prior_depth_along_ray is not None: rr.log( f"prior_depth_along_ray_{view_idx}", rr.DepthImage(prior_depth_along_ray), ) if non_ambiguous_mask is not None: rr.log( f"{base_name}/pinhole/non_ambiguous_mask", rr.SegmentationImage(non_ambiguous_mask.astype(int)), ) # Log points in 3D filtered_pts = pts3d[valid_mask] filtered_pts_col = image[valid_mask] rr.log( pts_name, rr.Points3D( positions=filtered_pts.reshape(-1, 3), colors=filtered_pts_col.reshape(-1, 3), ), )