File size: 14,979 Bytes
9507532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
"""
Base dataset class that enables easy resizing and combining

References: DUSt3R
"""

import numpy as np

from mapanything.datasets.base.batched_sampler import (
    BatchedMultiFeatureRandomSampler,
    DynamicBatchedMultiFeatureRandomSampler,
)


class EasyDataset:
    """
    Dataset that can be easily resized and combined.

    Examples:
    ---------
        2 * dataset ==> Duplicate each element 2x

        10 @ dataset ==> Set the size to 10 (random sampling, duplicates if necessary)

        Dataset1 + Dataset2 ==> Concatenate datasets
    """

    def __add__(self, other):
        """
        Concatenate this dataset with another dataset.

        Args:
            other (EasyDataset): Another dataset to concatenate with this one

        Returns:
            CatDataset: A new dataset that is the concatenation of this dataset and the other
        """
        return CatDataset([self, other])

    def __rmul__(self, factor):
        """
        Multiply the dataset by a factor, duplicating each element.

        Args:
            factor (int): Number of times to duplicate each element

        Returns:
            MulDataset: A new dataset with each element duplicated 'factor' times
        """
        return MulDataset(factor, self)

    def __rmatmul__(self, factor):
        """
        Resize the dataset to a specific size using random sampling.

        Args:
            factor (int): The new size of the dataset

        Returns:
            ResizedDataset: A new dataset with the specified size
        """
        return ResizedDataset(factor, self)

    def set_epoch(self, epoch):
        """
        Set the current epoch for all constituent datasets.

        Args:
            epoch (int): The current epoch number
        """
        pass  # nothing to do by default

    def make_sampler(
        self,
        batch_size=None,
        shuffle=True,
        world_size=1,
        rank=0,
        drop_last=True,
        max_num_of_images_per_gpu=None,
        use_dynamic_sampler=True,
    ):
        """
        Create a sampler for this dataset.

        Args:
            batch_size (int, optional): Number of samples per batch (used for non-dynamic sampler). Defaults to None.
            shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
            world_size (int, optional): Number of distributed processes. Defaults to 1.
            rank (int, optional): Rank of the current process. Defaults to 0.
            drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
            max_num_of_images_per_gpu (int, optional): Maximum number of images per GPU for dynamic batching. Defaults to None.
            use_dynamic_sampler (bool, optional): Whether to use the dynamic sampler. Defaults to True.

        Returns:
            DynamicBatchedMultiFeatureRandomSampler or BatchedMultiFeatureRandomSampler: A sampler for this dataset

        Raises:
            NotImplementedError: If shuffle is False
            ValueError: If num_views has an invalid type or required parameters are missing
        """
        if not (shuffle):
            raise NotImplementedError()  # cannot deal yet

        if isinstance(self.num_views, int):
            num_of_aspect_ratios = len(self._resolutions)
            feature_pool_sizes = [num_of_aspect_ratios]
            scaling_feature_idx = 0  # Use aspect ratio as scaling feature
        elif isinstance(self.num_views, list):
            num_of_aspect_ratios = len(self._resolutions)
            num_of_num_views = len(self.num_views)
            feature_pool_sizes = [num_of_aspect_ratios, num_of_num_views]
            scaling_feature_idx = 1  # Use num_views as scaling feature
        else:
            raise ValueError(
                f"Bad type for {self.num_views=}, should be int or list of ints"
            )

        if use_dynamic_sampler:
            if max_num_of_images_per_gpu is None:
                raise ValueError(
                    "max_num_of_images_per_gpu must be provided when using dynamic sampler"
                )

            # Create feature-to-batch-size mapping
            if isinstance(self.num_views, list):
                # Map num_views_idx to batch size: max(1, max_num_of_images_per_gpu // (num_views_idx + dataset.num_views_min))
                feature_to_batch_size_map = {}
                for num_views_idx, num_views in enumerate(self.num_views):
                    batch_size_for_multi_view_sets = max(
                        1, max_num_of_images_per_gpu // num_views
                    )
                    feature_to_batch_size_map[num_views_idx] = (
                        batch_size_for_multi_view_sets
                    )
            else:
                # For fixed num_views, use a simple mapping
                feature_to_batch_size_map = {
                    0: max(1, max_num_of_images_per_gpu // self.num_views)
                }

            return DynamicBatchedMultiFeatureRandomSampler(
                self,
                pool_sizes=feature_pool_sizes,
                scaling_feature_idx=scaling_feature_idx,
                feature_to_batch_size_map=feature_to_batch_size_map,
                world_size=world_size,
                rank=rank,
                drop_last=drop_last,
            )
        else:
            if batch_size is None:
                raise ValueError(
                    "batch_size must be provided when not using dynamic sampler"
                )

            return BatchedMultiFeatureRandomSampler(
                self,
                batch_size,
                feature_pool_sizes,
                world_size=world_size,
                rank=rank,
                drop_last=drop_last,
            )


class MulDataset(EasyDataset):
    """Artifically augmenting the size of a dataset."""

    multiplicator: int

    def __init__(self, multiplicator, dataset):
        """
        Initialize a dataset that artificially augments the size of another dataset.

        Args:
            multiplicator (int): Factor by which to multiply the dataset size
            dataset (EasyDataset): The dataset to augment
        """
        assert isinstance(multiplicator, int) and multiplicator > 0
        self.multiplicator = multiplicator
        self.dataset = dataset

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: The number of samples in the dataset
        """
        return self.multiplicator * len(self.dataset)

    def __repr__(self):
        """
        Get a string representation of the dataset.

        Returns:
            str: String representation showing the multiplication factor and the original dataset
        """
        return f"{self.multiplicator}*{repr(self.dataset)}"

    def __getitem__(self, idx):
        """
        Get an item from the dataset.

        Args:
            idx: Index or tuple of indices to retrieve

        Returns:
            The item at the specified index from the original dataset
        """
        if isinstance(idx, tuple):
            other = idx[1:]
            idx = idx[0]
            new_idx = (idx // self.multiplicator, *other)
            return self.dataset[new_idx]
        else:
            return self.dataset[idx // self.multiplicator]

    @property
    def _resolutions(self):
        """
        Get the resolutions of the dataset.

        Returns:
            The resolutions from the original dataset
        """
        return self.dataset._resolutions

    @property
    def num_views(self):
        """
        Get the number of views used for the dataset.

        Returns:
            int or list: The number of views parameter from the original dataset
        """
        return self.dataset.num_views


class ResizedDataset(EasyDataset):
    """Artifically changing the size of a dataset."""

    new_size: int

    def __init__(self, new_size, dataset):
        """
        Initialize a dataset with an artificially changed size.

        Args:
            new_size (int): The new size of the dataset
            dataset (EasyDataset): The original dataset
        """
        assert isinstance(new_size, int) and new_size > 0
        self.new_size = new_size
        self.dataset = dataset

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: The new size of the dataset
        """
        return self.new_size

    def __repr__(self):
        """
        Get a string representation of the dataset.

        Returns:
            str: String representation showing the new size and the original dataset
        """
        size_str = str(self.new_size)
        for i in range((len(size_str) - 1) // 3):
            sep = -4 * i - 3
            size_str = size_str[:sep] + "_" + size_str[sep:]
        return f"{size_str} @ {repr(self.dataset)}"

    def set_epoch(self, epoch):
        """
        Set the current epoch and generate a new random mapping of indices.

        This method must be called before using __getitem__.

        Args:
            epoch (int): The current epoch number
        """
        # This random shuffle only depends on the epoch
        rng = np.random.default_rng(seed=epoch + 777)

        # Shuffle all indices
        perm = rng.permutation(len(self.dataset))

        # Calculate how many repetitions we need
        num_repetitions = 1 + (len(self) - 1) // len(self.dataset)

        # Rotary extension until target size is met
        shuffled_idxs = np.concatenate([perm] * num_repetitions)
        self._idxs_mapping = shuffled_idxs[: self.new_size]

        # Generate the seed offset for each repetition
        # This is needed to ensure we see unique samples when we repeat a scene
        seed_offset_per_repetition = [
            np.full(len(self.dataset), i) for i in range(num_repetitions)
        ]
        seed_offset_idxs = np.concatenate(seed_offset_per_repetition)
        self._idxs_seed_offset = seed_offset_idxs[: self.new_size]

        assert len(self._idxs_mapping) == self.new_size
        assert len(self._idxs_seed_offset) == self.new_size

    def __getitem__(self, idx):
        """
        Get an item from the dataset.

        Args:
            idx: Index or tuple of indices to retrieve

        Returns:
            The item at the mapped index from the original dataset

        Raises:
            AssertionError: If set_epoch has not been called
        """
        assert hasattr(self, "_idxs_mapping"), (
            "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()"
        )
        if isinstance(idx, tuple):
            other = idx[1:]
            idx = idx[0]
            self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
            new_idx = (self._idxs_mapping[idx], *other)
            return self.dataset[new_idx]
        else:
            self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
            return self.dataset[self._idxs_mapping[idx]]

    @property
    def _resolutions(self):
        """
        Get the resolutions of the dataset.

        Returns:
            The resolutions from the original dataset
        """
        return self.dataset._resolutions

    @property
    def num_views(self):
        """
        Get the number of views used for the dataset.

        Returns:
            int or list: The number of views parameter from the original dataset
        """
        return self.dataset.num_views


class CatDataset(EasyDataset):
    """Concatenation of several datasets"""

    def __init__(self, datasets):
        """
        Initialize a dataset that is a concatenation of several datasets.

        Args:
            datasets (list): List of EasyDataset instances to concatenate
        """
        for dataset in datasets:
            assert isinstance(dataset, EasyDataset)
        self.datasets = datasets
        self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])

    def __len__(self):
        """
        Get the length of the concatenated dataset.

        Returns:
            int: Total number of samples across all datasets
        """
        return self._cum_sizes[-1]

    def __repr__(self):
        """
        Get a string representation of the concatenated dataset.

        Returns:
            str: String representation showing all concatenated datasets joined by '+'
        """
        # Remove uselessly long transform
        return " + ".join(
            repr(dataset).replace(
                ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))",
                "",
            )
            for dataset in self.datasets
        )

    def set_epoch(self, epoch):
        """
        Set the current epoch for all constituent datasets.

        Args:
            epoch (int): The current epoch number
        """
        for dataset in self.datasets:
            dataset.set_epoch(epoch)

    def __getitem__(self, idx):
        """
        Get an item from the concatenated dataset.

        Args:
            idx: Index or tuple of indices to retrieve

        Returns:
            The item at the specified index from the appropriate constituent dataset

        Raises:
            IndexError: If the index is out of range
        """
        other = None
        if isinstance(idx, tuple):
            other = idx[1:]
            idx = idx[0]

        if not (0 <= idx < len(self)):
            raise IndexError()

        db_idx = np.searchsorted(self._cum_sizes, idx, "right")
        dataset = self.datasets[db_idx]
        new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)

        if other is not None:
            new_idx = (new_idx, *other)
        return dataset[new_idx]

    @property
    def _resolutions(self):
        """
        Get the resolutions of the dataset.

        Returns:
            The resolutions from the first dataset (all datasets must have the same resolutions)

        Raises:
            AssertionError: If datasets have different resolutions
        """
        resolutions = self.datasets[0]._resolutions
        for dataset in self.datasets[1:]:
            assert tuple(dataset._resolutions) == tuple(resolutions), (
                "All datasets must have the same resolutions"
            )
        return resolutions

    @property
    def num_views(self):
        """
        Get the number of views used for the dataset.

        Returns:
            int or list: The number of views parameter from the first dataset

        Raises:
            AssertionError: If datasets have different num_views
        """
        num_views = self.datasets[0].num_views
        for dataset in self.datasets[1:]:
            assert dataset.num_views == num_views, (
                "All datasets must have the same num_views and variable_num_views parameters"
            )
        return num_views