Spaces:
Running
Running
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from box import Box | |
| from mapanything.utils.wai.core import get_frame_index, load_data, load_frame | |
| from mapanything.utils.wai.ops import stack | |
| from mapanything.utils.wai.scene_frame import get_scene_frame_names | |
| class BasicSceneframeDataset(torch.utils.data.Dataset): | |
| """Basic wai dataset to iterative over frames of scenes""" | |
| def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]: | |
| return stack(batch) | |
| def __init__( | |
| self, | |
| cfg: Box, | |
| ): | |
| """ | |
| Initialize the BasicSceneframeDataset. | |
| Args: | |
| cfg (Box): Configuration object containing dataset parameters including: | |
| - root: Root directory containing scene data | |
| - frame_modalities: List of modalities to load for each frame | |
| - key_remap: Optional dictionary mapping original keys to new keys | |
| """ | |
| super().__init__() | |
| self.cfg = cfg | |
| self.root = cfg.root | |
| keyframes = cfg.get("use_keyframes", True) | |
| self.scene_frame_names = get_scene_frame_names(cfg, keyframes=keyframes) | |
| self.scene_frame_list = [ | |
| (scene_name, frame_name) | |
| for scene_name, frame_names in self.scene_frame_names.items() | |
| for frame_name in frame_names | |
| ] | |
| self._scene_cache = {} | |
| def __len__(self): | |
| """ | |
| Get the total number of scene-frame pairs in the dataset. | |
| Returns: | |
| int: The number of scene-frame pairs. | |
| """ | |
| return len(self.scene_frame_list) | |
| def _load_scene(self, scene_name: str) -> dict[str, Any]: | |
| """ | |
| Load scene data for a given scene name. | |
| Args: | |
| scene_name (str): The name of the scene to load. | |
| Returns: | |
| dict: A dictionary containing scene data, including scene metadata. | |
| """ | |
| # load scene data | |
| scene_data = {} | |
| scene_data["meta"] = load_data( | |
| Path( | |
| self.root, | |
| scene_name, | |
| self.cfg.get("scene_meta_path", "scene_meta.json"), | |
| ), | |
| "scene_meta", | |
| ) | |
| return scene_data | |
| def _load_scene_frame( | |
| self, scene_name: str, frame_name: str | float | |
| ) -> dict[str, Any]: | |
| """ | |
| Load data for a specific frame from a specific scene. | |
| This method loads scene data if not already cached, then loads the specified frame | |
| from that scene with the modalities specified in the configuration. | |
| Args: | |
| scene_name (str): The name of the scene containing the frame. | |
| frame_name (str or float): The name/timestamp of the frame to load. | |
| Returns: | |
| dict: A dictionary containing the loaded frame data with requested modalities. | |
| """ | |
| scene_frame_data = {} | |
| if not (scene_data := self._scene_cache.get(scene_name)): | |
| scene_data = self._load_scene(scene_name) | |
| # for now only cache the last scene | |
| self._scene_cache = {} | |
| self._scene_cache[scene_name] = scene_data | |
| frame_idx = get_frame_index(scene_data["meta"], frame_name) | |
| scene_frame_data["scene_name"] = scene_name | |
| scene_frame_data["frame_name"] = frame_name | |
| scene_frame_data["scene_path"] = str(Path(self.root, scene_name)) | |
| scene_frame_data["frame_idx"] = frame_idx | |
| scene_frame_data.update( | |
| load_frame( | |
| Path(self.root, scene_name), | |
| frame_name, | |
| modalities=self.cfg.frame_modalities, | |
| scene_meta=scene_data["meta"], | |
| ) | |
| ) | |
| # Remap key names | |
| for key, new_key in self.cfg.get("key_remap", {}).items(): | |
| if key in scene_frame_data: | |
| scene_frame_data[new_key] = scene_frame_data.pop(key) | |
| return scene_frame_data | |
| def __getitem__(self, index: int) -> dict[str, Any]: | |
| """ | |
| Get a specific scene-frame pair by index. | |
| Args: | |
| index (int): The index of the scene-frame pair to retrieve. | |
| Returns: | |
| dict: A dictionary containing the loaded frame data with requested modalities. | |
| """ | |
| scene_frame = self._load_scene_frame(*self.scene_frame_list[index]) | |
| return scene_frame | |