Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import gzip | |
| import torch | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| from src.downloader import download_dataset | |
| def download_mnist(download_dir): | |
| download_dataset("mnist", download_dir) | |
| return {"train": (download_dir + "train_images", download_dir + "train_labels"), | |
| "test": (download_dir + "test_images", download_dir + "test_labels")} | |
| class DatasetMNIST(Dataset): | |
| def __init__(self, images, labels): | |
| with gzip.open(images, 'r') as f: | |
| f.read(4) | |
| self.total = int.from_bytes(f.read(4), 'big') | |
| rows = int.from_bytes(f.read(4), 'big') | |
| columns = int.from_bytes(f.read(4), 'big') | |
| image_data = f.read() | |
| images = np.frombuffer(image_data, dtype=np.uint8).reshape((self.total, rows, columns)) | |
| self.images = images | |
| with gzip.open(labels, 'r') as f: | |
| f.read(8) | |
| label_data = f.read() | |
| labels = np.frombuffer(label_data, dtype=np.uint8) | |
| self.labels = labels | |
| self.data = list(zip(self.images, self.labels)) | |
| def __getitem__(self, n): | |
| if n > self.total: | |
| raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.") | |
| return torch.tensor(self.data[n][0].reshape(1, 28, 28), dtype=torch.float32), torch.tensor(self.data[n][1]) | |
| def __len__(self): | |
| return len(self.data) | |
| if __name__ == "__main__": | |
| download_dir = "downloads/mnist/" | |
| mnist = download_mnist(download_dir) | |
| dataset = DatasetMNIST(*mnist["train"]) | |
| import matplotlib.pyplot as plt | |
| X, y = dataset[4] | |
| plt.imshow(X, cmap="gray") | |
| plt.title(label="Annotated label: " + str(y)) | |
| plt.show() | |