Spaces:
Sleeping
Sleeping
Commit
·
1de9461
1
Parent(s):
49b098d
fix: Changed default dataset image type
Browse files- src/dataset.py +5 -4
src/dataset.py
CHANGED
|
@@ -2,10 +2,11 @@
|
|
| 2 |
# coding: utf-8
|
| 3 |
import gzip
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
from torch.utils.data import Dataset
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def load_mnist(download_dir):
|
|
@@ -37,7 +38,7 @@ class DatasetMNIST(Dataset):
|
|
| 37 |
def __getitem__(self, n):
|
| 38 |
if n > self.total:
|
| 39 |
raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.")
|
| 40 |
-
return self.data[n]
|
| 41 |
|
| 42 |
def __len__(self):
|
| 43 |
return len(self.data)
|
|
|
|
| 2 |
# coding: utf-8
|
| 3 |
import gzip
|
| 4 |
|
| 5 |
+
import torch
|
|
|
|
|
|
|
| 6 |
from torch.utils.data import Dataset
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from src.downloader import download_dataset
|
| 10 |
|
| 11 |
|
| 12 |
def load_mnist(download_dir):
|
|
|
|
| 38 |
def __getitem__(self, n):
|
| 39 |
if n > self.total:
|
| 40 |
raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.")
|
| 41 |
+
return torch.tensor(self.data[n][0].reshape(1, 28, 28), dtype=torch.float32), torch.tensor(self.data[n][1])
|
| 42 |
|
| 43 |
def __len__(self):
|
| 44 |
return len(self.data)
|