Spaces:
Running
Running
| """ | |
| Utilities for random sampling under a single or multiple constraints | |
| References: DUSt3R | |
| """ | |
| import numpy as np | |
| import torch | |
| def round_by(total, multiple, up=False): | |
| """ | |
| Round a number to the nearest multiple of another number. | |
| Args: | |
| total (int): The number to round | |
| multiple (int): The multiple to round to | |
| up (bool, optional): Whether to round up. Defaults to False. | |
| Returns: | |
| int: The rounded number | |
| """ | |
| if up: | |
| total = total + multiple - 1 | |
| return (total // multiple) * multiple | |
| class BatchedRandomSampler: | |
| """ | |
| Random sampling under a constraint: each sample in the batch has the same feature, | |
| which is chosen randomly from a known pool of 'features' for each batch. | |
| For instance, the 'feature' could be the image aspect-ratio. | |
| The index returned is a tuple (sample_idx, feat_idx). | |
| This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. | |
| """ | |
| def __init__( | |
| self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True | |
| ): | |
| """ | |
| Args: | |
| dataset: Dataset to sample from | |
| batch_size: Number of samples per batch | |
| pool_size: Integer representing the size of feature pool | |
| world_size: Number of distributed processes | |
| rank: Rank of the current process | |
| drop_last: Whether to drop the last incomplete batch | |
| """ | |
| self.batch_size = batch_size | |
| self.pool_size = pool_size | |
| self.len_dataset = N = len(dataset) | |
| self.total_size = round_by(N, batch_size * world_size) if drop_last else N | |
| assert world_size == 1 or drop_last, ( | |
| "must drop the last batch in distributed mode" | |
| ) | |
| # Distributed sampler | |
| self.world_size = world_size | |
| self.rank = rank | |
| self.epoch = None | |
| def __len__(self): | |
| """ | |
| Get the length of the sampler. | |
| Returns: | |
| int: The number of samples in the sampler for the current process | |
| """ | |
| return self.total_size // self.world_size | |
| def set_epoch(self, epoch): | |
| """ | |
| Set the epoch for this sampler. | |
| This should be called before each epoch to ensure proper shuffling of the data. | |
| Args: | |
| epoch (int): The current epoch number | |
| """ | |
| self.epoch = epoch | |
| def __iter__(self): | |
| """ | |
| Iterator over the indices. | |
| This method generates random indices for each batch, ensuring that all samples | |
| within a batch have the same feature index for the given feature pool. | |
| Yields: | |
| tuple: A tuple containing (sample_idx, feat_idx) | |
| """ | |
| # Prepare RNG | |
| if self.epoch is None: | |
| assert self.world_size == 1 and self.rank == 0, ( | |
| "use set_epoch() if distributed mode is used" | |
| ) | |
| seed = int(torch.empty((), dtype=torch.int64).random_().item()) | |
| else: | |
| seed = self.epoch + 777 | |
| rng = np.random.default_rng(seed=seed) | |
| # Random indices (will restart from 0 if not drop_last) | |
| sample_idxs = np.arange(self.total_size) | |
| rng.shuffle(sample_idxs) | |
| # Random feat_idxs (same across each batch) | |
| n_batches = (self.total_size + self.batch_size - 1) // self.batch_size | |
| feat_idxs = rng.integers(self.pool_size, size=n_batches) | |
| feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) | |
| feat_idxs = feat_idxs.ravel()[: self.total_size] | |
| # Put them together | |
| idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) | |
| # Distributed sampler: we select a subset of batches | |
| # Make sure the slice for each node is aligned with batch_size | |
| size_per_proc = self.batch_size * ( | |
| (self.total_size + self.world_size * self.batch_size - 1) | |
| // (self.world_size * self.batch_size) | |
| ) | |
| idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc] | |
| yield from (tuple(idx) for idx in idxs) | |
| class BatchedMultiFeatureRandomSampler: | |
| """ | |
| Random sampling under multiple constraints: each sample in the batch has the same features, | |
| which are chosen randomly from known pools of 'features' for each batch. | |
| For instance, the 'features' could be the image aspect-ratio and scene type. | |
| The index returned is a tuple (sample_idx, feat_idx_1, feat_idx_2, ...). | |
| This sampler ensures that each series of `batch_size` indices has the same feature indices. | |
| """ | |
| def __init__( | |
| self, dataset, batch_size, pool_sizes, world_size=1, rank=0, drop_last=True | |
| ): | |
| """ | |
| Args: | |
| dataset: Dataset to sample from | |
| batch_size: Number of samples per batch | |
| pool_sizes: List of integers representing the size of each feature pool | |
| world_size: Number of distributed processes | |
| rank: Rank of the current process | |
| drop_last: Whether to drop the last incomplete batch | |
| """ | |
| self.batch_size = batch_size | |
| self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes] | |
| self.len_dataset = N = len(dataset) | |
| self.total_size = round_by(N, batch_size * world_size) if drop_last else N | |
| assert world_size == 1 or drop_last, ( | |
| "must drop the last batch in distributed mode" | |
| ) | |
| # Distributed sampler | |
| self.world_size = world_size | |
| self.rank = rank | |
| self.epoch = None | |
| def __len__(self): | |
| """ | |
| Get the length of the sampler. | |
| Returns: | |
| int: The number of samples in the sampler for the current process | |
| """ | |
| return self.total_size // self.world_size | |
| def set_epoch(self, epoch): | |
| """ | |
| Set the epoch for this sampler. | |
| This should be called before each epoch to ensure proper shuffling of the data. | |
| Args: | |
| epoch (int): The current epoch number | |
| """ | |
| self.epoch = epoch | |
| def __iter__(self): | |
| """ | |
| Iterator over the indices. | |
| This method generates random indices for each batch, ensuring that all samples | |
| within a batch have the same feature indices for multiple features. | |
| Yields: | |
| tuple: A tuple containing (sample_idx, feat_idx_1, feat_idx_2, ...) | |
| """ | |
| # Prepare RNG | |
| if self.epoch is None: | |
| assert self.world_size == 1 and self.rank == 0, ( | |
| "use set_epoch() if distributed mode is used" | |
| ) | |
| seed = int(torch.empty((), dtype=torch.int64).random_().item()) | |
| else: | |
| seed = self.epoch + 777 | |
| rng = np.random.default_rng(seed=seed) | |
| # Random indices (will restart from 0 if not drop_last) | |
| sample_idxs = np.arange(self.total_size) | |
| rng.shuffle(sample_idxs) | |
| # Random feat_idxs (same across each batch) | |
| n_batches = (self.total_size + self.batch_size - 1) // self.batch_size | |
| # Generate feature indices for each feature pool | |
| all_feat_idxs = [] | |
| for pool_size in self.pool_sizes: | |
| feat_idxs = rng.integers(pool_size, size=n_batches) | |
| feat_idxs = np.broadcast_to( | |
| feat_idxs[:, None], (n_batches, self.batch_size) | |
| ) | |
| feat_idxs = feat_idxs.ravel()[: self.total_size] | |
| all_feat_idxs.append(feat_idxs) | |
| # Put them together | |
| idxs = np.column_stack( | |
| [sample_idxs] + all_feat_idxs | |
| ) # shape = (total_size, 1 + len(pool_sizes)) | |
| # Distributed sampler: we select a subset of batches | |
| # Make sure the slice for each node is aligned with batch_size | |
| size_per_proc = self.batch_size * ( | |
| (self.total_size + self.world_size * self.batch_size - 1) | |
| // (self.world_size * self.batch_size) | |
| ) | |
| idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc] | |
| yield from (tuple(idx) for idx in idxs) | |
| class DynamicBatchedMultiFeatureRandomSampler: | |
| """ | |
| Random sampling under multiple constraints with dynamic batch size: | |
| each sample in the batch has the same features, which are chosen randomly | |
| from known pools of 'features' for each batch. | |
| The batch size is dynamically determined based on a specified feature index, | |
| using a direct mapping from feature values to batch sizes. | |
| For instance, if one of the features is the number of images in a multi-view set, | |
| you can specify different batch sizes for different numbers of images to optimize | |
| GPU memory usage. This is achieved by using the feature_to_batch_size_map parameter | |
| to directly specify what batch size to use for each feature value. | |
| The returned index is a list of tuples [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]. | |
| """ | |
| def __init__( | |
| self, | |
| dataset, | |
| pool_sizes, | |
| scaling_feature_idx=0, | |
| feature_to_batch_size_map=None, | |
| world_size=1, | |
| rank=0, | |
| drop_last=True, | |
| ): | |
| """ | |
| Args: | |
| dataset: Dataset to sample from | |
| pool_sizes: List of integers representing the size of each feature pool | |
| scaling_feature_idx: Index of the feature to use for determining batch size (0-based index into pool_sizes) | |
| feature_to_batch_size_map: Optional function or dict that maps feature values directly to batch sizes. | |
| For example, if the feature represents number of views, this maps number of views | |
| to appropriate batch size that can fit in GPU memory. | |
| If None, uses a default batch size of 1 for all feature values. | |
| world_size: Number of distributed processes | |
| rank: Rank of the current process | |
| drop_last: Whether to drop the last incomplete batch | |
| """ | |
| self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes] | |
| self.scaling_feature_idx = scaling_feature_idx | |
| # Ensure scaling_feature_idx is valid | |
| if scaling_feature_idx < 0 or scaling_feature_idx >= len(self.pool_sizes): | |
| raise ValueError( | |
| f"scaling_feature_idx must be between 0 and {len(self.pool_sizes) - 1}" | |
| ) | |
| # Set up mapping from feature values to batch sizes | |
| self.feature_to_batch_size_map = feature_to_batch_size_map | |
| if self.feature_to_batch_size_map is None: | |
| # Default: batch size of 1 for all feature values | |
| self.feature_to_batch_size_map = { | |
| i: 1 for i in range(self.pool_sizes[scaling_feature_idx]) | |
| } | |
| self.len_dataset = N = len(dataset) | |
| # We don't know the exact batch size yet, so we use a large number for total_size | |
| # This will be adjusted during iteration | |
| self.total_size = N | |
| # Distributed sampler | |
| self.world_size = world_size | |
| self.rank = rank | |
| self.epoch = None | |
| self.drop_last = drop_last | |
| def __len__(self): | |
| """ | |
| Get the approximate length of the sampler. | |
| Since batch size varies, this is an estimate based on the largest batch size | |
| in the mapping, which provides a lower bound on the number of batches. | |
| Returns: | |
| int: The estimated minimum number of samples in the sampler for the current process | |
| """ | |
| # Find the largest batch size in the mapping | |
| if callable(self.feature_to_batch_size_map): | |
| # If it's a function, sample some values to find the maximum | |
| batch_sizes = [ | |
| self.feature_to_batch_size_map(i) | |
| for i in range(self.pool_sizes[self.scaling_feature_idx]) | |
| ] | |
| max_batch_size = max(batch_sizes) | |
| else: | |
| # If it's a dict or similar, find the maximum directly | |
| max_batch_size = max(self.feature_to_batch_size_map.values()) | |
| # Ensure minimum batch size of 1 | |
| max_batch_size = max(1, max_batch_size) | |
| # Estimate total batches using the largest batch size | |
| # This gives a lower bound on the number of batches | |
| total_batches = self.total_size // max_batch_size | |
| if not self.drop_last and self.total_size % max_batch_size > 0: | |
| total_batches += 1 | |
| # Distribute among processes | |
| return total_batches // self.world_size | |
| def set_epoch(self, epoch): | |
| """ | |
| Set the epoch for this sampler. | |
| This should be called before each epoch to ensure proper shuffling of the data. | |
| Args: | |
| epoch (int): The current epoch number | |
| """ | |
| self.epoch = epoch | |
| def __iter__(self): | |
| """ | |
| Iterator over the indices with dynamic batch sizes. | |
| This method generates random indices for each batch, ensuring that all samples | |
| within a batch have the same feature indices for multiple features. | |
| The batch size is determined directly from the feature_to_batch_size_map. | |
| The iterator enforces the length returned by __len__() by stopping after | |
| exactly that many batches have been yielded for this process. | |
| Yields: | |
| list of tuples: A batch of tuples, each containing (sample_idx, feat_idx_1, feat_idx_2, ...) | |
| """ | |
| # Prepare RNG | |
| if self.epoch is None: | |
| assert self.world_size == 1 and self.rank == 0, ( | |
| "use set_epoch() if distributed mode is used" | |
| ) | |
| seed = int(torch.empty((), dtype=torch.int64).random_().item()) | |
| else: | |
| seed = self.epoch + 777 | |
| rng = np.random.default_rng(seed=seed) | |
| # Random indices for the entire dataset | |
| sample_idxs = np.arange(self.total_size) | |
| rng.shuffle(sample_idxs) | |
| # Get the target number of batches for this process (enforce strict length) | |
| target_batches_for_process = len(self) | |
| batches_yielded_for_process = 0 | |
| # Process indices in batches with dynamic sizing | |
| idx = 0 | |
| batch_idx = 0 # Track batch index for even distribution | |
| while idx < len(sample_idxs) and ( | |
| batches_yielded_for_process < target_batches_for_process | |
| ): | |
| # Randomly select feature indices for this batch | |
| feat_idxs = [rng.integers(pool_size) for pool_size in self.pool_sizes] | |
| # Get the scaling feature value | |
| scaling_feat = feat_idxs[self.scaling_feature_idx] | |
| # Get the batch size directly from the mapping | |
| if callable(self.feature_to_batch_size_map): | |
| batch_size = self.feature_to_batch_size_map(scaling_feat) | |
| else: | |
| batch_size = self.feature_to_batch_size_map.get(scaling_feat, 1) | |
| # Ensure minimum batch size of 1 | |
| batch_size = max(1, batch_size) | |
| # Ensure we don't go beyond available samples | |
| remaining = len(sample_idxs) - idx | |
| if remaining < batch_size: | |
| if self.drop_last: | |
| break | |
| batch_size = remaining | |
| # Create batch with consistent feature indices | |
| batch = [] | |
| for i in range(batch_size): | |
| if idx + i < len(sample_idxs): | |
| sample_idx = sample_idxs[idx + i] | |
| batch.append(tuple([sample_idx] + feat_idxs)) | |
| # Distribute batches among processes in round-robin fashion | |
| if len(batch) > 0 and (batch_idx % self.world_size == self.rank): | |
| yield batch | |
| batches_yielded_for_process += 1 | |
| batch_idx += 1 # Increment batch index | |
| idx += batch_size | |