Spaces:
Running
Running
File size: 4,388 Bytes
37de32d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from pathlib import Path
from typing import Any
import torch
from box import Box
from mapanything.utils.wai.core import get_frame_index, load_data, load_frame
from mapanything.utils.wai.ops import stack
from mapanything.utils.wai.scene_frame import get_scene_frame_names
class BasicSceneframeDataset(torch.utils.data.Dataset):
"""Basic wai dataset to iterative over frames of scenes"""
@staticmethod
def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]:
return stack(batch)
def __init__(
self,
cfg: Box,
):
"""
Initialize the BasicSceneframeDataset.
Args:
cfg (Box): Configuration object containing dataset parameters including:
- root: Root directory containing scene data
- frame_modalities: List of modalities to load for each frame
- key_remap: Optional dictionary mapping original keys to new keys
"""
super().__init__()
self.cfg = cfg
self.root = cfg.root
keyframes = cfg.get("use_keyframes", True)
self.scene_frame_names = get_scene_frame_names(cfg, keyframes=keyframes)
self.scene_frame_list = [
(scene_name, frame_name)
for scene_name, frame_names in self.scene_frame_names.items()
for frame_name in frame_names
]
self._scene_cache = {}
def __len__(self):
"""
Get the total number of scene-frame pairs in the dataset.
Returns:
int: The number of scene-frame pairs.
"""
return len(self.scene_frame_list)
def _load_scene(self, scene_name: str) -> dict[str, Any]:
"""
Load scene data for a given scene name.
Args:
scene_name (str): The name of the scene to load.
Returns:
dict: A dictionary containing scene data, including scene metadata.
"""
# load scene data
scene_data = {}
scene_data["meta"] = load_data(
Path(
self.root,
scene_name,
self.cfg.get("scene_meta_path", "scene_meta.json"),
),
"scene_meta",
)
return scene_data
def _load_scene_frame(
self, scene_name: str, frame_name: str | float
) -> dict[str, Any]:
"""
Load data for a specific frame from a specific scene.
This method loads scene data if not already cached, then loads the specified frame
from that scene with the modalities specified in the configuration.
Args:
scene_name (str): The name of the scene containing the frame.
frame_name (str or float): The name/timestamp of the frame to load.
Returns:
dict: A dictionary containing the loaded frame data with requested modalities.
"""
scene_frame_data = {}
if not (scene_data := self._scene_cache.get(scene_name)):
scene_data = self._load_scene(scene_name)
# for now only cache the last scene
self._scene_cache = {}
self._scene_cache[scene_name] = scene_data
frame_idx = get_frame_index(scene_data["meta"], frame_name)
scene_frame_data["scene_name"] = scene_name
scene_frame_data["frame_name"] = frame_name
scene_frame_data["scene_path"] = str(Path(self.root, scene_name))
scene_frame_data["frame_idx"] = frame_idx
scene_frame_data.update(
load_frame(
Path(self.root, scene_name),
frame_name,
modalities=self.cfg.frame_modalities,
scene_meta=scene_data["meta"],
)
)
# Remap key names
for key, new_key in self.cfg.get("key_remap", {}).items():
if key in scene_frame_data:
scene_frame_data[new_key] = scene_frame_data.pop(key)
return scene_frame_data
def __getitem__(self, index: int) -> dict[str, Any]:
"""
Get a specific scene-frame pair by index.
Args:
index (int): The index of the scene-frame pair to retrieve.
Returns:
dict: A dictionary containing the loaded frame data with requested modalities.
"""
scene_frame = self._load_scene_frame(*self.scene_frame_list[index])
return scene_frame
|