Spaces:
Running
Running
| from typing import Callable, List, Optional, Union | |
| import torch | |
| from pytorch_lightning import LightningDataModule | |
| from torch.utils.data import DataLoader, Dataset | |
| class SampleDataset(Dataset): | |
| def __init__(self, | |
| x: Union[List, torch.Tensor], | |
| y: Union[List, torch.Tensor], | |
| transforms: Optional[Callable] = None) -> None: | |
| super(SampleDataset, self).__init__() | |
| self.x = x | |
| self.y = y | |
| if transforms is None: | |
| # Replace None with some default transforms | |
| # If image, could be an Resize and ToTensor | |
| self.transforms = lambda x: x | |
| else: | |
| self.transforms = transforms | |
| def __len__(self): | |
| return len(self.x) | |
| def __getitem__(self, index: int): | |
| x = self.x[index] | |
| y = self.y[index] | |
| x = self.transforms(x) | |
| return x, y | |
| class SampleDataModule(LightningDataModule): | |
| def __init__(self, | |
| x: Union[List, torch.Tensor], | |
| y: Union[List, torch.Tensor], | |
| transforms: Optional[Callable] = None, | |
| val_ratio: float = 0, | |
| batch_size: int = 32) -> None: | |
| super(SampleDataModule, self).__init__() | |
| assert 0 <= val_ratio < 1 | |
| assert isinstance(batch_size, int) | |
| self.x = x | |
| self.y = y | |
| self.transforms = transforms | |
| self.val_ratio = val_ratio | |
| self.batch_size = batch_size | |
| self.setup() | |
| self.prepare_data() | |
| def setup(self, stage: Optional[str] = None) -> None: | |
| pass | |
| def prepare_data(self) -> None: | |
| n_samples: int = len(self.x) | |
| train_size: int = n_samples - int(n_samples * self.val_ratio) | |
| self.train_dataset = SampleDataset(x=self.x[:train_size], | |
| y=self.y[:train_size], | |
| transforms=self.transforms) | |
| if train_size < n_samples: | |
| self.val_dataset = SampleDataset(x=self.x[train_size:], | |
| y=self.y[train_size:], | |
| transforms=self.transforms) | |
| else: | |
| self.val_dataset = SampleDataset(x=self.x[-self.batch_size:], | |
| y=self.y[-self.batch_size:], | |
| transforms=self.transforms) | |
| def train_dataloader(self) -> DataLoader: | |
| return DataLoader(dataset=self.train_dataset, | |
| batch_size=self.batch_size) | |
| def val_dataloader(self) -> DataLoader: | |
| return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size) | |