File size: 4,388 Bytes
37de32d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from pathlib import Path
from typing import Any

import torch
from box import Box

from mapanything.utils.wai.core import get_frame_index, load_data, load_frame
from mapanything.utils.wai.ops import stack
from mapanything.utils.wai.scene_frame import get_scene_frame_names


class BasicSceneframeDataset(torch.utils.data.Dataset):
    """Basic wai dataset to iterative over frames of scenes"""

    @staticmethod
    def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]:
        return stack(batch)

    def __init__(
        self,
        cfg: Box,
    ):
        """
        Initialize the BasicSceneframeDataset.

        Args:
            cfg (Box): Configuration object containing dataset parameters including:
                - root: Root directory containing scene data
                - frame_modalities: List of modalities to load for each frame
                - key_remap: Optional dictionary mapping original keys to new keys
        """
        super().__init__()
        self.cfg = cfg
        self.root = cfg.root
        keyframes = cfg.get("use_keyframes", True)
        self.scene_frame_names = get_scene_frame_names(cfg, keyframes=keyframes)
        self.scene_frame_list = [
            (scene_name, frame_name)
            for scene_name, frame_names in self.scene_frame_names.items()
            for frame_name in frame_names
        ]
        self._scene_cache = {}

    def __len__(self):
        """
        Get the total number of scene-frame pairs in the dataset.

        Returns:
            int: The number of scene-frame pairs.
        """
        return len(self.scene_frame_list)

    def _load_scene(self, scene_name: str) -> dict[str, Any]:
        """
        Load scene data for a given scene name.

        Args:
            scene_name (str): The name of the scene to load.

        Returns:
            dict: A dictionary containing scene data, including scene metadata.
        """
        # load scene data
        scene_data = {}
        scene_data["meta"] = load_data(
            Path(
                self.root,
                scene_name,
                self.cfg.get("scene_meta_path", "scene_meta.json"),
            ),
            "scene_meta",
        )

        return scene_data

    def _load_scene_frame(
        self, scene_name: str, frame_name: str | float
    ) -> dict[str, Any]:
        """
        Load data for a specific frame from a specific scene.

        This method loads scene data if not already cached, then loads the specified frame
        from that scene with the modalities specified in the configuration.

        Args:
            scene_name (str): The name of the scene containing the frame.
            frame_name (str or float): The name/timestamp of the frame to load.

        Returns:
            dict: A dictionary containing the loaded frame data with requested modalities.
        """
        scene_frame_data = {}
        if not (scene_data := self._scene_cache.get(scene_name)):
            scene_data = self._load_scene(scene_name)
            # for now only cache the last scene
            self._scene_cache = {}
            self._scene_cache[scene_name] = scene_data

        frame_idx = get_frame_index(scene_data["meta"], frame_name)

        scene_frame_data["scene_name"] = scene_name
        scene_frame_data["frame_name"] = frame_name
        scene_frame_data["scene_path"] = str(Path(self.root, scene_name))
        scene_frame_data["frame_idx"] = frame_idx
        scene_frame_data.update(
            load_frame(
                Path(self.root, scene_name),
                frame_name,
                modalities=self.cfg.frame_modalities,
                scene_meta=scene_data["meta"],
            )
        )
        # Remap key names
        for key, new_key in self.cfg.get("key_remap", {}).items():
            if key in scene_frame_data:
                scene_frame_data[new_key] = scene_frame_data.pop(key)

        return scene_frame_data

    def __getitem__(self, index: int) -> dict[str, Any]:
        """
        Get a specific scene-frame pair by index.

        Args:
            index (int): The index of the scene-frame pair to retrieve.

        Returns:
            dict: A dictionary containing the loaded frame data with requested modalities.
        """
        scene_frame = self._load_scene_frame(*self.scene_frame_list[index])
        return scene_frame