Spaces:
Runtime error
Runtime error
| import pytorch_lightning as pl | |
| import torch | |
| import os | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset, random_split | |
| from torchvision import transforms | |
| from functools import partial | |
| class CelebADataset(Dataset): | |
| def __init__( | |
| self, | |
| data_dir: str, | |
| img_dim: int = 64 | |
| ): | |
| self.list_path = os.listdir(data_dir) | |
| self.data_dir = data_dir | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((img_dim, img_dim)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
| ] | |
| ) | |
| def __len__(self): | |
| return len(self.list_path) | |
| def __getitem__(self, index): | |
| img = Image.open(os.path.join(self.data_dir, self.list_path[index])) | |
| return self.transform(img) | |
| class CelebADataModule(pl.LightningDataModule): | |
| def __init__( | |
| self, | |
| data_dir: str = "./", | |
| batch_size: int = 32, | |
| num_workers: int = 0, | |
| seed: int = 42, | |
| train_ratio: float = 0.99, | |
| img_dim: int = 64 | |
| ): | |
| super().__init__() | |
| self.data_dir = data_dir | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.train_ratio = min(train_ratio, 0.99) | |
| self.img_dim = img_dim | |
| self.seed = seed | |
| self.loader = partial( | |
| DataLoader, | |
| batch_size=self.batch_size, | |
| pin_memory=True, | |
| num_workers=self.num_workers, | |
| persistent_workers=True | |
| ) | |
| def setup(self, stage: str): | |
| if stage == "fit": | |
| dataset = CelebADataset(self.data_dir, self.img_dim) | |
| self.CelebA_train, self.CelebA_val, _ = random_split( | |
| dataset=dataset, | |
| lengths=[self.train_ratio, 0.01, 1 - 0.01 - self.train_ratio], | |
| generator=torch.Generator().manual_seed(self.seed) | |
| ) | |
| else: | |
| pass | |
| def train_dataloader(self): | |
| return self.loader(dataset=self.CelebA_train) | |
| def val_dataloader(self): | |
| return self.loader(dataset=self.CelebA_val) | |