aknapitsch user
simpler inference and refactoring
37de32d
from pathlib import Path
from typing import Any
import torch
from box import Box
from mapanything.utils.wai.core import get_frame_index, load_data, load_frame
from mapanything.utils.wai.ops import stack
from mapanything.utils.wai.scene_frame import get_scene_frame_names
class BasicSceneframeDataset(torch.utils.data.Dataset):
"""Basic wai dataset to iterative over frames of scenes"""
@staticmethod
def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]:
return stack(batch)
def __init__(
self,
cfg: Box,
):
"""
Initialize the BasicSceneframeDataset.
Args:
cfg (Box): Configuration object containing dataset parameters including:
- root: Root directory containing scene data
- frame_modalities: List of modalities to load for each frame
- key_remap: Optional dictionary mapping original keys to new keys
"""
super().__init__()
self.cfg = cfg
self.root = cfg.root
keyframes = cfg.get("use_keyframes", True)
self.scene_frame_names = get_scene_frame_names(cfg, keyframes=keyframes)
self.scene_frame_list = [
(scene_name, frame_name)
for scene_name, frame_names in self.scene_frame_names.items()
for frame_name in frame_names
]
self._scene_cache = {}
def __len__(self):
"""
Get the total number of scene-frame pairs in the dataset.
Returns:
int: The number of scene-frame pairs.
"""
return len(self.scene_frame_list)
def _load_scene(self, scene_name: str) -> dict[str, Any]:
"""
Load scene data for a given scene name.
Args:
scene_name (str): The name of the scene to load.
Returns:
dict: A dictionary containing scene data, including scene metadata.
"""
# load scene data
scene_data = {}
scene_data["meta"] = load_data(
Path(
self.root,
scene_name,
self.cfg.get("scene_meta_path", "scene_meta.json"),
),
"scene_meta",
)
return scene_data
def _load_scene_frame(
self, scene_name: str, frame_name: str | float
) -> dict[str, Any]:
"""
Load data for a specific frame from a specific scene.
This method loads scene data if not already cached, then loads the specified frame
from that scene with the modalities specified in the configuration.
Args:
scene_name (str): The name of the scene containing the frame.
frame_name (str or float): The name/timestamp of the frame to load.
Returns:
dict: A dictionary containing the loaded frame data with requested modalities.
"""
scene_frame_data = {}
if not (scene_data := self._scene_cache.get(scene_name)):
scene_data = self._load_scene(scene_name)
# for now only cache the last scene
self._scene_cache = {}
self._scene_cache[scene_name] = scene_data
frame_idx = get_frame_index(scene_data["meta"], frame_name)
scene_frame_data["scene_name"] = scene_name
scene_frame_data["frame_name"] = frame_name
scene_frame_data["scene_path"] = str(Path(self.root, scene_name))
scene_frame_data["frame_idx"] = frame_idx
scene_frame_data.update(
load_frame(
Path(self.root, scene_name),
frame_name,
modalities=self.cfg.frame_modalities,
scene_meta=scene_data["meta"],
)
)
# Remap key names
for key, new_key in self.cfg.get("key_remap", {}).items():
if key in scene_frame_data:
scene_frame_data[new_key] = scene_frame_data.pop(key)
return scene_frame_data
def __getitem__(self, index: int) -> dict[str, Any]:
"""
Get a specific scene-frame pair by index.
Args:
index (int): The index of the scene-frame pair to retrieve.
Returns:
dict: A dictionary containing the loaded frame data with requested modalities.
"""
scene_frame = self._load_scene_frame(*self.scene_frame_list[index])
return scene_frame