Spaces:
Running
Running
| import os | |
| import pandas as pd | |
| from PIL import Image, ImageFile | |
| from torch.utils.data import Dataset | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| from .transforms import task_transform | |
| class TaskonomyDataset(Dataset): | |
| def __init__(self, | |
| data_root, | |
| tasks, | |
| split='train', | |
| variant='tiny', | |
| image_size=256, | |
| max_images=None): | |
| """ | |
| Taskonomy dataloader. | |
| Args: | |
| data_root: Root of Taskonomy data directory | |
| tasks: List of tasks. Any of ['rgb', 'depth_euclidean', 'depth_zbuffer', | |
| 'edge_occlusion', 'edge_texture', 'keypoints2d', 'keypoints3d', 'normal', | |
| 'principal_curvature', 'reshading', 'mask_valid']. | |
| split: One of {'train', 'val', 'test'} | |
| variant: One of {'debug', 'tiny', 'medium', 'full', 'fullplus'} | |
| image_size: Target image size | |
| max_images: Optional subset selection | |
| """ | |
| super(TaskonomyDataset, self).__init__() | |
| self.data_root = data_root | |
| self.tasks = tasks | |
| self.split = split | |
| self.variant = variant | |
| self.image_size=image_size | |
| self.max_images = max_images | |
| self.image_ids = pd.read_csv( | |
| os.path.join(os.path.dirname(__file__), 'splits', f'{self.variant}_{self.split}.csv') | |
| ).to_numpy() | |
| if isinstance(self.max_images, int): | |
| self.image_ids = self.image_ids[:self.max_images] | |
| print(f'Initialized TaskonomyDataset with {len(self.image_ids)} images from variant {self.variant} in split {self.split}.') | |
| def __len__(self): | |
| return len(self.image_ids) | |
| def __getitem__(self, index): | |
| # building / point / view | |
| building, point, view = self.image_ids[index] | |
| result = {} | |
| for task in self.tasks: | |
| task_id = 'depth_zbuffer' if task == 'mask_valid' else task | |
| path = os.path.join( | |
| self.data_root, task, building, f'point_{point}_view_{view}_domain_{task_id}.png' | |
| ) | |
| img = Image.open(path) | |
| # Perform transformations | |
| img = task_transform(img, task=task, image_size=self.image_size) | |
| result[task] = img | |
| return result | |