Spaces:
Runtime error
Runtime error
| import pytorch_lightning as pl | |
| import torch | |
| from torchvision.datasets import MNIST | |
| from torch.utils.data import DataLoader, random_split | |
| from torchvision import transforms | |
| from functools import partial | |
| class MNISTDataModule(pl.LightningDataModule): | |
| def __init__( | |
| self, | |
| data_dir: str = "./", | |
| batch_size: int = 32, | |
| num_workers: int = 0, | |
| seed: int = 42, | |
| train_ratio: float = 0.99, | |
| img_dim: int = 32 | |
| ): | |
| super().__init__() | |
| self.data_dir = data_dir | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.train_ratio = min(train_ratio, 0.99) | |
| self.seed = seed | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((img_dim, img_dim)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.5), std=(0.5)) | |
| ] | |
| ) | |
| self.loader = partial( | |
| DataLoader, | |
| batch_size=self.batch_size, | |
| pin_memory=True, | |
| num_workers=self.num_workers, | |
| persistent_workers=True | |
| ) | |
| def setup(self, stage: str): | |
| mnist_partial = partial( | |
| MNIST, | |
| root=self.data_dir, transform=self.transform, download=True | |
| ) | |
| if stage == "fit": | |
| retrying = True | |
| while retrying: | |
| try: | |
| mnist_full = mnist_partial(train=True) | |
| retrying = False | |
| except: | |
| pass | |
| self.mnist_train, self.mnist_val, _ = random_split( | |
| dataset=mnist_full, | |
| lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio], | |
| generator=torch.Generator().manual_seed(self.seed) | |
| ) | |
| else: | |
| retrying = True | |
| while retrying: | |
| try: | |
| self.mnist_test = mnist_partial(train=False) | |
| retrying = False | |
| except: | |
| pass | |
| def train_dataloader(self): | |
| return self.loader(dataset=self.mnist_train) | |
| def val_dataloader(self): | |
| return self.loader(dataset=self.mnist_val) | |
| def test_dataloader(self): | |
| return self.loader(dataset=self.mnist_test) | |