aknapitsch user
initial commit of map anything demo
9507532
"""
Base dataset class that enables easy resizing and combining
References: DUSt3R
"""
import numpy as np
from mapanything.datasets.base.batched_sampler import (
BatchedMultiFeatureRandomSampler,
DynamicBatchedMultiFeatureRandomSampler,
)
class EasyDataset:
"""
Dataset that can be easily resized and combined.
Examples:
---------
2 * dataset ==> Duplicate each element 2x
10 @ dataset ==> Set the size to 10 (random sampling, duplicates if necessary)
Dataset1 + Dataset2 ==> Concatenate datasets
"""
def __add__(self, other):
"""
Concatenate this dataset with another dataset.
Args:
other (EasyDataset): Another dataset to concatenate with this one
Returns:
CatDataset: A new dataset that is the concatenation of this dataset and the other
"""
return CatDataset([self, other])
def __rmul__(self, factor):
"""
Multiply the dataset by a factor, duplicating each element.
Args:
factor (int): Number of times to duplicate each element
Returns:
MulDataset: A new dataset with each element duplicated 'factor' times
"""
return MulDataset(factor, self)
def __rmatmul__(self, factor):
"""
Resize the dataset to a specific size using random sampling.
Args:
factor (int): The new size of the dataset
Returns:
ResizedDataset: A new dataset with the specified size
"""
return ResizedDataset(factor, self)
def set_epoch(self, epoch):
"""
Set the current epoch for all constituent datasets.
Args:
epoch (int): The current epoch number
"""
pass # nothing to do by default
def make_sampler(
self,
batch_size=None,
shuffle=True,
world_size=1,
rank=0,
drop_last=True,
max_num_of_images_per_gpu=None,
use_dynamic_sampler=True,
):
"""
Create a sampler for this dataset.
Args:
batch_size (int, optional): Number of samples per batch (used for non-dynamic sampler). Defaults to None.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
world_size (int, optional): Number of distributed processes. Defaults to 1.
rank (int, optional): Rank of the current process. Defaults to 0.
drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
max_num_of_images_per_gpu (int, optional): Maximum number of images per GPU for dynamic batching. Defaults to None.
use_dynamic_sampler (bool, optional): Whether to use the dynamic sampler. Defaults to True.
Returns:
DynamicBatchedMultiFeatureRandomSampler or BatchedMultiFeatureRandomSampler: A sampler for this dataset
Raises:
NotImplementedError: If shuffle is False
ValueError: If num_views has an invalid type or required parameters are missing
"""
if not (shuffle):
raise NotImplementedError() # cannot deal yet
if isinstance(self.num_views, int):
num_of_aspect_ratios = len(self._resolutions)
feature_pool_sizes = [num_of_aspect_ratios]
scaling_feature_idx = 0 # Use aspect ratio as scaling feature
elif isinstance(self.num_views, list):
num_of_aspect_ratios = len(self._resolutions)
num_of_num_views = len(self.num_views)
feature_pool_sizes = [num_of_aspect_ratios, num_of_num_views]
scaling_feature_idx = 1 # Use num_views as scaling feature
else:
raise ValueError(
f"Bad type for {self.num_views=}, should be int or list of ints"
)
if use_dynamic_sampler:
if max_num_of_images_per_gpu is None:
raise ValueError(
"max_num_of_images_per_gpu must be provided when using dynamic sampler"
)
# Create feature-to-batch-size mapping
if isinstance(self.num_views, list):
# Map num_views_idx to batch size: max(1, max_num_of_images_per_gpu // (num_views_idx + dataset.num_views_min))
feature_to_batch_size_map = {}
for num_views_idx, num_views in enumerate(self.num_views):
batch_size_for_multi_view_sets = max(
1, max_num_of_images_per_gpu // num_views
)
feature_to_batch_size_map[num_views_idx] = (
batch_size_for_multi_view_sets
)
else:
# For fixed num_views, use a simple mapping
feature_to_batch_size_map = {
0: max(1, max_num_of_images_per_gpu // self.num_views)
}
return DynamicBatchedMultiFeatureRandomSampler(
self,
pool_sizes=feature_pool_sizes,
scaling_feature_idx=scaling_feature_idx,
feature_to_batch_size_map=feature_to_batch_size_map,
world_size=world_size,
rank=rank,
drop_last=drop_last,
)
else:
if batch_size is None:
raise ValueError(
"batch_size must be provided when not using dynamic sampler"
)
return BatchedMultiFeatureRandomSampler(
self,
batch_size,
feature_pool_sizes,
world_size=world_size,
rank=rank,
drop_last=drop_last,
)
class MulDataset(EasyDataset):
"""Artifically augmenting the size of a dataset."""
multiplicator: int
def __init__(self, multiplicator, dataset):
"""
Initialize a dataset that artificially augments the size of another dataset.
Args:
multiplicator (int): Factor by which to multiply the dataset size
dataset (EasyDataset): The dataset to augment
"""
assert isinstance(multiplicator, int) and multiplicator > 0
self.multiplicator = multiplicator
self.dataset = dataset
def __len__(self):
"""
Get the length of the dataset.
Returns:
int: The number of samples in the dataset
"""
return self.multiplicator * len(self.dataset)
def __repr__(self):
"""
Get a string representation of the dataset.
Returns:
str: String representation showing the multiplication factor and the original dataset
"""
return f"{self.multiplicator}*{repr(self.dataset)}"
def __getitem__(self, idx):
"""
Get an item from the dataset.
Args:
idx: Index or tuple of indices to retrieve
Returns:
The item at the specified index from the original dataset
"""
if isinstance(idx, tuple):
other = idx[1:]
idx = idx[0]
new_idx = (idx // self.multiplicator, *other)
return self.dataset[new_idx]
else:
return self.dataset[idx // self.multiplicator]
@property
def _resolutions(self):
"""
Get the resolutions of the dataset.
Returns:
The resolutions from the original dataset
"""
return self.dataset._resolutions
@property
def num_views(self):
"""
Get the number of views used for the dataset.
Returns:
int or list: The number of views parameter from the original dataset
"""
return self.dataset.num_views
class ResizedDataset(EasyDataset):
"""Artifically changing the size of a dataset."""
new_size: int
def __init__(self, new_size, dataset):
"""
Initialize a dataset with an artificially changed size.
Args:
new_size (int): The new size of the dataset
dataset (EasyDataset): The original dataset
"""
assert isinstance(new_size, int) and new_size > 0
self.new_size = new_size
self.dataset = dataset
def __len__(self):
"""
Get the length of the dataset.
Returns:
int: The new size of the dataset
"""
return self.new_size
def __repr__(self):
"""
Get a string representation of the dataset.
Returns:
str: String representation showing the new size and the original dataset
"""
size_str = str(self.new_size)
for i in range((len(size_str) - 1) // 3):
sep = -4 * i - 3
size_str = size_str[:sep] + "_" + size_str[sep:]
return f"{size_str} @ {repr(self.dataset)}"
def set_epoch(self, epoch):
"""
Set the current epoch and generate a new random mapping of indices.
This method must be called before using __getitem__.
Args:
epoch (int): The current epoch number
"""
# This random shuffle only depends on the epoch
rng = np.random.default_rng(seed=epoch + 777)
# Shuffle all indices
perm = rng.permutation(len(self.dataset))
# Calculate how many repetitions we need
num_repetitions = 1 + (len(self) - 1) // len(self.dataset)
# Rotary extension until target size is met
shuffled_idxs = np.concatenate([perm] * num_repetitions)
self._idxs_mapping = shuffled_idxs[: self.new_size]
# Generate the seed offset for each repetition
# This is needed to ensure we see unique samples when we repeat a scene
seed_offset_per_repetition = [
np.full(len(self.dataset), i) for i in range(num_repetitions)
]
seed_offset_idxs = np.concatenate(seed_offset_per_repetition)
self._idxs_seed_offset = seed_offset_idxs[: self.new_size]
assert len(self._idxs_mapping) == self.new_size
assert len(self._idxs_seed_offset) == self.new_size
def __getitem__(self, idx):
"""
Get an item from the dataset.
Args:
idx: Index or tuple of indices to retrieve
Returns:
The item at the mapped index from the original dataset
Raises:
AssertionError: If set_epoch has not been called
"""
assert hasattr(self, "_idxs_mapping"), (
"You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()"
)
if isinstance(idx, tuple):
other = idx[1:]
idx = idx[0]
self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
new_idx = (self._idxs_mapping[idx], *other)
return self.dataset[new_idx]
else:
self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
return self.dataset[self._idxs_mapping[idx]]
@property
def _resolutions(self):
"""
Get the resolutions of the dataset.
Returns:
The resolutions from the original dataset
"""
return self.dataset._resolutions
@property
def num_views(self):
"""
Get the number of views used for the dataset.
Returns:
int or list: The number of views parameter from the original dataset
"""
return self.dataset.num_views
class CatDataset(EasyDataset):
"""Concatenation of several datasets"""
def __init__(self, datasets):
"""
Initialize a dataset that is a concatenation of several datasets.
Args:
datasets (list): List of EasyDataset instances to concatenate
"""
for dataset in datasets:
assert isinstance(dataset, EasyDataset)
self.datasets = datasets
self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
def __len__(self):
"""
Get the length of the concatenated dataset.
Returns:
int: Total number of samples across all datasets
"""
return self._cum_sizes[-1]
def __repr__(self):
"""
Get a string representation of the concatenated dataset.
Returns:
str: String representation showing all concatenated datasets joined by '+'
"""
# Remove uselessly long transform
return " + ".join(
repr(dataset).replace(
",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))",
"",
)
for dataset in self.datasets
)
def set_epoch(self, epoch):
"""
Set the current epoch for all constituent datasets.
Args:
epoch (int): The current epoch number
"""
for dataset in self.datasets:
dataset.set_epoch(epoch)
def __getitem__(self, idx):
"""
Get an item from the concatenated dataset.
Args:
idx: Index or tuple of indices to retrieve
Returns:
The item at the specified index from the appropriate constituent dataset
Raises:
IndexError: If the index is out of range
"""
other = None
if isinstance(idx, tuple):
other = idx[1:]
idx = idx[0]
if not (0 <= idx < len(self)):
raise IndexError()
db_idx = np.searchsorted(self._cum_sizes, idx, "right")
dataset = self.datasets[db_idx]
new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
if other is not None:
new_idx = (new_idx, *other)
return dataset[new_idx]
@property
def _resolutions(self):
"""
Get the resolutions of the dataset.
Returns:
The resolutions from the first dataset (all datasets must have the same resolutions)
Raises:
AssertionError: If datasets have different resolutions
"""
resolutions = self.datasets[0]._resolutions
for dataset in self.datasets[1:]:
assert tuple(dataset._resolutions) == tuple(resolutions), (
"All datasets must have the same resolutions"
)
return resolutions
@property
def num_views(self):
"""
Get the number of views used for the dataset.
Returns:
int or list: The number of views parameter from the first dataset
Raises:
AssertionError: If datasets have different num_views
"""
num_views = self.datasets[0].num_views
for dataset in self.datasets[1:]:
assert dataset.num_views == num_views, (
"All datasets must have the same num_views and variable_num_views parameters"
)
return num_views